Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
53a525bf
Unverified
Commit
53a525bf
authored
Jun 16, 2025
by
Lianmin Zheng
Committed by
GitHub
Jun 16, 2025
Browse files
[Eagle] Fix kernel call after updating speculative sampling kernels (#7231)
parent
7ddf8e83
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
24 additions
and
33 deletions
+24
-33
docker/Dockerfile.blackwell
docker/Dockerfile.blackwell
+1
-1
python/pyproject.toml
python/pyproject.toml
+1
-1
python/sglang/srt/entrypoints/engine.py
python/sglang/srt/entrypoints/engine.py
+1
-1
python/sglang/srt/speculative/build_eagle_tree.py
python/sglang/srt/speculative/build_eagle_tree.py
+1
-1
python/sglang/srt/speculative/eagle_utils.py
python/sglang/srt/speculative/eagle_utils.py
+12
-21
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+1
-1
test/srt/test_fa3.py
test/srt/test_fa3.py
+7
-7
No files found.
docker/Dockerfile.blackwell
View file @
53a525bf
...
...
@@ -20,7 +20,7 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
RUN pip3 install torch==2.7.1 torchvision==0.22.1 torchaudio==2.7.1 --index-url https://download.pytorch.org/whl/cu128 --break-system-packages
RUN pip3 install https://github.com/sgl-project/whl/releases/download/v0.1.
8.post2
/sgl_kernel-0.1.
8.post2
+cu128-cp39-abi3-manylinux2014_x86_64.whl --break-system-packages \
RUN pip3 install https://github.com/sgl-project/whl/releases/download/v0.1.
9
/sgl_kernel-0.1.
9
+cu128-cp39-abi3-manylinux2014_x86_64.whl --break-system-packages \
&& pip3 install setuptools==75.0.0 wheel scikit-build-core --break-system-packages
RUN git clone --depth=1 https://github.com/sgl-project/sglang.git \
...
...
python/pyproject.toml
View file @
53a525bf
...
...
@@ -49,7 +49,7 @@ runtime_common = [
srt
=
[
"sglang[runtime_common]"
,
"sgl-kernel==0.1.
8.post2
"
,
"sgl-kernel==0.1.
9
"
,
"flashinfer_python==0.2.6.post1"
,
"torch==2.7.1"
,
"torchaudio==2.7.1"
,
...
...
python/sglang/srt/entrypoints/engine.py
View file @
53a525bf
...
...
@@ -605,7 +605,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if
_is_cuda
:
assert_pkg_version
(
"sgl-kernel"
,
"0.1.
8.post2
"
,
"0.1.
9
"
,
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`"
,
)
...
...
python/sglang/srt/speculative/build_eagle_tree.py
View file @
53a525bf
...
...
@@ -92,7 +92,7 @@ def build_tree_kernel_efficient(
sgl_build_tree_kernel_efficient
(
parent_list
,
top_scores_index
,
seq_lens
.
to
(
torch
.
int32
)
,
seq_lens
,
tree_mask
,
positions
,
retrive_index
,
...
...
python/sglang/srt/speculative/eagle_utils.py
View file @
53a525bf
...
...
@@ -23,7 +23,7 @@ from sglang.srt.managers.schedule_batch import (
)
from
sglang.srt.mem_cache.memory_pool
import
TokenToKVPoolAllocator
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
,
ForwardMode
from
sglang.srt.utils
import
fast_topk
,
is_cuda
,
is_hip
,
next_power_of_2
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
next_power_of_2
if
is_cuda
():
from
sgl_kernel
import
(
...
...
@@ -32,6 +32,7 @@ if is_cuda():
tree_speculative_sampling_target_only
,
verify_tree_greedy
,
)
from
sgl_kernel.top_k
import
fast_topk
elif
is_hip
():
from
sgl_kernel
import
verify_tree_greedy
...
...
@@ -327,11 +328,11 @@ class EagleVerifyInput:
predicts
=
predict
,
# mutable
accept_index
=
accept_index
,
# mutable
accept_token_num
=
accept_length
,
# mutable
candidates
=
candidates
.
to
(
torch
.
int32
)
,
retrive_index
=
self
.
retrive_index
.
to
(
torch
.
int32
)
,
retrive_next_token
=
self
.
retrive_next_token
.
to
(
torch
.
int32
)
,
retrive_next_sibling
=
self
.
retrive_next_sibling
.
to
(
torch
.
int32
)
,
target_predict
=
target_predict
.
to
(
torch
.
int32
)
,
candidates
=
candidates
,
retrive_index
=
self
.
retrive_index
,
retrive_next_token
=
self
.
retrive_next_token
,
retrive_next_sibling
=
self
.
retrive_next_sibling
,
target_predict
=
target_predict
,
)
else
:
# apply temperature and get target probs
...
...
@@ -370,12 +371,12 @@ class EagleVerifyInput:
predicts
=
predict
,
# mutable
accept_index
=
accept_index
,
# mutable
accept_token_num
=
accept_length
,
# mutable
candidates
=
candidates
.
to
(
torch
.
int32
)
,
retrive_index
=
self
.
retrive_index
.
to
(
torch
.
int32
)
,
retrive_next_token
=
self
.
retrive_next_token
.
to
(
torch
.
int32
)
,
retrive_next_sibling
=
self
.
retrive_next_sibling
.
to
(
torch
.
int32
)
,
candidates
=
candidates
,
retrive_index
=
self
.
retrive_index
,
retrive_next_token
=
self
.
retrive_next_token
,
retrive_next_sibling
=
self
.
retrive_next_sibling
,
uniform_samples
=
coins
,
#
uniform_samples_for_final_sampling=coins_for_final_sampling,
uniform_samples_for_final_sampling
=
coins_for_final_sampling
,
target_probs
=
target_probs
,
draft_probs
=
draft_probs
,
threshold_single
=
global_server_args_dict
[
...
...
@@ -1005,16 +1006,6 @@ def select_top_k_tokens(
return
input_ids
,
hidden_states
,
scores
,
tree_info
def
fast_topk_torch
(
values
,
topk
,
dim
):
if
topk
==
1
:
# Use max along the specified dimension to get both value and index
max_value
,
max_index
=
torch
.
max
(
values
,
dim
=
dim
)
return
max_value
.
unsqueeze
(
1
),
max_index
.
unsqueeze
(
1
)
else
:
# Use topk for efficiency with larger k values
return
torch
.
topk
(
values
,
topk
,
dim
=
dim
)
def
_generate_simulated_accept_index
(
accept_index
,
predict
,
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
53a525bf
...
...
@@ -828,7 +828,7 @@ def load_token_map(token_map_path: str) -> List[int]:
)
token_map_path
=
os
.
path
.
join
(
cache_dir
,
os
.
path
.
basename
(
token_map_path
))
hot_token_id
=
torch
.
load
(
token_map_path
,
weights_only
=
True
)
return
torch
.
tensor
(
hot_token_id
,
dtype
=
torch
.
int
32
)
return
torch
.
tensor
(
hot_token_id
,
dtype
=
torch
.
int
64
)
@
torch
.
compile
(
dynamic
=
True
)
...
...
test/srt/test_fa3.py
View file @
53a525bf
...
...
@@ -143,7 +143,7 @@ class TestFlashAttention3SpeculativeDecode(BaseFlashAttentionTest):
args
.
extend
(
[
"--cuda-graph-max-bs"
,
"
2
"
,
"
4
"
,
"--speculative-algorithm"
,
"EAGLE3"
,
"--speculative-draft"
,
...
...
@@ -169,7 +169,7 @@ class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest):
model
=
DEFAULT_MODEL_NAME_FOR_TEST
accuracy_threshold
=
0.65
speculative_decode
=
True
spec_decode_threshold
=
1.
5
spec_decode_threshold
=
1.
6
@
classmethod
def
get_server_args
(
cls
):
...
...
@@ -177,7 +177,7 @@ class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest):
args
.
extend
(
[
"--cuda-graph-max-bs"
,
"
2
"
,
"
4
"
,
"--speculative-algorithm"
,
"EAGLE3"
,
"--speculative-draft"
,
...
...
@@ -201,7 +201,7 @@ class TestFlashAttention3MLASpeculativeDecode(BaseFlashAttentionTest):
model
=
DEFAULT_MODEL_NAME_FOR_TEST_MLA
accuracy_threshold
=
0.60
speculative_decode
=
True
spec_decode_threshold
=
1
.5
spec_decode_threshold
=
2
.5
@
classmethod
def
get_server_args
(
cls
):
...
...
@@ -209,7 +209,7 @@ class TestFlashAttention3MLASpeculativeDecode(BaseFlashAttentionTest):
args
.
extend
(
[
"--cuda-graph-max-bs"
,
"
2
"
,
"
4
"
,
"--speculative-algorithm"
,
"EAGLE"
,
"--speculative-draft"
,
...
...
@@ -233,7 +233,7 @@ class TestFlashAttention3MLASpeculativeDecodeTopk(BaseFlashAttentionTest):
model
=
DEFAULT_MODEL_NAME_FOR_TEST_MLA
accuracy_threshold
=
0.60
speculative_decode
=
True
spec_decode_threshold
=
1.
5
spec_decode_threshold
=
2.9
5
@
classmethod
def
get_server_args
(
cls
):
...
...
@@ -241,7 +241,7 @@ class TestFlashAttention3MLASpeculativeDecodeTopk(BaseFlashAttentionTest):
args
.
extend
(
[
"--cuda-graph-max-bs"
,
"
2
"
,
"
4
"
,
"--speculative-algorithm"
,
"EAGLE"
,
"--speculative-draft"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment