Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ac5010e0
Unverified
Commit
ac5010e0
authored
Jun 22, 2025
by
Cheng Wan
Committed by
GitHub
Jun 22, 2025
Browse files
Fix CUDA Graph Check under Deepep with DP FFN (#7451)
parent
3cee035e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
40 additions
and
40 deletions
+40
-40
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+13
-11
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
...n/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
+14
-13
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
...g/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
+13
-16
No files found.
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
ac5010e0
...
...
@@ -48,6 +48,7 @@ from sglang.srt.utils import (
rank0_log
,
require_attn_tp_gather
,
require_gathered_buffer
,
require_mlp_sync
,
require_mlp_tp_gather
,
)
...
...
@@ -212,6 +213,7 @@ class CudaGraphRunner:
self
.
is_encoder_decoder
=
model_runner
.
model_config
.
is_encoder_decoder
self
.
require_gathered_buffer
=
require_gathered_buffer
(
model_runner
.
server_args
)
self
.
require_mlp_tp_gather
=
require_mlp_tp_gather
(
model_runner
.
server_args
)
self
.
require_mlp_sync
=
require_mlp_sync
(
model_runner
.
server_args
)
self
.
require_attn_tp_gather
=
require_attn_tp_gather
(
model_runner
.
server_args
)
self
.
enable_two_batch_overlap
=
(
model_runner
.
server_args
.
enable_two_batch_overlap
...
...
@@ -337,22 +339,22 @@ class CudaGraphRunner:
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
if
self
.
require_mlp_tp_gather
:
total_batch_size
=
(
cuda_graph_bs
=
(
sum
(
forward_batch
.
global_num_tokens_cpu
)
//
self
.
num_tokens_per_bs
if
self
.
model_runner
.
spec_algorithm
.
is_eagle
()
else
sum
(
forward_batch
.
global_num_tokens_cpu
)
)
is_bs_supported
=
forward_batch
.
can_run_dp_cuda_graph
and
(
total_batch_size
in
self
.
graphs
if
self
.
disable_padding
else
total_batch_size
<=
self
.
max_bs
)
else
:
is_bs_supported
=
(
forward_batch
.
batch_size
in
self
.
graphs
if
self
.
disable_padding
else
forward_batch
.
batch_size
<=
self
.
max_bs
)
cuda_graph_bs
=
forward_batch
.
batch_size
is_bs_supported
=
(
cuda_graph_bs
in
self
.
graphs
if
self
.
disable_padding
else
cuda_graph_bs
<=
self
.
max_bs
)
if
self
.
require_mlp_sync
:
is_bs_supported
=
is_bs_supported
and
forward_batch
.
can_run_dp_cuda_graph
# NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
# If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
...
...
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
View file @
ac5010e0
...
...
@@ -23,6 +23,7 @@ from sglang.srt.speculative.eagle_utils import EagleDraftInput
from
sglang.srt.utils
import
(
require_attn_tp_gather
,
require_gathered_buffer
,
require_mlp_sync
,
require_mlp_tp_gather
,
)
...
...
@@ -46,6 +47,7 @@ class EAGLEDraftCudaGraphRunner:
self
.
is_encoder_decoder
=
model_runner
.
model_config
.
is_encoder_decoder
self
.
require_gathered_buffer
=
require_gathered_buffer
(
model_runner
.
server_args
)
self
.
require_mlp_tp_gather
=
require_mlp_tp_gather
(
model_runner
.
server_args
)
self
.
require_mlp_sync
=
require_mlp_sync
(
model_runner
.
server_args
)
self
.
require_attn_tp_gather
=
require_attn_tp_gather
(
model_runner
.
server_args
)
self
.
dp_size
=
self
.
model_runner
.
dp_size
self
.
tp_size
=
self
.
model_runner
.
tp_size
...
...
@@ -127,24 +129,23 @@ class EAGLEDraftCudaGraphRunner:
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
if
self
.
require_mlp_tp_gather
:
if
not
forward_batch
.
can_run_dp_cuda_graph
:
return
False
total_batch_size
=
(
cuda_graph_bs
=
(
sum
(
forward_batch
.
global_num_tokens_cpu
)
//
self
.
num_tokens_per_bs
if
self
.
model_runner
.
spec_algorithm
.
is_eagle
()
else
sum
(
forward_batch
.
global_num_tokens_cpu
)
)
is_bs_supported
=
(
total_batch_size
in
self
.
graphs
if
self
.
disable_padding
else
total_batch_size
<=
self
.
max_bs
)
else
:
is_bs_supported
=
(
forward_batch
.
batch_size
in
self
.
graphs
if
self
.
disable_padding
else
forward_batch
.
batch_size
<=
self
.
max_bs
)
cuda_graph_bs
=
forward_batch
.
batch_size
is_bs_supported
=
(
cuda_graph_bs
in
self
.
graphs
if
self
.
disable_padding
else
cuda_graph_bs
<=
self
.
max_bs
)
if
self
.
require_mlp_sync
:
is_bs_supported
=
is_bs_supported
and
forward_batch
.
can_run_dp_cuda_graph
return
is_bs_supported
def
capture
(
self
):
...
...
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
View file @
ac5010e0
...
...
@@ -24,6 +24,7 @@ from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk
from
sglang.srt.utils
import
(
require_attn_tp_gather
,
require_gathered_buffer
,
require_mlp_sync
,
require_mlp_tp_gather
,
)
...
...
@@ -42,6 +43,7 @@ class EAGLEDraftExtendCudaGraphRunner:
self
.
disable_padding
=
model_runner
.
server_args
.
disable_cuda_graph_padding
self
.
require_gathered_buffer
=
require_gathered_buffer
(
model_runner
.
server_args
)
self
.
require_mlp_tp_gather
=
require_mlp_tp_gather
(
model_runner
.
server_args
)
self
.
require_mlp_sync
=
require_mlp_sync
(
model_runner
.
server_args
)
self
.
require_attn_tp_gather
=
require_attn_tp_gather
(
model_runner
.
server_args
)
self
.
tp_size
=
self
.
model_runner
.
tp_size
self
.
dp_size
=
model_runner
.
server_args
.
dp_size
...
...
@@ -130,29 +132,24 @@ class EAGLEDraftExtendCudaGraphRunner:
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
if
self
.
require_mlp_tp_gather
:
if
not
forward_batch
.
can_run_dp_cuda_graph
:
return
False
total_batch_size
=
(
cuda_graph_bs
=
(
sum
(
forward_batch
.
global_num_tokens_cpu
)
//
self
.
num_tokens_per_bs
if
self
.
model_runner
.
spec_algorithm
.
is_eagle
()
else
sum
(
forward_batch
.
global_num_tokens_cpu
)
)
is_bs_supported
=
(
total_batch_size
in
self
.
graphs
if
self
.
disable_padding
else
total_batch_size
<=
self
.
max_bs
)
return
is_bs_supported
else
:
batch_size
=
forward_batch
.
seq_lens
.
numel
()
cuda_graph_bs
=
forward_batch
.
seq_lens
.
numel
()
is_bs_supported
=
(
batch_size
in
self
.
graphs
if
self
.
disable_padding
else
batch_size
<=
self
.
max_bs
)
is_bs_supported
=
(
cuda_graph_bs
in
self
.
graphs
if
self
.
disable_padding
else
cuda_graph_bs
<=
self
.
max_bs
)
if
self
.
require_mlp_sync
:
is_bs_supported
=
is_bs_supported
and
forward_batch
.
can_run_dp_cuda_graph
return
is_bs_supported
return
is_bs_supported
def
capture
(
self
):
CudaGraphRunner
.
capture
(
self
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment