Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0c1f03a2
"references/vscode:/vscode.git/clone" did not exist on "0dceac025615a1c2df6ec1675d8f9d7757432a49"
Unverified
Commit
0c1f03a2
authored
Jun 08, 2025
by
Lianmin Zheng
Committed by
GitHub
Jun 08, 2025
Browse files
Sync cuda graph runners (#6976)
parent
3712abfa
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
58 additions
and
51 deletions
+58
-51
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
...n/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
+2
-2
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
...g/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
+1
-1
python/sglang/srt/speculative/eagle_utils.py
python/sglang/srt/speculative/eagle_utils.py
+50
-48
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+1
-0
test/srt/test_eagle_infer.py
test/srt/test_eagle_infer.py
+4
-0
No files found.
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
View file @
0c1f03a2
...
@@ -127,7 +127,7 @@ class EAGLEDraftCudaGraphRunner:
...
@@ -127,7 +127,7 @@ class EAGLEDraftCudaGraphRunner:
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
out_cache_loc
=
out_cache_loc
,
out_cache_loc
=
out_cache_loc
,
seq_lens_sum
=
seq_lens
.
sum
(),
seq_lens_sum
=
seq_lens
.
sum
()
.
item
()
,
return_logprob
=
False
,
return_logprob
=
False
,
positions
=
positions
,
positions
=
positions
,
spec_algorithm
=
self
.
model_runner
.
spec_algorithm
,
spec_algorithm
=
self
.
model_runner
.
spec_algorithm
,
...
@@ -209,7 +209,7 @@ class EAGLEDraftCudaGraphRunner:
...
@@ -209,7 +209,7 @@ class EAGLEDraftCudaGraphRunner:
forward_batch
.
positions
=
self
.
positions
[:
num_tokens
]
forward_batch
.
positions
=
self
.
positions
[:
num_tokens
]
# Special handle for seq_len_cpu used when flashinfer mla is used
# Special handle for seq_len_cpu used when flashinfer mla is used
if
(
forward_batch
.
seq_lens_cpu
is
not
None
)
and
(
bs
!=
raw_bs
)
:
if
forward_batch
.
seq_lens_cpu
is
not
None
and
bs
!=
raw_bs
:
self
.
seq_lens_cpu
.
fill_
(
1
)
self
.
seq_lens_cpu
.
fill_
(
1
)
self
.
seq_lens_cpu
[:
raw_bs
].
copy_
(
forward_batch
.
seq_lens_cpu
)
self
.
seq_lens_cpu
[:
raw_bs
].
copy_
(
forward_batch
.
seq_lens_cpu
)
forward_batch
.
seq_lens_cpu
=
self
.
seq_lens_cpu
[:
bs
]
forward_batch
.
seq_lens_cpu
=
self
.
seq_lens_cpu
[:
bs
]
...
...
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
View file @
0c1f03a2
...
@@ -138,7 +138,7 @@ class EAGLEDraftExtendCudaGraphRunner:
...
@@ -138,7 +138,7 @@ class EAGLEDraftExtendCudaGraphRunner:
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
out_cache_loc
=
out_cache_loc
,
out_cache_loc
=
out_cache_loc
,
seq_lens_sum
=
seq_lens
.
sum
(),
seq_lens_sum
=
seq_lens
.
sum
()
.
item
()
,
return_logprob
=
False
,
return_logprob
=
False
,
positions
=
positions
,
positions
=
positions
,
spec_algorithm
=
self
.
model_runner
.
spec_algorithm
,
spec_algorithm
=
self
.
model_runner
.
spec_algorithm
,
...
...
python/sglang/srt/speculative/eagle_utils.py
View file @
0c1f03a2
from
__future__
import
annotations
from
__future__
import
annotations
import
logging
import
os
import
os
import
time
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
typing
import
List
,
Optional
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -12,6 +14,7 @@ import triton.language as tl
...
@@ -12,6 +14,7 @@ import triton.language as tl
from
sglang.srt.constrained.base_grammar_backend
import
BaseGrammarObject
from
sglang.srt.constrained.base_grammar_backend
import
BaseGrammarObject
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
apply_custom_logit_processor
from
sglang.srt.managers.schedule_batch
import
(
from
sglang.srt.managers.schedule_batch
import
(
Req
,
Req
,
ScheduleBatch
,
ScheduleBatch
,
...
@@ -20,7 +23,6 @@ from sglang.srt.managers.schedule_batch import (
...
@@ -20,7 +23,6 @@ from sglang.srt.managers.schedule_batch import (
)
)
from
sglang.srt.mem_cache.memory_pool
import
TokenToKVPoolAllocator
from
sglang.srt.mem_cache.memory_pool
import
TokenToKVPoolAllocator
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.speculative.build_eagle_tree
import
build_tree_kernel_efficient
from
sglang.srt.speculative.build_eagle_tree
import
build_tree_kernel_efficient
from
sglang.srt.utils
import
fast_topk
,
is_cuda
,
is_hip
,
next_power_of_2
from
sglang.srt.utils
import
fast_topk
,
is_cuda
,
is_hip
,
next_power_of_2
...
@@ -34,15 +36,15 @@ if is_cuda():
...
@@ -34,15 +36,15 @@ if is_cuda():
elif
is_hip
():
elif
is_hip
():
from
sgl_kernel
import
verify_tree_greedy
from
sgl_kernel
import
verify_tree_greedy
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
import
logging
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
# Simulate acceptance length for benchmarking purposes
SIMULATE_ACC_LEN
=
os
.
environ
.
get
(
"SIMULATE_ACC_LEN"
)
SIMULATE_ACC_LEN
=
os
.
environ
.
get
(
"SIMULATE_ACC_LEN"
)
SIMULATE_ACC_METHOD
=
os
.
environ
.
get
(
"SIMULATE_ACC_METHOD"
,
"multinomial"
)
TREE_TRAVERSE_TIME_THRESHOLD
=
1
# TODO: set this properly
@
dataclass
@
dataclass
...
@@ -84,9 +86,9 @@ class EagleDraftInput:
...
@@ -84,9 +86,9 @@ class EagleDraftInput:
self
,
self
,
batch
:
ScheduleBatch
,
batch
:
ScheduleBatch
,
speculative_num_steps
:
int
,
speculative_num_steps
:
int
,
context_length
:
int
,
pad_input
:
bool
=
False
,
pad_input
:
bool
=
False
,
):
):
assert
len
(
self
.
verified_id
)
==
len
(
batch
.
out_cache_loc
)
accept_length_cpu
=
batch
.
spec_info
.
accept_length_cpu
accept_length_cpu
=
batch
.
spec_info
.
accept_length_cpu
batch
.
extend_lens
=
[
x
+
1
for
x
in
accept_length_cpu
]
batch
.
extend_lens
=
[
x
+
1
for
x
in
accept_length_cpu
]
batch
.
extend_num_tokens
=
sum
(
batch
.
extend_lens
)
batch
.
extend_num_tokens
=
sum
(
batch
.
extend_lens
)
...
@@ -112,49 +114,49 @@ class EagleDraftInput:
...
@@ -112,49 +114,49 @@ class EagleDraftInput:
batch
.
input_ids
=
self
.
verified_id
batch
.
input_ids
=
self
.
verified_id
self
.
verified_id
=
new_verified_id
self
.
verified_id
=
new_verified_id
if
pad_input
:
if
not
pad_input
:
batch_size
=
sum
(
not
req
.
finished
()
for
req
in
batch
.
reqs
)
return
# Total constant input length after padding
static_len
=
speculative_num_steps
+
1
# Total size after padding
padded_input_size
=
batch_size
*
static_len
padded_len
=
padded_input_size
-
batch
.
input_ids
.
shape
[
0
]
if
padded_len
>
0
:
new_input_ids
=
torch
.
nn
.
functional
.
pad
(
batch
.
input_ids
,
(
0
,
padded_len
),
value
=
0
)
position_padding
=
torch
.
arange
(
padded_len
,
device
=
self
.
positions
.
device
)
new_positions
=
torch
.
cat
([
self
.
positions
,
position_padding
])
# need dummy hidden states for the padded positions
hidden_states_dim
=
self
.
hidden_states
.
shape
[
-
1
]
new_hidden_states
=
torch
.
cat
(
[
self
.
hidden_states
,
torch
.
zeros
(
(
padded_len
,
hidden_states_dim
),
dtype
=
self
.
hidden_states
.
dtype
,
device
=
self
.
hidden_states
.
device
,
),
],
dim
=
0
,
)
# allocate KV cache location for the padded tokens
batch_size
=
sum
(
not
req
.
finished
()
for
req
in
batch
.
reqs
)
padded_cache_loc
=
torch
.
zeros
(
# Total constant input length after padding
padded_len
,
static_len
=
speculative_num_steps
+
1
dtype
=
batch
.
out_cache_loc
.
dtype
,
# Total size after padding
device
=
batch
.
out_cache_loc
.
device
,
padded_input_size
=
batch_size
*
static_len
)
new_out_cache_loc
=
torch
.
cat
([
batch
.
out_cache_loc
,
padded_cache_loc
])
padded_len
=
padded_input_size
-
batch
.
input_ids
.
shape
[
0
]
if
padded_len
>
0
:
new_input_ids
=
torch
.
nn
.
functional
.
pad
(
batch
.
input_ids
,
(
0
,
padded_len
),
value
=
0
)
position_padding
=
torch
.
arange
(
padded_len
,
device
=
self
.
positions
.
device
)
new_positions
=
torch
.
cat
([
self
.
positions
,
position_padding
])
# need dummy hidden states for the padded positions
hidden_states_dim
=
self
.
hidden_states
.
shape
[
-
1
]
new_hidden_states
=
torch
.
cat
(
[
self
.
hidden_states
,
torch
.
zeros
(
(
padded_len
,
hidden_states_dim
),
dtype
=
self
.
hidden_states
.
dtype
,
device
=
self
.
hidden_states
.
device
,
),
],
dim
=
0
,
)
# allocate KV cache location for the padded tokens
padded_cache_loc
=
torch
.
zeros
(
padded_len
,
dtype
=
batch
.
out_cache_loc
.
dtype
,
device
=
batch
.
out_cache_loc
.
device
,
)
new_out_cache_loc
=
torch
.
cat
([
batch
.
out_cache_loc
,
padded_cache_loc
])
batch
.
input_ids
=
new_input_ids
batch
.
input_ids
=
new_input_ids
self
.
hidden_states
=
new_hidden_states
self
.
hidden_states
=
new_hidden_states
self
.
positions
=
new_positions
self
.
positions
=
new_positions
batch
.
out_cache_loc
=
new_out_cache_loc
batch
.
out_cache_loc
=
new_out_cache_loc
def
generate_attn_arg_prefill
(
def
generate_attn_arg_prefill
(
self
,
self
,
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
0c1f03a2
...
@@ -687,6 +687,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -687,6 +687,7 @@ class EAGLEWorker(TpModelWorker):
batch
.
spec_info
.
prepare_extend_after_decode
(
batch
.
spec_info
.
prepare_extend_after_decode
(
batch
,
batch
,
self
.
speculative_num_steps
,
self
.
speculative_num_steps
,
self
.
server_args
.
context_length
,
pad_input
=
self
.
cuda_graph_runner_for_draft_extend
is
not
None
,
pad_input
=
self
.
cuda_graph_runner_for_draft_extend
is
not
None
,
)
)
batch
.
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
batch
.
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
...
...
test/srt/test_eagle_infer.py
View file @
0c1f03a2
...
@@ -23,6 +23,7 @@ from sglang.test.test_utils import (
...
@@ -23,6 +23,7 @@ from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
CustomTestCase
,
is_in_ci
,
popen_launch_server
,
popen_launch_server
,
run_logprob_check
,
run_logprob_check
,
)
)
...
@@ -578,6 +579,7 @@ class TestEAGLEServerTriton(TestEAGLEServer):
...
@@ -578,6 +579,7 @@ class TestEAGLEServerTriton(TestEAGLEServer):
)
)
@
unittest
.
skipIf
(
is_in_ci
(),
"To reduce the CI execution time."
)
class
TestEAGLEDraftExtend
(
CustomTestCase
):
class
TestEAGLEDraftExtend
(
CustomTestCase
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
...
@@ -669,6 +671,7 @@ class TestEAGLEDraftExtendFlashinfer(TestEAGLEDraftExtend):
...
@@ -669,6 +671,7 @@ class TestEAGLEDraftExtendFlashinfer(TestEAGLEDraftExtend):
cls
.
accept_len_threshold
=
1.50
cls
.
accept_len_threshold
=
1.50
@
unittest
.
skipIf
(
is_in_ci
(),
"To reduce the CI execution time."
)
class
TestEAGLEDraftExtendTriton
(
TestEAGLEDraftExtend
):
class
TestEAGLEDraftExtendTriton
(
TestEAGLEDraftExtend
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
...
@@ -697,6 +700,7 @@ class TestEAGLEDraftExtendTriton(TestEAGLEDraftExtend):
...
@@ -697,6 +700,7 @@ class TestEAGLEDraftExtendTriton(TestEAGLEDraftExtend):
cls
.
accept_len_threshold
=
1.50
cls
.
accept_len_threshold
=
1.50
@
unittest
.
skipIf
(
is_in_ci
(),
"To reduce the CI execution time."
)
class
TestEAGLEDraftExtendFlashinferMLA
(
TestEAGLEDraftExtend
):
class
TestEAGLEDraftExtendFlashinferMLA
(
TestEAGLEDraftExtend
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment