Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2fc12995
Unverified
Commit
2fc12995
authored
Jun 08, 2025
by
fzyzcjy
Committed by
GitHub
Jun 08, 2025
Browse files
Remove unnecessary kernels of num_token_non_padded (#6965)
parent
20d3ad3b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
33 additions
and
27 deletions
+33
-27
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+14
-10
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+19
-14
python/sglang/srt/two_batch_overlap.py
python/sglang/srt/two_batch_overlap.py
+0
-3
No files found.
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
2fc12995
...
...
@@ -35,6 +35,7 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch
,
ForwardMode
,
PPProxyTensors
,
enable_num_token_non_padded
,
)
from
sglang.srt.patch_torch
import
monkey_patch_torch_compile
from
sglang.srt.two_batch_overlap
import
TboCudaGraphRunnerPlugin
...
...
@@ -190,6 +191,9 @@ class CudaGraphRunner:
self
.
is_encoder_decoder
=
model_runner
.
model_config
.
is_encoder_decoder
self
.
enable_dp_attention
=
model_runner
.
server_args
.
enable_dp_attention
self
.
enable_sp_layernorm
=
model_runner
.
server_args
.
enable_sp_layernorm
self
.
enable_two_batch_overlap
=
(
model_runner
.
server_args
.
enable_two_batch_overlap
)
self
.
speculative_algorithm
=
model_runner
.
server_args
.
speculative_algorithm
self
.
tp_size
=
model_runner
.
server_args
.
tp_size
self
.
dp_size
=
model_runner
.
server_args
.
dp_size
...
...
@@ -327,9 +331,7 @@ class CudaGraphRunner:
)
is_tbo_supported
=
(
forward_batch
.
can_run_tbo
if
self
.
model_runner
.
server_args
.
enable_two_batch_overlap
else
True
forward_batch
.
can_run_tbo
if
self
.
enable_two_batch_overlap
else
True
)
return
is_bs_supported
and
is_encoder_lens_supported
and
is_tbo_supported
...
...
@@ -549,13 +551,7 @@ class CudaGraphRunner:
self
.
seq_lens
[:
raw_bs
].
copy_
(
forward_batch
.
seq_lens
)
self
.
out_cache_loc
[:
raw_num_token
].
copy_
(
forward_batch
.
out_cache_loc
)
self
.
positions
[:
raw_num_token
].
copy_
(
forward_batch
.
positions
)
num_token_non_padded
=
len
(
forward_batch
.
input_ids
)
self
.
num_token_non_padded
[...]
=
num_token_non_padded
self
.
tbo_plugin
.
replay_prepare
(
forward_mode
=
forward_batch
.
forward_mode
,
bs
=
bs
,
num_token_non_padded
=
num_token_non_padded
,
)
if
forward_batch
.
seq_lens_cpu
is
not
None
:
if
bs
!=
raw_bs
:
self
.
seq_lens_cpu
.
fill_
(
1
)
...
...
@@ -572,6 +568,14 @@ class CudaGraphRunner:
self
.
mrope_positions
[:,
:
raw_bs
].
copy_
(
forward_batch
.
mrope_positions
)
if
self
.
enable_dp_attention
or
self
.
enable_sp_layernorm
:
self
.
global_num_tokens_gpu
.
copy_
(
forward_batch
.
global_num_tokens_gpu
)
if
enable_num_token_non_padded
(
self
.
model_runner
.
server_args
):
self
.
num_token_non_padded
.
copy_
(
forward_batch
.
num_token_non_padded
)
if
self
.
enable_two_batch_overlap
:
self
.
tbo_plugin
.
replay_prepare
(
forward_mode
=
forward_batch
.
forward_mode
,
bs
=
bs
,
num_token_non_padded
=
len
(
forward_batch
.
input_ids
),
)
# Attention backend
self
.
model_runner
.
attn_backend
.
init_forward_metadata_replay_cuda_graph
(
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
2fc12995
...
...
@@ -118,6 +118,7 @@ class ForwardMode(IntEnum):
class
CaptureHiddenMode
(
IntEnum
):
# Do not capture anything.
NULL
=
auto
()
# Capture hidden states of all tokens.
FULL
=
auto
()
...
...
@@ -253,6 +254,7 @@ class ForwardBatch:
# For Qwen2-VL
mrope_positions
:
torch
.
Tensor
=
None
# For two-batch overlap
tbo_split_seq_index
:
Optional
[
int
]
=
None
tbo_parent_token_range
:
Optional
[
Tuple
[
int
,
int
]]
=
None
tbo_children
:
Optional
[
List
[
"ForwardBatch"
]]
=
None
...
...
@@ -265,12 +267,6 @@ class ForwardBatch:
):
from
sglang.srt.two_batch_overlap
import
TboForwardBatchPreparer
device
=
model_runner
.
device
extend_input_logprob_token_ids_gpu
=
None
if
batch
.
extend_input_logprob_token_ids
is
not
None
:
extend_input_logprob_token_ids_gpu
=
(
batch
.
extend_input_logprob_token_ids
.
to
(
device
,
non_blocking
=
True
)
)
ret
=
cls
(
forward_mode
=
batch
.
forward_mode
,
batch_size
=
len
(
batch
.
seq_lens
),
...
...
@@ -284,6 +280,7 @@ class ForwardBatch:
encoder_lens_cpu
=
batch
.
encoder_lens_cpu
,
encoder_out_cache_loc
=
batch
.
encoder_out_cache_loc
,
seq_lens_sum
=
batch
.
seq_lens_sum
,
seq_lens_cpu
=
batch
.
seq_lens_cpu
,
return_logprob
=
batch
.
return_logprob
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
token_ids_logprobs
=
batch
.
token_ids_logprobs
,
...
...
@@ -298,12 +295,19 @@ class ForwardBatch:
spec_info
=
batch
.
spec_info
,
capture_hidden_mode
=
batch
.
capture_hidden_mode
,
input_embeds
=
batch
.
input_embeds
,
extend_input_logprob_token_ids_gpu
=
extend_input_logprob_token_ids_gpu
,
num_token_non_padded
=
torch
.
tensor
(
len
(
batch
.
input_ids
),
dtype
=
torch
.
int32
).
to
(
device
,
non_blocking
=
True
),
tbo_split_seq_index
=
batch
.
tbo_split_seq_index
,
)
device
=
model_runner
.
device
if
batch
.
extend_input_logprob_token_ids
is
not
None
:
ret
.
extend_input_logprob_token_ids_gpu
=
(
batch
.
extend_input_logprob_token_ids
.
to
(
device
,
non_blocking
=
True
)
)
if
enable_num_token_non_padded
(
model_runner
.
server_args
):
ret
.
num_token_non_padded
=
torch
.
tensor
(
len
(
batch
.
input_ids
),
dtype
=
torch
.
int32
).
to
(
device
,
non_blocking
=
True
)
# For DP attention
if
batch
.
global_num_tokens
is
not
None
:
...
...
@@ -323,6 +327,7 @@ class ForwardBatch:
dtype
=
model_runner
.
dtype
,
device
=
device
,
)
if
ret
.
forward_mode
.
is_idle
():
ret
.
positions
=
torch
.
empty
((
0
,),
device
=
device
)
TboForwardBatchPreparer
.
prepare
(
ret
)
...
...
@@ -335,10 +340,6 @@ class ForwardBatch:
):
ret
.
positions
=
ret
.
spec_info
.
positions
# Get seq_lens_cpu if needed
if
ret
.
seq_lens_cpu
is
None
:
ret
.
seq_lens_cpu
=
batch
.
seq_lens_cpu
# Init position information
if
ret
.
forward_mode
.
is_decode
():
if
ret
.
positions
is
None
:
...
...
@@ -605,6 +606,10 @@ class ForwardBatch:
return
self
.
tbo_split_seq_index
is
not
None
def
enable_num_token_non_padded
(
server_args
):
return
server_args
.
enable_ep_moe
or
server_args
.
enable_deepep_moe
class
PPProxyTensors
:
# adapted from https://github.com/vllm-project/vllm/blob/d14e98d924724b284dc5eaf8070d935e214e50c0/vllm/sequence.py#L1103
tensors
:
Dict
[
str
,
torch
.
Tensor
]
...
...
python/sglang/srt/two_batch_overlap.py
View file @
2fc12995
...
...
@@ -131,9 +131,6 @@ class TboCudaGraphRunnerPlugin:
def
replay_prepare
(
self
,
forward_mode
:
ForwardMode
,
bs
:
int
,
num_token_non_padded
:
int
):
if
not
global_server_args_dict
[
"enable_two_batch_overlap"
]:
return
tbo_split_seq_index
,
tbo_split_token_index
=
(
compute_split_indices_for_cuda_graph_replay
(
forward_mode
=
forward_mode
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment