Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0c1f03a2
Unverified
Commit
0c1f03a2
authored
Jun 08, 2025
by
Lianmin Zheng
Committed by
GitHub
Jun 08, 2025
Browse files
Sync cuda graph runners (#6976)
parent
3712abfa
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
58 additions
and
51 deletions
+58
-51
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
...n/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
+2
-2
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
...g/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
+1
-1
python/sglang/srt/speculative/eagle_utils.py
python/sglang/srt/speculative/eagle_utils.py
+50
-48
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+1
-0
test/srt/test_eagle_infer.py
test/srt/test_eagle_infer.py
+4
-0
No files found.
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
View file @
0c1f03a2
...
@@ -127,7 +127,7 @@ class EAGLEDraftCudaGraphRunner:
...
@@ -127,7 +127,7 @@ class EAGLEDraftCudaGraphRunner:
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
out_cache_loc
=
out_cache_loc
,
out_cache_loc
=
out_cache_loc
,
seq_lens_sum
=
seq_lens
.
sum
(),
seq_lens_sum
=
seq_lens
.
sum
()
.
item
()
,
return_logprob
=
False
,
return_logprob
=
False
,
positions
=
positions
,
positions
=
positions
,
spec_algorithm
=
self
.
model_runner
.
spec_algorithm
,
spec_algorithm
=
self
.
model_runner
.
spec_algorithm
,
...
@@ -209,7 +209,7 @@ class EAGLEDraftCudaGraphRunner:
...
@@ -209,7 +209,7 @@ class EAGLEDraftCudaGraphRunner:
forward_batch
.
positions
=
self
.
positions
[:
num_tokens
]
forward_batch
.
positions
=
self
.
positions
[:
num_tokens
]
# Special handle for seq_len_cpu used when flashinfer mla is used
# Special handle for seq_len_cpu used when flashinfer mla is used
if
(
forward_batch
.
seq_lens_cpu
is
not
None
)
and
(
bs
!=
raw_bs
)
:
if
forward_batch
.
seq_lens_cpu
is
not
None
and
bs
!=
raw_bs
:
self
.
seq_lens_cpu
.
fill_
(
1
)
self
.
seq_lens_cpu
.
fill_
(
1
)
self
.
seq_lens_cpu
[:
raw_bs
].
copy_
(
forward_batch
.
seq_lens_cpu
)
self
.
seq_lens_cpu
[:
raw_bs
].
copy_
(
forward_batch
.
seq_lens_cpu
)
forward_batch
.
seq_lens_cpu
=
self
.
seq_lens_cpu
[:
bs
]
forward_batch
.
seq_lens_cpu
=
self
.
seq_lens_cpu
[:
bs
]
...
...
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
View file @
0c1f03a2
...
@@ -138,7 +138,7 @@ class EAGLEDraftExtendCudaGraphRunner:
...
@@ -138,7 +138,7 @@ class EAGLEDraftExtendCudaGraphRunner:
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
out_cache_loc
=
out_cache_loc
,
out_cache_loc
=
out_cache_loc
,
seq_lens_sum
=
seq_lens
.
sum
(),
seq_lens_sum
=
seq_lens
.
sum
()
.
item
()
,
return_logprob
=
False
,
return_logprob
=
False
,
positions
=
positions
,
positions
=
positions
,
spec_algorithm
=
self
.
model_runner
.
spec_algorithm
,
spec_algorithm
=
self
.
model_runner
.
spec_algorithm
,
...
...
python/sglang/srt/speculative/eagle_utils.py
View file @
0c1f03a2
from
__future__
import
annotations
from
__future__
import
annotations
import
logging
import
os
import
os
import
time
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
typing
import
List
,
Optional
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -12,6 +14,7 @@ import triton.language as tl
...
@@ -12,6 +14,7 @@ import triton.language as tl
from
sglang.srt.constrained.base_grammar_backend
import
BaseGrammarObject
from
sglang.srt.constrained.base_grammar_backend
import
BaseGrammarObject
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
apply_custom_logit_processor
from
sglang.srt.managers.schedule_batch
import
(
from
sglang.srt.managers.schedule_batch
import
(
Req
,
Req
,
ScheduleBatch
,
ScheduleBatch
,
...
@@ -20,7 +23,6 @@ from sglang.srt.managers.schedule_batch import (
...
@@ -20,7 +23,6 @@ from sglang.srt.managers.schedule_batch import (
)
)
from
sglang.srt.mem_cache.memory_pool
import
TokenToKVPoolAllocator
from
sglang.srt.mem_cache.memory_pool
import
TokenToKVPoolAllocator
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.speculative.build_eagle_tree
import
build_tree_kernel_efficient
from
sglang.srt.speculative.build_eagle_tree
import
build_tree_kernel_efficient
from
sglang.srt.utils
import
fast_topk
,
is_cuda
,
is_hip
,
next_power_of_2
from
sglang.srt.utils
import
fast_topk
,
is_cuda
,
is_hip
,
next_power_of_2
...
@@ -34,15 +36,15 @@ if is_cuda():
...
@@ -34,15 +36,15 @@ if is_cuda():
elif
is_hip
():
elif
is_hip
():
from
sgl_kernel
import
verify_tree_greedy
from
sgl_kernel
import
verify_tree_greedy
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
import
logging
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
# Simulate acceptance length for benchmarking purposes
SIMULATE_ACC_LEN
=
os
.
environ
.
get
(
"SIMULATE_ACC_LEN"
)
SIMULATE_ACC_LEN
=
os
.
environ
.
get
(
"SIMULATE_ACC_LEN"
)
SIMULATE_ACC_METHOD
=
os
.
environ
.
get
(
"SIMULATE_ACC_METHOD"
,
"multinomial"
)
TREE_TRAVERSE_TIME_THRESHOLD
=
1
# TODO: set this properly
@
dataclass
@
dataclass
...
@@ -84,9 +86,9 @@ class EagleDraftInput:
...
@@ -84,9 +86,9 @@ class EagleDraftInput:
self
,
self
,
batch
:
ScheduleBatch
,
batch
:
ScheduleBatch
,
speculative_num_steps
:
int
,
speculative_num_steps
:
int
,
context_length
:
int
,
pad_input
:
bool
=
False
,
pad_input
:
bool
=
False
,
):
):
assert
len
(
self
.
verified_id
)
==
len
(
batch
.
out_cache_loc
)
accept_length_cpu
=
batch
.
spec_info
.
accept_length_cpu
accept_length_cpu
=
batch
.
spec_info
.
accept_length_cpu
batch
.
extend_lens
=
[
x
+
1
for
x
in
accept_length_cpu
]
batch
.
extend_lens
=
[
x
+
1
for
x
in
accept_length_cpu
]
batch
.
extend_num_tokens
=
sum
(
batch
.
extend_lens
)
batch
.
extend_num_tokens
=
sum
(
batch
.
extend_lens
)
...
@@ -112,49 +114,49 @@ class EagleDraftInput:
...
@@ -112,49 +114,49 @@ class EagleDraftInput:
batch
.
input_ids
=
self
.
verified_id
batch
.
input_ids
=
self
.
verified_id
self
.
verified_id
=
new_verified_id
self
.
verified_id
=
new_verified_id
if
pad_input
:
if
not
pad_input
:
batch_size
=
sum
(
not
req
.
finished
()
for
req
in
batch
.
reqs
)
return
# Total constant input length after padding
static_len
=
speculative_num_steps
+
1
# Total size after padding
padded_input_size
=
batch_size
*
static_len
padded_len
=
padded_input_size
-
batch
.
input_ids
.
shape
[
0
]
if
padded_len
>
0
:
new_input_ids
=
torch
.
nn
.
functional
.
pad
(
batch
.
input_ids
,
(
0
,
padded_len
),
value
=
0
)
position_padding
=
torch
.
arange
(
padded_len
,
device
=
self
.
positions
.
device
)
new_positions
=
torch
.
cat
([
self
.
positions
,
position_padding
])
# need dummy hidden states for the padded positions
hidden_states_dim
=
self
.
hidden_states
.
shape
[
-
1
]
new_hidden_states
=
torch
.
cat
(
[
self
.
hidden_states
,
torch
.
zeros
(
(
padded_len
,
hidden_states_dim
),
dtype
=
self
.
hidden_states
.
dtype
,
device
=
self
.
hidden_states
.
device
,
),
],
dim
=
0
,
)
# allocate KV cache location for the padded tokens
batch_size
=
sum
(
not
req
.
finished
()
for
req
in
batch
.
reqs
)
padded_cache_loc
=
torch
.
zeros
(
# Total constant input length after padding
padded_len
,
static_len
=
speculative_num_steps
+
1
dtype
=
batch
.
out_cache_loc
.
dtype
,
# Total size after padding
device
=
batch
.
out_cache_loc
.
device
,
padded_input_size
=
batch_size
*
static_len
)
new_out_cache_loc
=
torch
.
cat
([
batch
.
out_cache_loc
,
padded_cache_loc
])
padded_len
=
padded_input_size
-
batch
.
input_ids
.
shape
[
0
]
if
padded_len
>
0
:
new_input_ids
=
torch
.
nn
.
functional
.
pad
(
batch
.
input_ids
,
(
0
,
padded_len
),
value
=
0
)
position_padding
=
torch
.
arange
(
padded_len
,
device
=
self
.
positions
.
device
)
new_positions
=
torch
.
cat
([
self
.
positions
,
position_padding
])
# need dummy hidden states for the padded positions
hidden_states_dim
=
self
.
hidden_states
.
shape
[
-
1
]
new_hidden_states
=
torch
.
cat
(
[
self
.
hidden_states
,
torch
.
zeros
(
(
padded_len
,
hidden_states_dim
),
dtype
=
self
.
hidden_states
.
dtype
,
device
=
self
.
hidden_states
.
device
,
),
],
dim
=
0
,
)
# allocate KV cache location for the padded tokens
padded_cache_loc
=
torch
.
zeros
(
padded_len
,
dtype
=
batch
.
out_cache_loc
.
dtype
,
device
=
batch
.
out_cache_loc
.
device
,
)
new_out_cache_loc
=
torch
.
cat
([
batch
.
out_cache_loc
,
padded_cache_loc
])
batch
.
input_ids
=
new_input_ids
batch
.
input_ids
=
new_input_ids
self
.
hidden_states
=
new_hidden_states
self
.
hidden_states
=
new_hidden_states
self
.
positions
=
new_positions
self
.
positions
=
new_positions
batch
.
out_cache_loc
=
new_out_cache_loc
batch
.
out_cache_loc
=
new_out_cache_loc
def
generate_attn_arg_prefill
(
def
generate_attn_arg_prefill
(
self
,
self
,
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
0c1f03a2
...
@@ -687,6 +687,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -687,6 +687,7 @@ class EAGLEWorker(TpModelWorker):
batch
.
spec_info
.
prepare_extend_after_decode
(
batch
.
spec_info
.
prepare_extend_after_decode
(
batch
,
batch
,
self
.
speculative_num_steps
,
self
.
speculative_num_steps
,
self
.
server_args
.
context_length
,
pad_input
=
self
.
cuda_graph_runner_for_draft_extend
is
not
None
,
pad_input
=
self
.
cuda_graph_runner_for_draft_extend
is
not
None
,
)
)
batch
.
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
batch
.
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
...
...
test/srt/test_eagle_infer.py
View file @
0c1f03a2
...
@@ -23,6 +23,7 @@ from sglang.test.test_utils import (
...
@@ -23,6 +23,7 @@ from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
CustomTestCase
,
is_in_ci
,
popen_launch_server
,
popen_launch_server
,
run_logprob_check
,
run_logprob_check
,
)
)
...
@@ -578,6 +579,7 @@ class TestEAGLEServerTriton(TestEAGLEServer):
...
@@ -578,6 +579,7 @@ class TestEAGLEServerTriton(TestEAGLEServer):
)
)
@
unittest
.
skipIf
(
is_in_ci
(),
"To reduce the CI execution time."
)
class
TestEAGLEDraftExtend
(
CustomTestCase
):
class
TestEAGLEDraftExtend
(
CustomTestCase
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
...
@@ -669,6 +671,7 @@ class TestEAGLEDraftExtendFlashinfer(TestEAGLEDraftExtend):
...
@@ -669,6 +671,7 @@ class TestEAGLEDraftExtendFlashinfer(TestEAGLEDraftExtend):
cls
.
accept_len_threshold
=
1.50
cls
.
accept_len_threshold
=
1.50
@
unittest
.
skipIf
(
is_in_ci
(),
"To reduce the CI execution time."
)
class
TestEAGLEDraftExtendTriton
(
TestEAGLEDraftExtend
):
class
TestEAGLEDraftExtendTriton
(
TestEAGLEDraftExtend
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
...
@@ -697,6 +700,7 @@ class TestEAGLEDraftExtendTriton(TestEAGLEDraftExtend):
...
@@ -697,6 +700,7 @@ class TestEAGLEDraftExtendTriton(TestEAGLEDraftExtend):
cls
.
accept_len_threshold
=
1.50
cls
.
accept_len_threshold
=
1.50
@
unittest
.
skipIf
(
is_in_ci
(),
"To reduce the CI execution time."
)
class
TestEAGLEDraftExtendFlashinferMLA
(
TestEAGLEDraftExtend
):
class
TestEAGLEDraftExtendFlashinferMLA
(
TestEAGLEDraftExtend
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment