Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
fad315cb
Unverified
Commit
fad315cb
authored
Feb 09, 2025
by
Yineng Zhang
Committed by
GitHub
Feb 09, 2025
Browse files
fix EAGLE 2 non greedy case (#3407)
Co-authored-by:
Ying Sheng
<
sqy1415@gmail.com
>
parent
f90db8bc
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
71 additions
and
22 deletions
+71
-22
benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py
...d_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py
+3
-1
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+3
-0
python/sglang/srt/speculative/eagle_utils.py
python/sglang/srt/speculative/eagle_utils.py
+64
-21
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+1
-0
No files found.
benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py
View file @
fad315cb
...
@@ -54,7 +54,9 @@ def get_model_config(model_name: str, tp_size: int):
...
@@ -54,7 +54,9 @@ def get_model_config(model_name: str, tp_size: int):
):
):
block_shape
=
config
.
quantization_config
[
"weight_block_size"
]
block_shape
=
config
.
quantization_config
[
"weight_block_size"
]
assert
len
(
block_shape
)
==
2
assert
len
(
block_shape
)
==
2
assert
vllm_version_num
>=
66
,
"Block-wise quantized fp8 fused_moe is only supported for VLLM>=0.6.6.post1"
assert
(
vllm_version_num
>=
66
),
"Block-wise quantized fp8 fused_moe is only supported for VLLM>=0.6.6.post1"
shape_configs
=
{
shape_configs
=
{
"num_experts"
:
E
,
"num_experts"
:
E
,
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
fad315cb
...
@@ -462,8 +462,11 @@ class CudaGraphRunner:
...
@@ -462,8 +462,11 @@ class CudaGraphRunner:
),
),
positions
=
None
,
positions
=
None
,
retrive_index
=
None
,
retrive_index
=
None
,
retrive_next_token
=
None
,
retrive_next_sibling
=
None
,
retrive_cum_len
=
None
,
retrive_cum_len
=
None
,
draft_token_num
=
self
.
model_runner
.
server_args
.
speculative_num_draft_tokens
,
draft_token_num
=
self
.
model_runner
.
server_args
.
speculative_num_draft_tokens
,
spec_steps
=
self
.
model_runner
.
server_args
.
speculative_num_steps
,
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
,
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
,
)
)
...
...
python/sglang/srt/speculative/eagle_utils.py
View file @
fad315cb
...
@@ -4,6 +4,7 @@ import dataclasses
...
@@ -4,6 +4,7 @@ import dataclasses
from
typing
import
TYPE_CHECKING
,
List
from
typing
import
TYPE_CHECKING
,
List
import
torch
import
torch
import
torch.nn.functional
as
F
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
...
@@ -11,7 +12,14 @@ from sglang.srt.layers.attention.flashinfer_backend import (
...
@@ -11,7 +12,14 @@ from sglang.srt.layers.attention.flashinfer_backend import (
create_flashinfer_kv_indices_triton
,
create_flashinfer_kv_indices_triton
,
)
)
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
from
sglang.srt.speculative.build_eagle_tree
import
build_tree_kernel
from
sglang.srt.speculative.build_eagle_tree
import
(
build_tree_kernel
,
build_tree_kernel_efficient
,
)
from
sglang.srt.utils
import
is_cuda_available
if
is_cuda_available
():
from
sgl_kernel
import
tree_speculative_sampling_target_only
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
...
@@ -160,8 +168,11 @@ class EagleVerifyInput:
...
@@ -160,8 +168,11 @@ class EagleVerifyInput:
custom_mask
:
torch
.
Tensor
custom_mask
:
torch
.
Tensor
positions
:
torch
.
Tensor
positions
:
torch
.
Tensor
retrive_index
:
torch
.
Tensor
retrive_index
:
torch
.
Tensor
retrive_next_token
:
torch
.
Tensor
retrive_next_sibling
:
torch
.
Tensor
retrive_cum_len
:
torch
.
Tensor
retrive_cum_len
:
torch
.
Tensor
draft_token_num
:
int
draft_token_num
:
int
spec_steps
:
int
capture_hidden_mode
:
CaptureHiddenMode
capture_hidden_mode
:
CaptureHiddenMode
@
classmethod
@
classmethod
...
@@ -175,10 +186,45 @@ class EagleVerifyInput:
...
@@ -175,10 +186,45 @@ class EagleVerifyInput:
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
topk
:
int
,
topk
:
int
,
spec_steps
:
int
,
spec_steps
:
int
,
num_verify_token
:
int
,
num_verify_tokens
:
int
,
is_all_greedy
:
bool
,
):
):
tree_mask
,
position
,
retrive_index
,
retrive_cum_len
,
draft_tokens
=
(
if
is_all_greedy
:
build_tree_kernel
(
tree_mask
,
position
,
retrive_index
,
retrive_cum_len
,
draft_tokens
=
(
build_tree_kernel
(
verified_id
,
score_list
,
# b, n, topk; n= 1 + (num_steps-1) * self.topk
token_list
,
parents_list
,
seq_lens
,
seq_lens_sum
,
topk
,
spec_steps
,
num_verify_tokens
,
)
)
return
cls
(
draft_tokens
,
tree_mask
,
position
,
retrive_index
,
None
,
None
,
retrive_cum_len
,
num_verify_tokens
,
spec_steps
,
CaptureHiddenMode
.
FULL
,
)
else
:
(
tree_mask
,
position
,
retrive_index
,
retrive_next_token
,
retrive_next_sibling
,
draft_tokens
,
)
=
build_tree_kernel_efficient
(
verified_id
,
verified_id
,
score_list
,
score_list
,
token_list
,
token_list
,
...
@@ -187,18 +233,21 @@ class EagleVerifyInput:
...
@@ -187,18 +233,21 @@ class EagleVerifyInput:
seq_lens_sum
,
seq_lens_sum
,
topk
,
topk
,
spec_steps
,
spec_steps
,
num_verify_token
,
num_verify_tokens
,
)
return
cls
(
draft_tokens
,
tree_mask
,
position
,
retrive_index
,
retrive_next_token
,
retrive_next_sibling
,
None
,
num_verify_tokens
,
spec_steps
,
CaptureHiddenMode
.
FULL
,
)
)
)
return
cls
(
draft_tokens
,
tree_mask
,
position
,
retrive_index
,
retrive_cum_len
,
num_verify_token
,
CaptureHiddenMode
.
FULL
,
)
def
prepare_for_verify
(
self
,
batch
:
ScheduleBatch
):
def
prepare_for_verify
(
self
,
batch
:
ScheduleBatch
):
batch
.
input_ids
=
self
.
draft_token
batch
.
input_ids
=
self
.
draft_token
...
@@ -313,12 +362,6 @@ class EagleVerifyInput:
...
@@ -313,12 +362,6 @@ class EagleVerifyInput:
uniform_samples
=
coins
,
uniform_samples
=
coins
,
target_probs
=
target_probs
,
target_probs
=
target_probs
,
draft_probs
=
draft_probs
,
draft_probs
=
draft_probs
,
threshold_single
=
global_server_args_dict
[
"speculative_accept_threshold_single"
],
threshold_acc
=
global_server_args_dict
[
"speculative_accept_threshold_acc"
],
deterministic
=
True
,
deterministic
=
True
,
)
)
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
fad315cb
...
@@ -185,6 +185,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -185,6 +185,7 @@ class EAGLEWorker(TpModelWorker):
self
.
topk
,
self
.
topk
,
self
.
speculative_num_steps
,
self
.
speculative_num_steps
,
self
.
server_args
.
speculative_num_draft_tokens
,
self
.
server_args
.
speculative_num_draft_tokens
,
batch
.
sampling_info
.
is_all_greedy
,
)
)
# Free cache locations
# Free cache locations
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment