Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
6ce0ed07
Unverified
Commit
6ce0ed07
authored
May 22, 2025
by
Ke Bao
Committed by
GitHub
May 21, 2025
Browse files
Apply constraint grammar to EAGLE (#6499)
Co-authored-by:
merrymercy
<
lianminzheng@gmail.com
>
parent
969660c7
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
198 additions
and
0 deletions
+198
-0
python/sglang/srt/speculative/eagle_utils.py
python/sglang/srt/speculative/eagle_utils.py
+131
-0
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+32
-0
test/srt/test_eagle_infer.py
test/srt/test_eagle_infer.py
+35
-0
No files found.
python/sglang/srt/speculative/eagle_utils.py
View file @
6ce0ed07
...
@@ -9,15 +9,18 @@ import torch.nn.functional as F
...
@@ -9,15 +9,18 @@ import torch.nn.functional as F
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sglang.srt.constrained.base_grammar_backend
import
BaseGrammarObject
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.schedule_batch
import
(
from
sglang.srt.managers.schedule_batch
import
(
Req
,
ScheduleBatch
,
ScheduleBatch
,
get_last_loc
,
get_last_loc
,
global_server_args_dict
,
global_server_args_dict
,
)
)
from
sglang.srt.mem_cache.memory_pool
import
TokenToKVPoolAllocator
from
sglang.srt.mem_cache.memory_pool
import
TokenToKVPoolAllocator
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.speculative.build_eagle_tree
import
build_tree_kernel_efficient
from
sglang.srt.speculative.build_eagle_tree
import
build_tree_kernel_efficient
from
sglang.srt.utils
import
fast_topk
,
is_cuda
,
is_hip
,
next_power_of_2
from
sglang.srt.utils
import
fast_topk
,
is_cuda
,
is_hip
,
next_power_of_2
...
@@ -187,6 +190,7 @@ class EagleVerifyInput:
...
@@ -187,6 +190,7 @@ class EagleVerifyInput:
draft_token_num
:
int
draft_token_num
:
int
spec_steps
:
int
spec_steps
:
int
capture_hidden_mode
:
CaptureHiddenMode
capture_hidden_mode
:
CaptureHiddenMode
grammar
:
BaseGrammarObject
=
None
@
classmethod
@
classmethod
def
create
(
def
create
(
...
@@ -307,6 +311,7 @@ class EagleVerifyInput:
...
@@ -307,6 +311,7 @@ class EagleVerifyInput:
logits_output
:
torch
.
Tensor
,
logits_output
:
torch
.
Tensor
,
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
page_size
:
int
,
page_size
:
int
,
vocab_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Verify and find accepted tokens based on logits output and batch
Verify and find accepted tokens based on logits output and batch
...
@@ -343,6 +348,13 @@ class EagleVerifyInput:
...
@@ -343,6 +348,13 @@ class EagleVerifyInput:
torch
.
repeat_interleave
(
linear_penalty
,
self
.
draft_token_num
,
dim
=
0
)
torch
.
repeat_interleave
(
linear_penalty
,
self
.
draft_token_num
,
dim
=
0
)
)
)
# Apply grammar mask
if
vocab_mask
is
not
None
:
assert
self
.
grammar
is
not
None
self
.
grammar
.
apply_vocab_mask
(
logits
=
logits_output
.
next_token_logits
,
vocab_mask
=
vocab_mask
)
# Sample tokens
# Sample tokens
if
batch
.
sampling_info
.
is_all_greedy
:
if
batch
.
sampling_info
.
is_all_greedy
:
target_predict
=
torch
.
argmax
(
logits_output
.
next_token_logits
,
dim
=-
1
)
target_predict
=
torch
.
argmax
(
logits_output
.
next_token_logits
,
dim
=-
1
)
...
@@ -440,6 +452,15 @@ class EagleVerifyInput:
...
@@ -440,6 +452,15 @@ class EagleVerifyInput:
break
break
else
:
else
:
new_accept_index_
.
append
(
idx
)
new_accept_index_
.
append
(
idx
)
# update grammar state
if
req
.
grammar
is
not
None
:
try
:
req
.
grammar
.
accept_token
(
id
)
except
ValueError
as
e
:
logger
.
info
(
f
"
{
i
=
}
,
{
req
=
}
\n
"
f
"
{
accept_index
=
}
\n
"
f
"
{
predict
=
}
\n
"
)
raise
e
if
not
req
.
finished
():
if
not
req
.
finished
():
new_accept_index
.
extend
(
new_accept_index_
)
new_accept_index
.
extend
(
new_accept_index_
)
unfinished_index
.
append
(
i
)
unfinished_index
.
append
(
i
)
...
@@ -801,3 +822,113 @@ def _generate_simulated_accept_index(
...
@@ -801,3 +822,113 @@ def _generate_simulated_accept_index(
accept_length
.
fill_
(
simulate_acc_len
-
1
)
accept_length
.
fill_
(
simulate_acc_len
-
1
)
predict
.
fill_
(
100
)
# some legit token id
predict
.
fill_
(
100
)
# some legit token id
return
sim_accept_index
return
sim_accept_index
def
traverse_tree
(
retrieve_next_token
:
torch
.
Tensor
,
retrieve_next_sibling
:
torch
.
Tensor
,
draft_tokens
:
torch
.
Tensor
,
grammar
:
BaseGrammarObject
,
allocate_token_bitmask
:
torch
.
Tensor
,
):
"""
Traverse the tree constructed by the draft model to generate the logits mask.
"""
assert
(
retrieve_next_token
.
shape
==
retrieve_next_sibling
.
shape
==
draft_tokens
.
shape
)
allocate_token_bitmask
.
fill_
(
0
)
def
dfs
(
curr
:
int
,
retrieve_next_token
:
torch
.
Tensor
,
retrieve_next_sibling
:
torch
.
Tensor
,
parent_pos
:
int
,
):
if
curr
==
0
:
# the first token generated by the target model, and thus it is always
# accepted from the previous iteration
accepted
=
True
else
:
parent_bitmask
=
allocate_token_bitmask
[
parent_pos
]
curr_token_id
=
draft_tokens
[
curr
]
# 32 boolean bitmask values are packed into 32-bit integers
accepted
=
(
parent_bitmask
[
curr_token_id
//
32
]
&
(
1
<<
(
curr_token_id
%
32
))
)
!=
0
if
accepted
:
if
curr
!=
0
:
# Accept the current token
grammar
.
accept_token
(
draft_tokens
[
curr
])
if
not
grammar
.
is_terminated
():
# Generate the bitmask for the current token
grammar
.
fill_vocab_mask
(
allocate_token_bitmask
,
curr
)
if
retrieve_next_token
[
curr
]
!=
-
1
:
# Visit the child node
dfs
(
retrieve_next_token
[
curr
],
retrieve_next_token
,
retrieve_next_sibling
,
curr
,
)
if
curr
!=
0
:
# Rollback the current token
grammar
.
rollback
(
1
)
if
retrieve_next_sibling
[
curr
]
!=
-
1
:
# Visit the sibling node
dfs
(
retrieve_next_sibling
[
curr
],
retrieve_next_token
,
retrieve_next_sibling
,
parent_pos
,
)
dfs
(
0
,
retrieve_next_token
,
retrieve_next_sibling
,
-
1
)
def
generate_token_bitmask
(
reqs
:
List
[
Req
],
verify_input
:
EagleVerifyInput
,
retrieve_next_token_cpu
:
torch
.
Tensor
,
retrieve_next_sibling_cpu
:
torch
.
Tensor
,
draft_tokens_cpu
:
torch
.
Tensor
,
vocab_size
:
int
,
):
"""
Generate the logit mask for structured output.
Draft model's token can be either valid or invalid with respect to the grammar.
We need to perform DFS to figure out:
1. which tokens are accepted by the grammar
2. what is the corresponding logit mask.
"""
num_draft_tokens
=
draft_tokens_cpu
.
shape
[
-
1
]
allocate_token_bitmask
=
None
assert
len
(
reqs
)
==
retrieve_next_token_cpu
.
shape
[
0
]
grammar
=
None
for
i
,
req
in
enumerate
(
reqs
):
if
req
.
grammar
is
not
None
:
if
allocate_token_bitmask
is
None
:
allocate_token_bitmask
=
req
.
grammar
.
allocate_vocab_mask
(
vocab_size
=
vocab_size
,
batch_size
=
draft_tokens_cpu
.
numel
(),
device
=
"cpu"
,
)
grammar
=
req
.
grammar
traverse_tree
(
retrieve_next_token_cpu
[
i
],
retrieve_next_sibling_cpu
[
i
],
draft_tokens_cpu
[
i
],
req
.
grammar
,
allocate_token_bitmask
[
i
*
num_draft_tokens
:
(
i
+
1
)
*
num_draft_tokens
],
)
verify_input
.
grammar
=
grammar
return
allocate_token_bitmask
python/sglang/srt/speculative/eagle_worker.py
View file @
6ce0ed07
...
@@ -31,6 +31,7 @@ from sglang.srt.speculative.eagle_utils import (
...
@@ -31,6 +31,7 @@ from sglang.srt.speculative.eagle_utils import (
EagleVerifyInput
,
EagleVerifyInput
,
EagleVerifyOutput
,
EagleVerifyOutput
,
assign_draft_cache_locs
,
assign_draft_cache_locs
,
generate_token_bitmask
,
select_top_k_tokens
,
select_top_k_tokens
,
)
)
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
...
@@ -492,11 +493,41 @@ class EAGLEWorker(TpModelWorker):
...
@@ -492,11 +493,41 @@ class EAGLEWorker(TpModelWorker):
batch
.
forward_mode
=
ForwardMode
.
TARGET_VERIFY
batch
.
forward_mode
=
ForwardMode
.
TARGET_VERIFY
batch
.
spec_info
=
spec_info
batch
.
spec_info
=
spec_info
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
if
batch
.
has_grammar
:
retrieve_next_token_cpu
=
spec_info
.
retrive_next_token
.
cpu
()
retrieve_next_sibling_cpu
=
spec_info
.
retrive_next_sibling
.
cpu
()
draft_tokens_cpu
=
spec_info
.
draft_token
.
view
(
spec_info
.
retrive_next_token
.
shape
).
cpu
()
# Forward
logits_output
,
_
,
can_run_cuda_graph
=
(
logits_output
,
_
,
can_run_cuda_graph
=
(
self
.
target_worker
.
forward_batch_generation
(
self
.
target_worker
.
forward_batch_generation
(
model_worker_batch
,
skip_sample
=
True
model_worker_batch
,
skip_sample
=
True
)
)
)
)
vocab_mask
=
None
if
batch
.
has_grammar
:
# Generate the logit mask for structured output.
# Overlap the CPU operations for bitmask generation with the forward pass.
vocab_mask
=
generate_token_bitmask
(
batch
.
reqs
,
spec_info
,
retrieve_next_token_cpu
,
retrieve_next_sibling_cpu
,
draft_tokens_cpu
,
batch
.
sampling_info
.
vocab_size
,
)
if
vocab_mask
is
not
None
:
assert
spec_info
.
grammar
is
not
None
vocab_mask
=
vocab_mask
.
to
(
spec_info
.
retrive_next_token
.
device
)
# otherwise, this vocab mask will be the one from the previous extend stage
# and will be applied to produce wrong results
batch
.
sampling_info
.
vocab_mask
=
None
self
.
_detect_nan_if_needed
(
logits_output
)
self
.
_detect_nan_if_needed
(
logits_output
)
spec_info
.
hidden_states
=
logits_output
.
hidden_states
spec_info
.
hidden_states
=
logits_output
.
hidden_states
res
:
EagleVerifyOutput
=
spec_info
.
verify
(
res
:
EagleVerifyOutput
=
spec_info
.
verify
(
...
@@ -504,6 +535,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -504,6 +535,7 @@ class EAGLEWorker(TpModelWorker):
logits_output
,
logits_output
,
self
.
token_to_kv_pool_allocator
,
self
.
token_to_kv_pool_allocator
,
self
.
page_size
,
self
.
page_size
,
vocab_mask
,
)
)
# Post process based on verified outputs.
# Post process based on verified outputs.
...
...
test/srt/test_eagle_infer.py
View file @
6ce0ed07
...
@@ -481,6 +481,41 @@ class TestEAGLEServer(CustomTestCase):
...
@@ -481,6 +481,41 @@ class TestEAGLEServer(CustomTestCase):
with
ThreadPoolExecutor
(
8
)
as
executor
:
with
ThreadPoolExecutor
(
8
)
as
executor
:
list
(
executor
.
map
(
self
.
run_decode
,
args
))
list
(
executor
.
map
(
self
.
run_decode
,
args
))
def
test_constrained_decoding
(
self
):
messages
=
[
{
"role"
:
"system"
,
"content"
:
"You are a helpful assistant."
},
{
"role"
:
"user"
,
"content"
:
"Give me a json"
},
]
response
=
requests
.
post
(
self
.
base_url
+
"/v1/chat/completions"
,
json
=
{
"model"
:
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
,
"messages"
:
messages
,
"temperature"
:
0
,
"response_format"
:
{
"type"
:
"json_object"
},
},
)
self
.
assertEqual
(
response
.
status_code
,
200
)
res
=
response
.
json
()
# Validate response structure
self
.
assertIn
(
"choices"
,
res
)
self
.
assertEqual
(
len
(
res
[
"choices"
]),
1
)
self
.
assertIn
(
"message"
,
res
[
"choices"
][
0
])
self
.
assertIn
(
"content"
,
res
[
"choices"
][
0
][
"message"
])
# Validate JSON content
content_json
=
res
[
"choices"
][
0
][
"message"
][
"content"
]
is_valid_json
=
True
try
:
content
=
json
.
loads
(
content_json
)
self
.
assertIsInstance
(
content
,
dict
)
except
Exception
:
print
(
f
"parse JSON failed:
{
content_json
}
"
)
is_valid_json
=
False
self
.
assertTrue
(
is_valid_json
)
class
TestEAGLERetract
(
TestEAGLEServer
):
class
TestEAGLERetract
(
TestEAGLEServer
):
@
classmethod
@
classmethod
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment