Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
60fdad7c
Unverified
Commit
60fdad7c
authored
Jun 06, 2025
by
Lianmin Zheng
Committed by
GitHub
Jun 06, 2025
Browse files
Sync the changes on cuda graph runners (#6932)
parent
61ce91ed
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
63 additions
and
70 deletions
+63
-70
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+2
-2
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+21
-30
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+13
-2
python/sglang/srt/speculative/build_eagle_tree.py
python/sglang/srt/speculative/build_eagle_tree.py
+8
-8
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
...n/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
+3
-8
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
...g/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
+3
-9
python/sglang/srt/speculative/eagle_utils.py
python/sglang/srt/speculative/eagle_utils.py
+13
-11
No files found.
python/sglang/srt/managers/io_struct.py
View file @
60fdad7c
...
...
@@ -20,7 +20,7 @@ import copy
import
uuid
from
dataclasses
import
dataclass
,
field
from
enum
import
Enum
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Literal
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Union
from
sglang.srt.mm_utils
import
has_valid_data
...
...
@@ -30,7 +30,7 @@ if TYPE_CHECKING:
else
:
Image
=
Any
from
sglang.srt.managers.schedule_batch
import
BaseFinishReason
,
flatten_nested_list
from
sglang.srt.managers.schedule_batch
import
BaseFinishReason
from
sglang.srt.sampling.sampling_params
import
SamplingParams
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
60fdad7c
...
...
@@ -259,23 +259,8 @@ class CudaGraphRunner:
}
# Speculative_inference
if
(
model_runner
.
spec_algorithm
.
is_eagle3
()
and
not
model_runner
.
is_draft_worker
):
self
.
hidden_states
=
torch
.
zeros
(
(
self
.
max_num_token
,
3
*
self
.
model_runner
.
model_config
.
hidden_size
,
),
dtype
=
self
.
model_runner
.
dtype
,
)
if
model_runner
.
spec_algorithm
.
is_eagle3
():
self
.
model_runner
.
model
.
set_eagle3_layers_to_capture
()
elif
model_runner
.
spec_algorithm
.
is_eagle
():
self
.
hidden_states
=
torch
.
zeros
(
(
self
.
max_num_token
,
self
.
model_runner
.
model_config
.
hidden_size
),
dtype
=
self
.
model_runner
.
dtype
,
)
if
self
.
is_encoder_decoder
:
# NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
...
...
@@ -284,6 +269,7 @@ class CudaGraphRunner:
)
else
:
self
.
encoder_lens
=
None
if
self
.
enable_dp_attention
or
self
.
enable_sp_layernorm
:
# TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer
self
.
gathered_buffer
=
torch
.
zeros
(
...
...
@@ -303,13 +289,7 @@ class CudaGraphRunner:
self
.
capture
()
except
RuntimeError
as
e
:
raise
Exception
(
f
"Capture CUDA graph failed:
{
e
}
\n
"
"Possible solutions:
\n
"
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)
\n
"
"2. set --cuda-graph-max-bs to a smaller value (e.g., 16)
\n
"
"3. disable torch compile by not using --enable-torch-compile
\n
"
"4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)
\n
"
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose
\n
"
f
"Capture cuda graph failed:
{
e
}
\n
{
CUDA_GRAPH_CAPTURE_FAILED_MSG
}
"
)
@
contextmanager
...
...
@@ -439,6 +419,7 @@ class CudaGraphRunner:
self
.
capture_hidden_mode
=
(
spec_info
.
capture_hidden_mode
if
spec_info
else
CaptureHiddenMode
.
NULL
)
if
self
.
model_runner
.
server_args
.
lora_paths
is
not
None
:
# Currently, if the lora_path in `lora_paths` is None, the lora backend will use a
# different logic to handle lora, so we need to set `lora_paths` to a list of non-None
...
...
@@ -467,9 +448,9 @@ class CudaGraphRunner:
spec_algorithm
=
self
.
model_runner
.
spec_algorithm
,
spec_info
=
spec_info
,
capture_hidden_mode
=
self
.
capture_hidden_mode
,
lora_paths
=
lora_paths
,
num_token_non_padded
=
self
.
num_token_non_padded
,
global_forward_mode
=
self
.
capture_forward_mode
,
lora_paths
=
lora_paths
,
)
self
.
tbo_plugin
.
capture_one_batch_size
(
forward_batch
,
num_tokens
=
num_tokens
)
...
...
@@ -497,7 +478,9 @@ class CudaGraphRunner:
self
.
pp_size
>
1
and
"pp_proxy_tensors"
in
inspect
.
signature
(
forward
).
parameters
):
kwargs
[
"pp_proxy_tensors"
]
=
pp_proxy_tensors
kwargs
[
"pp_proxy_tensors"
]
=
PPProxyTensors
(
{
k
:
v
.
clone
()
for
k
,
v
in
pp_proxy_tensors
.
tensors
.
items
()}
)
logits_output_or_pp_proxy_tensors
=
forward
(
input_ids
,
...
...
@@ -590,9 +573,6 @@ class CudaGraphRunner:
if
self
.
enable_dp_attention
or
self
.
enable_sp_layernorm
:
self
.
global_num_tokens_gpu
.
copy_
(
forward_batch
.
global_num_tokens_gpu
)
if
hasattr
(
forward_batch
.
spec_info
,
"hidden_states"
):
self
.
hidden_states
[:
raw_num_token
]
=
forward_batch
.
spec_info
.
hidden_states
# Attention backend
self
.
model_runner
.
attn_backend
.
init_forward_metadata_replay_cuda_graph
(
bs
,
...
...
@@ -650,7 +630,7 @@ class CudaGraphRunner:
else
:
spec_info
=
EagleVerifyInput
(
draft_token
=
None
,
custom_mask
=
torch
.
zero
s
(
custom_mask
=
torch
.
one
s
(
(
num_tokens
*
self
.
model_runner
.
model_config
.
context_len
),
dtype
=
torch
.
bool
,
device
=
"cuda"
,
...
...
@@ -660,9 +640,20 @@ class CudaGraphRunner:
retrive_next_token
=
None
,
retrive_next_sibling
=
None
,
retrive_cum_len
=
None
,
draft_token_num
=
self
.
model_runner
.
server_args
.
speculative_num_draft_tokens
,
spec_steps
=
self
.
model_runner
.
server_args
.
speculative_num_steps
,
topk
=
self
.
model_runner
.
server_args
.
speculative_eagle_topk
,
draft_token_num
=
self
.
model_runner
.
server_args
.
speculative_num_draft_tokens
,
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
,
)
return
spec_info
CUDA_GRAPH_CAPTURE_FAILED_MSG
=
(
"Possible solutions:
\n
"
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)
\n
"
"2. set --cuda-graph-max-bs to a smaller value (e.g., 16)
\n
"
"3. disable torch compile by not using --enable-torch-compile
\n
"
"4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)
\n
"
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose
\n
"
)
python/sglang/srt/server_args.py
View file @
60fdad7c
...
...
@@ -447,7 +447,7 @@ class ServerArgs:
self
.
speculative_num_steps
,
self
.
speculative_eagle_topk
,
self
.
speculative_num_draft_tokens
,
)
=
auto_choose_speculative_params
(
model_arch
)
)
=
auto_choose_speculative_params
(
self
)
if
self
.
page_size
>
1
and
self
.
speculative_eagle_topk
>
1
:
self
.
speculative_eagle_topk
=
1
...
...
@@ -1655,12 +1655,23 @@ def get_model_arch(args: ServerArgs):
return
hf_config
.
architectures
[
0
]
def
auto_choose_speculative_params
(
arch
:
str
):
def
auto_choose_speculative_params
(
self
:
ServerArgs
):
"""
Automatically choose the parameters for speculative decoding.
You can tune them on your own models and prompts with scripts/playground/bench_speculative.py
"""
kwargs
=
{}
hf_config
=
get_config
(
self
.
model_path
,
trust_remote_code
=
self
.
trust_remote_code
,
revision
=
self
.
revision
,
model_override_args
=
json
.
loads
(
self
.
json_model_override_args
),
**
kwargs
,
)
arch
=
hf_config
.
architectures
[
0
]
if
arch
in
[
"LlamaForCausalLM"
]:
# The default value for llama
return
(
5
,
4
,
8
)
...
...
python/sglang/srt/speculative/build_eagle_tree.py
View file @
60fdad7c
...
...
@@ -4,7 +4,7 @@ from typing import List
import
torch
from
sglang.srt.utils
import
is_cuda
,
is_hip
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
rank0_print
if
is_cuda
()
or
is_hip
():
from
sgl_kernel
import
(
...
...
@@ -344,13 +344,13 @@ def test_build_tree_kernel_efficient():
num_verify_tokens
=
num_draft_token
,
)
first_
rank_print
(
"=========== build tree kernel efficient =========="
)
#
first_
rank_print(f"{tree_mask=}", flush=True)
first_
rank_print
(
f
"
{
position
=
}
"
,
flush
=
True
)
first_
rank_print
(
f
"
{
retrive_index
=
}
"
,
flush
=
True
)
first_
rank_print
(
f
"
{
retrive_next_token
=
}
"
,
flush
=
True
)
first_
rank_print
(
f
"
{
retrive_next_sibling
=
}
"
,
flush
=
True
)
first_
rank_print
(
f
"
{
draft_tokens
=
}
"
,
flush
=
True
)
rank
0
_print
(
"=========== build tree kernel efficient =========="
)
# rank
0
_print(f"{tree_mask=}", flush=True)
rank
0
_print
(
f
"
{
position
=
}
"
,
flush
=
True
)
rank
0
_print
(
f
"
{
retrive_index
=
}
"
,
flush
=
True
)
rank
0
_print
(
f
"
{
retrive_next_token
=
}
"
,
flush
=
True
)
rank
0
_print
(
f
"
{
retrive_next_sibling
=
}
"
,
flush
=
True
)
rank
0
_print
(
f
"
{
draft_tokens
=
}
"
,
flush
=
True
)
assert
position
.
tolist
()
==
[
5
,
6
,
6
,
7
,
7
,
8
,
8
,
9
,
10
,
11
,
12
,
12
,
12
,
12
,
13
,
14
]
assert
retrive_index
.
tolist
()
==
[
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
...
...
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
View file @
60fdad7c
...
...
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Callable
import
torch
from
sglang.srt.model_executor.cuda_graph_runner
import
(
CUDA_GRAPH_CAPTURE_FAILED_MSG
,
CudaGraphRunner
,
get_batch_sizes_to_capture
,
get_global_graph_memory_pool
,
...
...
@@ -73,7 +74,7 @@ class EAGLEDraftCudaGraphRunner:
self
.
topk_p
=
torch
.
zeros
((
self
.
max_bs
,
self
.
topk
),
dtype
=
torch
.
float32
)
self
.
topk_index
=
torch
.
zeros
((
self
.
max_bs
,
self
.
topk
),
dtype
=
torch
.
int64
)
self
.
hidden_states
=
torch
.
zeros
(
(
self
.
max_
bs
,
self
.
model_runner
.
model_config
.
hidden_size
),
(
self
.
max_
num_token
,
self
.
model_runner
.
model_config
.
hidden_size
),
dtype
=
self
.
model_runner
.
dtype
,
)
...
...
@@ -82,13 +83,7 @@ class EAGLEDraftCudaGraphRunner:
self
.
capture
()
except
RuntimeError
as
e
:
raise
Exception
(
f
"Capture CUDA graph failed:
{
e
}
\n
"
"Possible solutions:
\n
"
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)
\n
"
"2. set --cuda-graph-max-bs to a smaller value (e.g., 16)
\n
"
"3. disable torch compile by not using --enable-torch-compile
\n
"
"4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)
\n
"
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose
\n
"
f
"Capture cuda graph failed:
{
e
}
\n
{
CUDA_GRAPH_CAPTURE_FAILED_MSG
}
"
)
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
...
...
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
View file @
60fdad7c
...
...
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Callable
import
torch
from
sglang.srt.model_executor.cuda_graph_runner
import
(
CUDA_GRAPH_CAPTURE_FAILED_MSG
,
CudaGraphRunner
,
LogitsProcessorOutput
,
get_batch_sizes_to_capture
,
...
...
@@ -89,13 +90,7 @@ class EAGLEDraftExtendCudaGraphRunner:
self
.
capture
()
except
RuntimeError
as
e
:
raise
Exception
(
f
"Capture CUDA graph failed:
{
e
}
\n
"
"Possible solutions:
\n
"
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)
\n
"
"2. set --cuda-graph-max-bs to a smaller value (e.g., 16)
\n
"
"3. disable torch compile by not using --enable-torch-compile
\n
"
"4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)
\n
"
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose
\n
"
f
"Capture cuda graph failed:
{
e
}
\n
{
CUDA_GRAPH_CAPTURE_FAILED_MSG
}
"
)
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
...
...
@@ -200,7 +195,6 @@ class EAGLEDraftExtendCudaGraphRunner:
# in the batch, which will not be counted as num_seqs
raw_bs
=
forward_batch
.
batch_size
num_tokens
=
forward_batch
.
input_ids
.
shape
[
0
]
assert
raw_bs
*
self
.
num_tokens_per_bs
==
num_tokens
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
raw_bs
)
bs
=
self
.
capture_bs
[
index
]
...
...
@@ -224,9 +218,9 @@ class EAGLEDraftExtendCudaGraphRunner:
self
.
seq_lens_cpu
.
fill_
(
1
)
self
.
seq_lens_cpu
[:
raw_bs
].
copy_
(
forward_batch
.
seq_lens_cpu
)
forward_batch
.
spec_info
.
positions
=
None
if
bs
!=
raw_bs
:
forward_batch
.
spec_info
.
accept_length
=
self
.
accept_length
[:
bs
]
forward_batch
.
spec_info
.
positions
=
None
self
.
eagle_worker
.
draft_extend_attn_backend
.
init_forward_metadata_replay_cuda_graph
(
bs
=
bs
,
...
...
python/sglang/srt/speculative/eagle_utils.py
View file @
60fdad7c
...
...
@@ -232,8 +232,9 @@ class EagleVerifyInput:
retrive_next_token
:
torch
.
Tensor
retrive_next_sibling
:
torch
.
Tensor
retrive_cum_len
:
torch
.
Tensor
draft_token_num
:
int
spec_steps
:
int
topk
:
int
draft_token_num
:
int
capture_hidden_mode
:
CaptureHiddenMode
grammar
:
BaseGrammarObject
=
None
...
...
@@ -270,16 +271,17 @@ class EagleVerifyInput:
)
return
cls
(
draft_tokens
,
tree_mask
,
position
,
retrive_index
,
retrive_next_token
,
retrive_next_sibling
,
None
,
num_verify_tokens
,
spec_steps
,
CaptureHiddenMode
.
FULL
,
draft_token
=
draft_tokens
,
custom_mask
=
tree_mask
,
positions
=
position
,
retrive_index
=
retrive_index
,
retrive_next_token
=
retrive_next_token
,
retrive_next_sibling
=
retrive_next_sibling
,
retrive_cum_len
=
None
,
spec_steps
=
spec_steps
,
topk
=
topk
,
draft_token_num
=
num_verify_tokens
,
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
,
)
def
prepare_for_verify
(
self
,
batch
:
ScheduleBatch
,
page_size
:
int
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment