Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
60fdad7c
"vscode:/vscode.git/clone" did not exist on "a934e5bc6c6da262a8d20e9dd6ae0a7995209d0a"
Unverified
Commit
60fdad7c
authored
Jun 06, 2025
by
Lianmin Zheng
Committed by
GitHub
Jun 06, 2025
Browse files
Sync the changes on cuda graph runners (#6932)
parent
61ce91ed
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
63 additions
and
70 deletions
+63
-70
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+2
-2
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+21
-30
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+13
-2
python/sglang/srt/speculative/build_eagle_tree.py
python/sglang/srt/speculative/build_eagle_tree.py
+8
-8
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
...n/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
+3
-8
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
...g/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
+3
-9
python/sglang/srt/speculative/eagle_utils.py
python/sglang/srt/speculative/eagle_utils.py
+13
-11
No files found.
python/sglang/srt/managers/io_struct.py
View file @
60fdad7c
...
...
@@ -20,7 +20,7 @@ import copy
import
uuid
from
dataclasses
import
dataclass
,
field
from
enum
import
Enum
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Literal
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Union
from
sglang.srt.mm_utils
import
has_valid_data
...
...
@@ -30,7 +30,7 @@ if TYPE_CHECKING:
else
:
Image
=
Any
from
sglang.srt.managers.schedule_batch
import
BaseFinishReason
,
flatten_nested_list
from
sglang.srt.managers.schedule_batch
import
BaseFinishReason
from
sglang.srt.sampling.sampling_params
import
SamplingParams
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
60fdad7c
...
...
@@ -259,23 +259,8 @@ class CudaGraphRunner:
}
# Speculative_inference
if
(
model_runner
.
spec_algorithm
.
is_eagle3
()
and
not
model_runner
.
is_draft_worker
):
self
.
hidden_states
=
torch
.
zeros
(
(
self
.
max_num_token
,
3
*
self
.
model_runner
.
model_config
.
hidden_size
,
),
dtype
=
self
.
model_runner
.
dtype
,
)
if
model_runner
.
spec_algorithm
.
is_eagle3
():
self
.
model_runner
.
model
.
set_eagle3_layers_to_capture
()
elif
model_runner
.
spec_algorithm
.
is_eagle
():
self
.
hidden_states
=
torch
.
zeros
(
(
self
.
max_num_token
,
self
.
model_runner
.
model_config
.
hidden_size
),
dtype
=
self
.
model_runner
.
dtype
,
)
if
self
.
is_encoder_decoder
:
# NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
...
...
@@ -284,6 +269,7 @@ class CudaGraphRunner:
)
else
:
self
.
encoder_lens
=
None
if
self
.
enable_dp_attention
or
self
.
enable_sp_layernorm
:
# TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer
self
.
gathered_buffer
=
torch
.
zeros
(
...
...
@@ -303,13 +289,7 @@ class CudaGraphRunner:
self
.
capture
()
except
RuntimeError
as
e
:
raise
Exception
(
f
"Capture CUDA graph failed:
{
e
}
\n
"
"Possible solutions:
\n
"
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)
\n
"
"2. set --cuda-graph-max-bs to a smaller value (e.g., 16)
\n
"
"3. disable torch compile by not using --enable-torch-compile
\n
"
"4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)
\n
"
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose
\n
"
f
"Capture cuda graph failed:
{
e
}
\n
{
CUDA_GRAPH_CAPTURE_FAILED_MSG
}
"
)
@
contextmanager
...
...
@@ -439,6 +419,7 @@ class CudaGraphRunner:
self
.
capture_hidden_mode
=
(
spec_info
.
capture_hidden_mode
if
spec_info
else
CaptureHiddenMode
.
NULL
)
if
self
.
model_runner
.
server_args
.
lora_paths
is
not
None
:
# Currently, if the lora_path in `lora_paths` is None, the lora backend will use a
# different logic to handle lora, so we need to set `lora_paths` to a list of non-None
...
...
@@ -467,9 +448,9 @@ class CudaGraphRunner:
spec_algorithm
=
self
.
model_runner
.
spec_algorithm
,
spec_info
=
spec_info
,
capture_hidden_mode
=
self
.
capture_hidden_mode
,
lora_paths
=
lora_paths
,
num_token_non_padded
=
self
.
num_token_non_padded
,
global_forward_mode
=
self
.
capture_forward_mode
,
lora_paths
=
lora_paths
,
)
self
.
tbo_plugin
.
capture_one_batch_size
(
forward_batch
,
num_tokens
=
num_tokens
)
...
...
@@ -497,7 +478,9 @@ class CudaGraphRunner:
self
.
pp_size
>
1
and
"pp_proxy_tensors"
in
inspect
.
signature
(
forward
).
parameters
):
kwargs
[
"pp_proxy_tensors"
]
=
pp_proxy_tensors
kwargs
[
"pp_proxy_tensors"
]
=
PPProxyTensors
(
{
k
:
v
.
clone
()
for
k
,
v
in
pp_proxy_tensors
.
tensors
.
items
()}
)
logits_output_or_pp_proxy_tensors
=
forward
(
input_ids
,
...
...
@@ -590,9 +573,6 @@ class CudaGraphRunner:
if
self
.
enable_dp_attention
or
self
.
enable_sp_layernorm
:
self
.
global_num_tokens_gpu
.
copy_
(
forward_batch
.
global_num_tokens_gpu
)
if
hasattr
(
forward_batch
.
spec_info
,
"hidden_states"
):
self
.
hidden_states
[:
raw_num_token
]
=
forward_batch
.
spec_info
.
hidden_states
# Attention backend
self
.
model_runner
.
attn_backend
.
init_forward_metadata_replay_cuda_graph
(
bs
,
...
...
@@ -650,7 +630,7 @@ class CudaGraphRunner:
else
:
spec_info
=
EagleVerifyInput
(
draft_token
=
None
,
custom_mask
=
torch
.
zero
s
(
custom_mask
=
torch
.
one
s
(
(
num_tokens
*
self
.
model_runner
.
model_config
.
context_len
),
dtype
=
torch
.
bool
,
device
=
"cuda"
,
...
...
@@ -660,9 +640,20 @@ class CudaGraphRunner:
retrive_next_token
=
None
,
retrive_next_sibling
=
None
,
retrive_cum_len
=
None
,
draft_token_num
=
self
.
model_runner
.
server_args
.
speculative_num_draft_tokens
,
spec_steps
=
self
.
model_runner
.
server_args
.
speculative_num_steps
,
topk
=
self
.
model_runner
.
server_args
.
speculative_eagle_topk
,
draft_token_num
=
self
.
model_runner
.
server_args
.
speculative_num_draft_tokens
,
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
,
)
return
spec_info
CUDA_GRAPH_CAPTURE_FAILED_MSG
=
(
"Possible solutions:
\n
"
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)
\n
"
"2. set --cuda-graph-max-bs to a smaller value (e.g., 16)
\n
"
"3. disable torch compile by not using --enable-torch-compile
\n
"
"4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)
\n
"
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose
\n
"
)
python/sglang/srt/server_args.py
View file @
60fdad7c
...
...
@@ -447,7 +447,7 @@ class ServerArgs:
self
.
speculative_num_steps
,
self
.
speculative_eagle_topk
,
self
.
speculative_num_draft_tokens
,
)
=
auto_choose_speculative_params
(
model_arch
)
)
=
auto_choose_speculative_params
(
self
)
if
self
.
page_size
>
1
and
self
.
speculative_eagle_topk
>
1
:
self
.
speculative_eagle_topk
=
1
...
...
@@ -1655,12 +1655,23 @@ def get_model_arch(args: ServerArgs):
return
hf_config
.
architectures
[
0
]
def
auto_choose_speculative_params
(
arch
:
str
):
def
auto_choose_speculative_params
(
self
:
ServerArgs
):
"""
Automatically choose the parameters for speculative decoding.
You can tune them on your own models and prompts with scripts/playground/bench_speculative.py
"""
kwargs
=
{}
hf_config
=
get_config
(
self
.
model_path
,
trust_remote_code
=
self
.
trust_remote_code
,
revision
=
self
.
revision
,
model_override_args
=
json
.
loads
(
self
.
json_model_override_args
),
**
kwargs
,
)
arch
=
hf_config
.
architectures
[
0
]
if
arch
in
[
"LlamaForCausalLM"
]:
# The default value for llama
return
(
5
,
4
,
8
)
...
...
python/sglang/srt/speculative/build_eagle_tree.py
View file @
60fdad7c
...
...
@@ -4,7 +4,7 @@ from typing import List
import
torch
from
sglang.srt.utils
import
is_cuda
,
is_hip
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
rank0_print
if
is_cuda
()
or
is_hip
():
from
sgl_kernel
import
(
...
...
@@ -344,13 +344,13 @@ def test_build_tree_kernel_efficient():
num_verify_tokens
=
num_draft_token
,
)
first_
rank_print
(
"=========== build tree kernel efficient =========="
)
#
first_
rank_print(f"{tree_mask=}", flush=True)
first_
rank_print
(
f
"
{
position
=
}
"
,
flush
=
True
)
first_
rank_print
(
f
"
{
retrive_index
=
}
"
,
flush
=
True
)
first_
rank_print
(
f
"
{
retrive_next_token
=
}
"
,
flush
=
True
)
first_
rank_print
(
f
"
{
retrive_next_sibling
=
}
"
,
flush
=
True
)
first_
rank_print
(
f
"
{
draft_tokens
=
}
"
,
flush
=
True
)
rank
0
_print
(
"=========== build tree kernel efficient =========="
)
# rank
0
_print(f"{tree_mask=}", flush=True)
rank
0
_print
(
f
"
{
position
=
}
"
,
flush
=
True
)
rank
0
_print
(
f
"
{
retrive_index
=
}
"
,
flush
=
True
)
rank
0
_print
(
f
"
{
retrive_next_token
=
}
"
,
flush
=
True
)
rank
0
_print
(
f
"
{
retrive_next_sibling
=
}
"
,
flush
=
True
)
rank
0
_print
(
f
"
{
draft_tokens
=
}
"
,
flush
=
True
)
assert
position
.
tolist
()
==
[
5
,
6
,
6
,
7
,
7
,
8
,
8
,
9
,
10
,
11
,
12
,
12
,
12
,
12
,
13
,
14
]
assert
retrive_index
.
tolist
()
==
[
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
...
...
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
View file @
60fdad7c
...
...
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Callable
import
torch
from
sglang.srt.model_executor.cuda_graph_runner
import
(
CUDA_GRAPH_CAPTURE_FAILED_MSG
,
CudaGraphRunner
,
get_batch_sizes_to_capture
,
get_global_graph_memory_pool
,
...
...
@@ -73,7 +74,7 @@ class EAGLEDraftCudaGraphRunner:
self
.
topk_p
=
torch
.
zeros
((
self
.
max_bs
,
self
.
topk
),
dtype
=
torch
.
float32
)
self
.
topk_index
=
torch
.
zeros
((
self
.
max_bs
,
self
.
topk
),
dtype
=
torch
.
int64
)
self
.
hidden_states
=
torch
.
zeros
(
(
self
.
max_
bs
,
self
.
model_runner
.
model_config
.
hidden_size
),
(
self
.
max_
num_token
,
self
.
model_runner
.
model_config
.
hidden_size
),
dtype
=
self
.
model_runner
.
dtype
,
)
...
...
@@ -82,13 +83,7 @@ class EAGLEDraftCudaGraphRunner:
self
.
capture
()
except
RuntimeError
as
e
:
raise
Exception
(
f
"Capture CUDA graph failed:
{
e
}
\n
"
"Possible solutions:
\n
"
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)
\n
"
"2. set --cuda-graph-max-bs to a smaller value (e.g., 16)
\n
"
"3. disable torch compile by not using --enable-torch-compile
\n
"
"4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)
\n
"
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose
\n
"
f
"Capture cuda graph failed:
{
e
}
\n
{
CUDA_GRAPH_CAPTURE_FAILED_MSG
}
"
)
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
...
...
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
View file @
60fdad7c
...
...
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Callable
import
torch
from
sglang.srt.model_executor.cuda_graph_runner
import
(
CUDA_GRAPH_CAPTURE_FAILED_MSG
,
CudaGraphRunner
,
LogitsProcessorOutput
,
get_batch_sizes_to_capture
,
...
...
@@ -89,13 +90,7 @@ class EAGLEDraftExtendCudaGraphRunner:
self
.
capture
()
except
RuntimeError
as
e
:
raise
Exception
(
f
"Capture CUDA graph failed:
{
e
}
\n
"
"Possible solutions:
\n
"
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)
\n
"
"2. set --cuda-graph-max-bs to a smaller value (e.g., 16)
\n
"
"3. disable torch compile by not using --enable-torch-compile
\n
"
"4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)
\n
"
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose
\n
"
f
"Capture cuda graph failed:
{
e
}
\n
{
CUDA_GRAPH_CAPTURE_FAILED_MSG
}
"
)
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
...
...
@@ -200,7 +195,6 @@ class EAGLEDraftExtendCudaGraphRunner:
# in the batch, which will not be counted as num_seqs
raw_bs
=
forward_batch
.
batch_size
num_tokens
=
forward_batch
.
input_ids
.
shape
[
0
]
assert
raw_bs
*
self
.
num_tokens_per_bs
==
num_tokens
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
raw_bs
)
bs
=
self
.
capture_bs
[
index
]
...
...
@@ -224,9 +218,9 @@ class EAGLEDraftExtendCudaGraphRunner:
self
.
seq_lens_cpu
.
fill_
(
1
)
self
.
seq_lens_cpu
[:
raw_bs
].
copy_
(
forward_batch
.
seq_lens_cpu
)
forward_batch
.
spec_info
.
positions
=
None
if
bs
!=
raw_bs
:
forward_batch
.
spec_info
.
accept_length
=
self
.
accept_length
[:
bs
]
forward_batch
.
spec_info
.
positions
=
None
self
.
eagle_worker
.
draft_extend_attn_backend
.
init_forward_metadata_replay_cuda_graph
(
bs
=
bs
,
...
...
python/sglang/srt/speculative/eagle_utils.py
View file @
60fdad7c
...
...
@@ -232,8 +232,9 @@ class EagleVerifyInput:
retrive_next_token
:
torch
.
Tensor
retrive_next_sibling
:
torch
.
Tensor
retrive_cum_len
:
torch
.
Tensor
draft_token_num
:
int
spec_steps
:
int
topk
:
int
draft_token_num
:
int
capture_hidden_mode
:
CaptureHiddenMode
grammar
:
BaseGrammarObject
=
None
...
...
@@ -270,16 +271,17 @@ class EagleVerifyInput:
)
return
cls
(
draft_tokens
,
tree_mask
,
position
,
retrive_index
,
retrive_next_token
,
retrive_next_sibling
,
None
,
num_verify_tokens
,
spec_steps
,
CaptureHiddenMode
.
FULL
,
draft_token
=
draft_tokens
,
custom_mask
=
tree_mask
,
positions
=
position
,
retrive_index
=
retrive_index
,
retrive_next_token
=
retrive_next_token
,
retrive_next_sibling
=
retrive_next_sibling
,
retrive_cum_len
=
None
,
spec_steps
=
spec_steps
,
topk
=
topk
,
draft_token_num
=
num_verify_tokens
,
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
,
)
def
prepare_for_verify
(
self
,
batch
:
ScheduleBatch
,
page_size
:
int
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment