Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3d8f1c9b
Unverified
Commit
3d8f1c9b
authored
Jan 21, 2025
by
Lianmin Zheng
Committed by
GitHub
Jan 21, 2025
Browse files
Use int64 as indices for set_kv_buffer (#3039)
parent
a42213db
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
30 additions
and
37 deletions
+30
-37
python/sglang/bench_one_batch.py
python/sglang/bench_one_batch.py
+3
-5
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+1
-1
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+2
-5
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+7
-7
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+15
-17
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+2
-2
No files found.
python/sglang/bench_one_batch.py
View file @
3d8f1c9b
...
@@ -99,10 +99,7 @@ class BenchArgs:
...
@@ -99,10 +99,7 @@ class BenchArgs:
parser
.
add_argument
(
"--correctness-test"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--correctness-test"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--cut-len"
,
type
=
int
,
default
=
BenchArgs
.
cut_len
)
parser
.
add_argument
(
"--cut-len"
,
type
=
int
,
default
=
BenchArgs
.
cut_len
)
parser
.
add_argument
(
parser
.
add_argument
(
"--profile"
,
"--profile"
,
action
=
"store_true"
,
help
=
"Use Torch Profiler."
action
=
"store_true"
,
help
=
"Use Torch Profiler. The endpoint must be launched with "
"SGLANG_TORCH_PROFILER_DIR to enable profiler."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--profile-filename-prefix"
,
"--profile-filename-prefix"
,
...
@@ -381,6 +378,7 @@ def latency_test_run_once(
...
@@ -381,6 +378,7 @@ def latency_test_run_once(
parent_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
profile_filename
))
parent_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
profile_filename
))
os
.
makedirs
(
parent_dir
,
exist_ok
=
True
)
os
.
makedirs
(
parent_dir
,
exist_ok
=
True
)
profiler
.
export_chrome_trace
(
profile_filename
)
profiler
.
export_chrome_trace
(
profile_filename
)
rank_print
(
f
"torch profiler chrome trace saved to
{
profile_filename
}
"
)
# Record decode timing from 2nd output
# Record decode timing from 2nd output
if
output_len
>
1
:
if
output_len
>
1
:
...
@@ -451,7 +449,7 @@ def latency_test(
...
@@ -451,7 +449,7 @@ def latency_test(
il
,
il
,
ol
,
ol
,
server_args
.
device
,
server_args
.
device
,
bench_args
.
profile
,
bench_args
.
profile
if
tp_rank
==
0
else
None
,
bench_args
.
profile_filename_prefix
,
bench_args
.
profile_filename_prefix
,
)
)
if
ret
is
not
None
:
if
ret
is
not
None
:
...
...
python/sglang/srt/layers/logits_processor.py
View file @
3d8f1c9b
...
@@ -296,7 +296,7 @@ def fused_softcap_kernel(
...
@@ -296,7 +296,7 @@ def fused_softcap_kernel(
n_elements
,
n_elements
,
BLOCK_SIZE
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
):
pid
=
tl
.
program_id
(
0
)
pid
=
tl
.
program_id
(
0
)
.
to
(
tl
.
int64
)
block_start
=
pid
*
BLOCK_SIZE
block_start
=
pid
*
BLOCK_SIZE
offsets
=
block_start
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
offsets
=
block_start
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
offsets
<
n_elements
mask
=
offsets
<
n_elements
...
...
python/sglang/srt/layers/sampler.py
View file @
3d8f1c9b
import
logging
import
logging
from
typing
import
Dict
,
List
from
typing
import
List
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.sampling.custom_logit_processor
import
CustomLogitProcessor
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.utils
import
crash_on_warnings
,
is_flashinfer_available
from
sglang.srt.utils
import
crash_on_warnings
,
is_flashinfer_available
...
@@ -109,8 +108,6 @@ class Sampler(nn.Module):
...
@@ -109,8 +108,6 @@ class Sampler(nn.Module):
f
"Invalid sampling backend:
{
global_server_args_dict
[
'sampling_backend'
]
}
"
f
"Invalid sampling backend:
{
global_server_args_dict
[
'sampling_backend'
]
}
"
)
)
batch_next_token_ids
=
batch_next_token_ids
.
to
(
torch
.
int32
)
# Attach logprobs to logits_output (in-place modification)
# Attach logprobs to logits_output (in-place modification)
if
return_logprob
:
if
return_logprob
:
if
any
(
x
>
0
for
x
in
top_logprobs_nums
):
if
any
(
x
>
0
for
x
in
top_logprobs_nums
):
...
@@ -124,7 +121,7 @@ class Sampler(nn.Module):
...
@@ -124,7 +121,7 @@ class Sampler(nn.Module):
batch_next_token_ids
,
batch_next_token_ids
,
]
]
return
batch_next_token_ids
return
batch_next_token_ids
.
to
(
torch
.
int32
)
def
_apply_custom_logit_processor
(
def
_apply_custom_logit_processor
(
self
,
logits
:
torch
.
Tensor
,
sampling_batch_info
:
SamplingBatchInfo
self
,
logits
:
torch
.
Tensor
,
sampling_batch_info
:
SamplingBatchInfo
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
3d8f1c9b
...
@@ -550,13 +550,13 @@ class ScheduleBatch:
...
@@ -550,13 +550,13 @@ class ScheduleBatch:
next_batch_sampling_info
:
SamplingBatchInfo
=
None
next_batch_sampling_info
:
SamplingBatchInfo
=
None
# Batched arguments to model runner
# Batched arguments to model runner
input_ids
:
torch
.
Tensor
=
None
input_ids
:
torch
.
Tensor
=
None
# shape: [b], int32
input_embeds
:
torch
.
Tensor
=
None
input_embeds
:
torch
.
Tensor
=
None
# shape: [b, hidden_size], float32
req_pool_indices
:
torch
.
Tensor
=
None
req_pool_indices
:
torch
.
Tensor
=
None
# shape: [b], int32
seq_lens
:
torch
.
Tensor
=
None
seq_lens
:
torch
.
Tensor
=
None
# shape: [b], int64
# The output locations of the KV cache
# The output locations of the KV cache
out_cache_loc
:
torch
.
Tensor
=
None
out_cache_loc
:
torch
.
Tensor
=
None
# shape: [b], int32
output_ids
:
torch
.
Tensor
=
None
output_ids
:
torch
.
Tensor
=
None
# shape: [b], int32
# The sum of all sequence lengths
# The sum of all sequence lengths
seq_lens_sum
:
int
=
None
seq_lens_sum
:
int
=
None
...
@@ -1026,7 +1026,7 @@ class ScheduleBatch:
...
@@ -1026,7 +1026,7 @@ class ScheduleBatch:
self
.
input_ids
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
input_ids
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
seq_lens
=
torch
.
empty
(
0
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
seq_lens
=
torch
.
empty
(
0
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
out_cache_loc
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
out_cache_loc
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
req_pool_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
64
,
device
=
self
.
device
)
self
.
req_pool_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
32
,
device
=
self
.
device
)
self
.
seq_lens_sum
=
0
self
.
seq_lens_sum
=
0
self
.
extend_num_tokens
=
0
self
.
extend_num_tokens
=
0
self
.
sampling_info
=
SamplingBatchInfo
.
from_schedule_batch
(
self
.
sampling_info
=
SamplingBatchInfo
.
from_schedule_batch
(
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
3d8f1c9b
...
@@ -24,7 +24,7 @@ import tqdm
...
@@ -24,7 +24,7 @@ import tqdm
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.custom_op
import
CustomOp
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
from
sglang.srt.distributed.parallel_state
import
graph_capture
from
sglang.srt.distributed.parallel_state
import
GroupCoordinator
,
graph_capture
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.moe.fused_moe_native
import
fused_moe_forward_native
from
sglang.srt.layers.moe.fused_moe_native
import
fused_moe_forward_native
from
sglang.srt.layers.torchao_utils
import
save_gemlite_cache
from
sglang.srt.layers.torchao_utils
import
save_gemlite_cache
...
@@ -63,7 +63,7 @@ def patch_model(
...
@@ -63,7 +63,7 @@ def patch_model(
model
:
torch
.
nn
.
Module
,
model
:
torch
.
nn
.
Module
,
enable_compile
:
bool
,
enable_compile
:
bool
,
batch_size
:
int
,
batch_size
:
int
,
tp_group
:
"
GroupCoordinator
"
,
tp_group
:
GroupCoordinator
,
):
):
"""Patch the model to make it compatible with with torch.compile"""
"""Patch the model to make it compatible with with torch.compile"""
backup_ca_comm
=
None
backup_ca_comm
=
None
...
@@ -149,9 +149,18 @@ class CudaGraphRunner:
...
@@ -149,9 +149,18 @@ class CudaGraphRunner:
and
bs
<=
model_runner
.
server_args
.
cuda_graph_max_bs
and
bs
<=
model_runner
.
server_args
.
cuda_graph_max_bs
]
]
self
.
compile_bs
=
(
[
bs
for
bs
in
self
.
capture_bs
if
bs
<=
self
.
model_runner
.
server_args
.
torch_compile_max_bs
]
if
self
.
use_torch_compile
else
[]
)
self
.
capture_forward_mode
=
ForwardMode
.
DECODE
self
.
capture_forward_mode
=
ForwardMode
.
DECODE
self
.
num_tokens_per_bs
=
1
self
.
num_tokens_per_bs
=
1
if
model_runner
.
spec_algorithm
.
is_eagle
():
if
model_runner
.
spec_algorithm
.
is_eagle
():
if
self
.
model_runner
.
is_draft_worker
:
if
self
.
model_runner
.
is_draft_worker
:
self
.
num_tokens_per_bs
=
(
self
.
num_tokens_per_bs
=
(
...
@@ -163,16 +172,6 @@ class CudaGraphRunner:
...
@@ -163,16 +172,6 @@ class CudaGraphRunner:
self
.
model_runner
.
server_args
.
speculative_num_draft_tokens
self
.
model_runner
.
server_args
.
speculative_num_draft_tokens
)
)
self
.
compile_bs
=
(
[
bs
for
bs
in
self
.
capture_bs
if
bs
<=
self
.
model_runner
.
server_args
.
torch_compile_max_bs
]
if
self
.
use_torch_compile
else
[]
)
# Attention backend
# Attention backend
self
.
max_bs
=
max
(
self
.
capture_bs
)
self
.
max_bs
=
max
(
self
.
capture_bs
)
self
.
max_num_token
=
self
.
max_bs
*
self
.
num_tokens_per_bs
self
.
max_num_token
=
self
.
max_bs
*
self
.
num_tokens_per_bs
...
@@ -180,7 +179,6 @@ class CudaGraphRunner:
...
@@ -180,7 +179,6 @@ class CudaGraphRunner:
self
.
seq_len_fill_value
=
(
self
.
seq_len_fill_value
=
(
self
.
model_runner
.
attn_backend
.
get_cuda_graph_seq_len_fill_value
()
self
.
model_runner
.
attn_backend
.
get_cuda_graph_seq_len_fill_value
()
)
)
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
self
.
encoder_len_fill_value
=
0
self
.
encoder_len_fill_value
=
0
...
@@ -189,14 +187,14 @@ class CudaGraphRunner:
...
@@ -189,14 +187,14 @@ class CudaGraphRunner:
# Common inputs
# Common inputs
with
torch
.
device
(
"cuda"
):
with
torch
.
device
(
"cuda"
):
self
.
input_ids
=
torch
.
zeros
((
self
.
max_num_token
,),
dtype
=
torch
.
int
32
)
self
.
input_ids
=
torch
.
zeros
((
self
.
max_num_token
,),
dtype
=
torch
.
int
64
)
self
.
req_pool_indices
=
torch
.
zeros
((
self
.
max_bs
,),
dtype
=
torch
.
int32
)
self
.
req_pool_indices
=
torch
.
zeros
((
self
.
max_bs
,),
dtype
=
torch
.
int32
)
self
.
seq_lens
=
torch
.
full
(
self
.
seq_lens
=
torch
.
full
(
(
self
.
max_bs
,),
self
.
seq_len_fill_value
,
dtype
=
torch
.
int32
(
self
.
max_bs
,),
self
.
seq_len_fill_value
,
dtype
=
torch
.
int32
)
)
self
.
out_cache_loc
=
torch
.
zeros
((
self
.
max_num_token
,),
dtype
=
torch
.
int
32
)
self
.
out_cache_loc
=
torch
.
zeros
((
self
.
max_num_token
,),
dtype
=
torch
.
int
64
)
self
.
positions
=
torch
.
zeros
((
self
.
max_num_token
,),
dtype
=
torch
.
int64
)
self
.
positions
=
torch
.
zeros
((
self
.
max_num_token
,),
dtype
=
torch
.
int64
)
self
.
mrope_positions
=
torch
.
zeros
((
3
,
self
.
max_bs
),
dtype
=
torch
.
int
32
)
self
.
mrope_positions
=
torch
.
zeros
((
3
,
self
.
max_bs
),
dtype
=
torch
.
int
64
)
# Speculative_inference
# Speculative_inference
if
model_runner
.
spec_algorithm
.
is_eagle
():
if
model_runner
.
spec_algorithm
.
is_eagle
():
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
3d8f1c9b
...
@@ -38,7 +38,7 @@ import triton
...
@@ -38,7 +38,7 @@ import triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sglang.srt.layers.rotary_embedding
import
MRotaryEmbedding
from
sglang.srt.layers.rotary_embedding
import
MRotaryEmbedding
from
sglang.srt.utils
import
maybe_torch_compile
from
sglang.srt.utils
import
get_compiler_backend
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.attention
import
AttentionBackend
from
sglang.srt.layers.attention
import
AttentionBackend
...
@@ -415,6 +415,6 @@ def compute_position_torch(
...
@@ -415,6 +415,6 @@ def compute_position_torch(
return
positions
.
to
(
torch
.
int64
),
extend_start_loc
return
positions
.
to
(
torch
.
int64
),
extend_start_loc
@
maybe_
torch
_
compile
(
dynamic
=
True
)
@
torch
.
compile
(
dynamic
=
True
,
backend
=
get_compiler_backend
()
)
def
clamp_position
(
seq_lens
):
def
clamp_position
(
seq_lens
):
return
torch
.
clamp
((
seq_lens
-
1
),
min
=
0
).
to
(
torch
.
int64
)
return
torch
.
clamp
((
seq_lens
-
1
),
min
=
0
).
to
(
torch
.
int64
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment