Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2fc12995
Unverified
Commit
2fc12995
authored
Jun 08, 2025
by
fzyzcjy
Committed by
GitHub
Jun 08, 2025
Browse files
Remove unnecessary kernels of num_token_non_padded (#6965)
parent
20d3ad3b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
33 additions
and
27 deletions
+33
-27
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+14
-10
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+19
-14
python/sglang/srt/two_batch_overlap.py
python/sglang/srt/two_batch_overlap.py
+0
-3
No files found.
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
2fc12995
...
@@ -35,6 +35,7 @@ from sglang.srt.model_executor.forward_batch_info import (
...
@@ -35,6 +35,7 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch
,
ForwardBatch
,
ForwardMode
,
ForwardMode
,
PPProxyTensors
,
PPProxyTensors
,
enable_num_token_non_padded
,
)
)
from
sglang.srt.patch_torch
import
monkey_patch_torch_compile
from
sglang.srt.patch_torch
import
monkey_patch_torch_compile
from
sglang.srt.two_batch_overlap
import
TboCudaGraphRunnerPlugin
from
sglang.srt.two_batch_overlap
import
TboCudaGraphRunnerPlugin
...
@@ -190,6 +191,9 @@ class CudaGraphRunner:
...
@@ -190,6 +191,9 @@ class CudaGraphRunner:
self
.
is_encoder_decoder
=
model_runner
.
model_config
.
is_encoder_decoder
self
.
is_encoder_decoder
=
model_runner
.
model_config
.
is_encoder_decoder
self
.
enable_dp_attention
=
model_runner
.
server_args
.
enable_dp_attention
self
.
enable_dp_attention
=
model_runner
.
server_args
.
enable_dp_attention
self
.
enable_sp_layernorm
=
model_runner
.
server_args
.
enable_sp_layernorm
self
.
enable_sp_layernorm
=
model_runner
.
server_args
.
enable_sp_layernorm
self
.
enable_two_batch_overlap
=
(
model_runner
.
server_args
.
enable_two_batch_overlap
)
self
.
speculative_algorithm
=
model_runner
.
server_args
.
speculative_algorithm
self
.
speculative_algorithm
=
model_runner
.
server_args
.
speculative_algorithm
self
.
tp_size
=
model_runner
.
server_args
.
tp_size
self
.
tp_size
=
model_runner
.
server_args
.
tp_size
self
.
dp_size
=
model_runner
.
server_args
.
dp_size
self
.
dp_size
=
model_runner
.
server_args
.
dp_size
...
@@ -327,9 +331,7 @@ class CudaGraphRunner:
...
@@ -327,9 +331,7 @@ class CudaGraphRunner:
)
)
is_tbo_supported
=
(
is_tbo_supported
=
(
forward_batch
.
can_run_tbo
forward_batch
.
can_run_tbo
if
self
.
enable_two_batch_overlap
else
True
if
self
.
model_runner
.
server_args
.
enable_two_batch_overlap
else
True
)
)
return
is_bs_supported
and
is_encoder_lens_supported
and
is_tbo_supported
return
is_bs_supported
and
is_encoder_lens_supported
and
is_tbo_supported
...
@@ -549,13 +551,7 @@ class CudaGraphRunner:
...
@@ -549,13 +551,7 @@ class CudaGraphRunner:
self
.
seq_lens
[:
raw_bs
].
copy_
(
forward_batch
.
seq_lens
)
self
.
seq_lens
[:
raw_bs
].
copy_
(
forward_batch
.
seq_lens
)
self
.
out_cache_loc
[:
raw_num_token
].
copy_
(
forward_batch
.
out_cache_loc
)
self
.
out_cache_loc
[:
raw_num_token
].
copy_
(
forward_batch
.
out_cache_loc
)
self
.
positions
[:
raw_num_token
].
copy_
(
forward_batch
.
positions
)
self
.
positions
[:
raw_num_token
].
copy_
(
forward_batch
.
positions
)
num_token_non_padded
=
len
(
forward_batch
.
input_ids
)
self
.
num_token_non_padded
[...]
=
num_token_non_padded
self
.
tbo_plugin
.
replay_prepare
(
forward_mode
=
forward_batch
.
forward_mode
,
bs
=
bs
,
num_token_non_padded
=
num_token_non_padded
,
)
if
forward_batch
.
seq_lens_cpu
is
not
None
:
if
forward_batch
.
seq_lens_cpu
is
not
None
:
if
bs
!=
raw_bs
:
if
bs
!=
raw_bs
:
self
.
seq_lens_cpu
.
fill_
(
1
)
self
.
seq_lens_cpu
.
fill_
(
1
)
...
@@ -572,6 +568,14 @@ class CudaGraphRunner:
...
@@ -572,6 +568,14 @@ class CudaGraphRunner:
self
.
mrope_positions
[:,
:
raw_bs
].
copy_
(
forward_batch
.
mrope_positions
)
self
.
mrope_positions
[:,
:
raw_bs
].
copy_
(
forward_batch
.
mrope_positions
)
if
self
.
enable_dp_attention
or
self
.
enable_sp_layernorm
:
if
self
.
enable_dp_attention
or
self
.
enable_sp_layernorm
:
self
.
global_num_tokens_gpu
.
copy_
(
forward_batch
.
global_num_tokens_gpu
)
self
.
global_num_tokens_gpu
.
copy_
(
forward_batch
.
global_num_tokens_gpu
)
if
enable_num_token_non_padded
(
self
.
model_runner
.
server_args
):
self
.
num_token_non_padded
.
copy_
(
forward_batch
.
num_token_non_padded
)
if
self
.
enable_two_batch_overlap
:
self
.
tbo_plugin
.
replay_prepare
(
forward_mode
=
forward_batch
.
forward_mode
,
bs
=
bs
,
num_token_non_padded
=
len
(
forward_batch
.
input_ids
),
)
# Attention backend
# Attention backend
self
.
model_runner
.
attn_backend
.
init_forward_metadata_replay_cuda_graph
(
self
.
model_runner
.
attn_backend
.
init_forward_metadata_replay_cuda_graph
(
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
2fc12995
...
@@ -118,6 +118,7 @@ class ForwardMode(IntEnum):
...
@@ -118,6 +118,7 @@ class ForwardMode(IntEnum):
class
CaptureHiddenMode
(
IntEnum
):
class
CaptureHiddenMode
(
IntEnum
):
# Do not capture anything.
NULL
=
auto
()
NULL
=
auto
()
# Capture hidden states of all tokens.
# Capture hidden states of all tokens.
FULL
=
auto
()
FULL
=
auto
()
...
@@ -253,6 +254,7 @@ class ForwardBatch:
...
@@ -253,6 +254,7 @@ class ForwardBatch:
# For Qwen2-VL
# For Qwen2-VL
mrope_positions
:
torch
.
Tensor
=
None
mrope_positions
:
torch
.
Tensor
=
None
# For two-batch overlap
tbo_split_seq_index
:
Optional
[
int
]
=
None
tbo_split_seq_index
:
Optional
[
int
]
=
None
tbo_parent_token_range
:
Optional
[
Tuple
[
int
,
int
]]
=
None
tbo_parent_token_range
:
Optional
[
Tuple
[
int
,
int
]]
=
None
tbo_children
:
Optional
[
List
[
"ForwardBatch"
]]
=
None
tbo_children
:
Optional
[
List
[
"ForwardBatch"
]]
=
None
...
@@ -265,12 +267,6 @@ class ForwardBatch:
...
@@ -265,12 +267,6 @@ class ForwardBatch:
):
):
from
sglang.srt.two_batch_overlap
import
TboForwardBatchPreparer
from
sglang.srt.two_batch_overlap
import
TboForwardBatchPreparer
device
=
model_runner
.
device
extend_input_logprob_token_ids_gpu
=
None
if
batch
.
extend_input_logprob_token_ids
is
not
None
:
extend_input_logprob_token_ids_gpu
=
(
batch
.
extend_input_logprob_token_ids
.
to
(
device
,
non_blocking
=
True
)
)
ret
=
cls
(
ret
=
cls
(
forward_mode
=
batch
.
forward_mode
,
forward_mode
=
batch
.
forward_mode
,
batch_size
=
len
(
batch
.
seq_lens
),
batch_size
=
len
(
batch
.
seq_lens
),
...
@@ -284,6 +280,7 @@ class ForwardBatch:
...
@@ -284,6 +280,7 @@ class ForwardBatch:
encoder_lens_cpu
=
batch
.
encoder_lens_cpu
,
encoder_lens_cpu
=
batch
.
encoder_lens_cpu
,
encoder_out_cache_loc
=
batch
.
encoder_out_cache_loc
,
encoder_out_cache_loc
=
batch
.
encoder_out_cache_loc
,
seq_lens_sum
=
batch
.
seq_lens_sum
,
seq_lens_sum
=
batch
.
seq_lens_sum
,
seq_lens_cpu
=
batch
.
seq_lens_cpu
,
return_logprob
=
batch
.
return_logprob
,
return_logprob
=
batch
.
return_logprob
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
token_ids_logprobs
=
batch
.
token_ids_logprobs
,
token_ids_logprobs
=
batch
.
token_ids_logprobs
,
...
@@ -298,12 +295,19 @@ class ForwardBatch:
...
@@ -298,12 +295,19 @@ class ForwardBatch:
spec_info
=
batch
.
spec_info
,
spec_info
=
batch
.
spec_info
,
capture_hidden_mode
=
batch
.
capture_hidden_mode
,
capture_hidden_mode
=
batch
.
capture_hidden_mode
,
input_embeds
=
batch
.
input_embeds
,
input_embeds
=
batch
.
input_embeds
,
extend_input_logprob_token_ids_gpu
=
extend_input_logprob_token_ids_gpu
,
num_token_non_padded
=
torch
.
tensor
(
len
(
batch
.
input_ids
),
dtype
=
torch
.
int32
).
to
(
device
,
non_blocking
=
True
),
tbo_split_seq_index
=
batch
.
tbo_split_seq_index
,
tbo_split_seq_index
=
batch
.
tbo_split_seq_index
,
)
)
device
=
model_runner
.
device
if
batch
.
extend_input_logprob_token_ids
is
not
None
:
ret
.
extend_input_logprob_token_ids_gpu
=
(
batch
.
extend_input_logprob_token_ids
.
to
(
device
,
non_blocking
=
True
)
)
if
enable_num_token_non_padded
(
model_runner
.
server_args
):
ret
.
num_token_non_padded
=
torch
.
tensor
(
len
(
batch
.
input_ids
),
dtype
=
torch
.
int32
).
to
(
device
,
non_blocking
=
True
)
# For DP attention
# For DP attention
if
batch
.
global_num_tokens
is
not
None
:
if
batch
.
global_num_tokens
is
not
None
:
...
@@ -323,6 +327,7 @@ class ForwardBatch:
...
@@ -323,6 +327,7 @@ class ForwardBatch:
dtype
=
model_runner
.
dtype
,
dtype
=
model_runner
.
dtype
,
device
=
device
,
device
=
device
,
)
)
if
ret
.
forward_mode
.
is_idle
():
if
ret
.
forward_mode
.
is_idle
():
ret
.
positions
=
torch
.
empty
((
0
,),
device
=
device
)
ret
.
positions
=
torch
.
empty
((
0
,),
device
=
device
)
TboForwardBatchPreparer
.
prepare
(
ret
)
TboForwardBatchPreparer
.
prepare
(
ret
)
...
@@ -335,10 +340,6 @@ class ForwardBatch:
...
@@ -335,10 +340,6 @@ class ForwardBatch:
):
):
ret
.
positions
=
ret
.
spec_info
.
positions
ret
.
positions
=
ret
.
spec_info
.
positions
# Get seq_lens_cpu if needed
if
ret
.
seq_lens_cpu
is
None
:
ret
.
seq_lens_cpu
=
batch
.
seq_lens_cpu
# Init position information
# Init position information
if
ret
.
forward_mode
.
is_decode
():
if
ret
.
forward_mode
.
is_decode
():
if
ret
.
positions
is
None
:
if
ret
.
positions
is
None
:
...
@@ -605,6 +606,10 @@ class ForwardBatch:
...
@@ -605,6 +606,10 @@ class ForwardBatch:
return
self
.
tbo_split_seq_index
is
not
None
return
self
.
tbo_split_seq_index
is
not
None
def
enable_num_token_non_padded
(
server_args
):
return
server_args
.
enable_ep_moe
or
server_args
.
enable_deepep_moe
class
PPProxyTensors
:
class
PPProxyTensors
:
# adapted from https://github.com/vllm-project/vllm/blob/d14e98d924724b284dc5eaf8070d935e214e50c0/vllm/sequence.py#L1103
# adapted from https://github.com/vllm-project/vllm/blob/d14e98d924724b284dc5eaf8070d935e214e50c0/vllm/sequence.py#L1103
tensors
:
Dict
[
str
,
torch
.
Tensor
]
tensors
:
Dict
[
str
,
torch
.
Tensor
]
...
...
python/sglang/srt/two_batch_overlap.py
View file @
2fc12995
...
@@ -131,9 +131,6 @@ class TboCudaGraphRunnerPlugin:
...
@@ -131,9 +131,6 @@ class TboCudaGraphRunnerPlugin:
def
replay_prepare
(
def
replay_prepare
(
self
,
forward_mode
:
ForwardMode
,
bs
:
int
,
num_token_non_padded
:
int
self
,
forward_mode
:
ForwardMode
,
bs
:
int
,
num_token_non_padded
:
int
):
):
if
not
global_server_args_dict
[
"enable_two_batch_overlap"
]:
return
tbo_split_seq_index
,
tbo_split_token_index
=
(
tbo_split_seq_index
,
tbo_split_token_index
=
(
compute_split_indices_for_cuda_graph_replay
(
compute_split_indices_for_cuda_graph_replay
(
forward_mode
=
forward_mode
,
forward_mode
=
forward_mode
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment