Unverified Commit 6ce0ed07 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Apply constraint grammar to EAGLE (#6499)


Co-authored-by: default avatarmerrymercy <lianminzheng@gmail.com>
parent 969660c7
...@@ -9,15 +9,18 @@ import torch.nn.functional as F ...@@ -9,15 +9,18 @@ import torch.nn.functional as F
import triton import triton
import triton.language as tl import triton.language as tl
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import ( from sglang.srt.managers.schedule_batch import (
Req,
ScheduleBatch, ScheduleBatch,
get_last_loc, get_last_loc,
global_server_args_dict, global_server_args_dict,
) )
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
from sglang.srt.utils import fast_topk, is_cuda, is_hip, next_power_of_2 from sglang.srt.utils import fast_topk, is_cuda, is_hip, next_power_of_2
...@@ -187,6 +190,7 @@ class EagleVerifyInput: ...@@ -187,6 +190,7 @@ class EagleVerifyInput:
draft_token_num: int draft_token_num: int
spec_steps: int spec_steps: int
capture_hidden_mode: CaptureHiddenMode capture_hidden_mode: CaptureHiddenMode
grammar: BaseGrammarObject = None
@classmethod @classmethod
def create( def create(
...@@ -307,6 +311,7 @@ class EagleVerifyInput: ...@@ -307,6 +311,7 @@ class EagleVerifyInput:
logits_output: torch.Tensor, logits_output: torch.Tensor,
token_to_kv_pool_allocator: TokenToKVPoolAllocator, token_to_kv_pool_allocator: TokenToKVPoolAllocator,
page_size: int, page_size: int,
vocab_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Verify and find accepted tokens based on logits output and batch Verify and find accepted tokens based on logits output and batch
...@@ -343,6 +348,13 @@ class EagleVerifyInput: ...@@ -343,6 +348,13 @@ class EagleVerifyInput:
torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0) torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0)
) )
# Apply grammar mask
if vocab_mask is not None:
assert self.grammar is not None
self.grammar.apply_vocab_mask(
logits=logits_output.next_token_logits, vocab_mask=vocab_mask
)
# Sample tokens # Sample tokens
if batch.sampling_info.is_all_greedy: if batch.sampling_info.is_all_greedy:
target_predict = torch.argmax(logits_output.next_token_logits, dim=-1) target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
...@@ -440,6 +452,15 @@ class EagleVerifyInput: ...@@ -440,6 +452,15 @@ class EagleVerifyInput:
break break
else: else:
new_accept_index_.append(idx) new_accept_index_.append(idx)
# update grammar state
if req.grammar is not None:
try:
req.grammar.accept_token(id)
except ValueError as e:
logger.info(
f"{i=}, {req=}\n" f"{accept_index=}\n" f"{predict=}\n"
)
raise e
if not req.finished(): if not req.finished():
new_accept_index.extend(new_accept_index_) new_accept_index.extend(new_accept_index_)
unfinished_index.append(i) unfinished_index.append(i)
...@@ -801,3 +822,113 @@ def _generate_simulated_accept_index( ...@@ -801,3 +822,113 @@ def _generate_simulated_accept_index(
accept_length.fill_(simulate_acc_len - 1) accept_length.fill_(simulate_acc_len - 1)
predict.fill_(100) # some legit token id predict.fill_(100) # some legit token id
return sim_accept_index return sim_accept_index
def traverse_tree(
retrieve_next_token: torch.Tensor,
retrieve_next_sibling: torch.Tensor,
draft_tokens: torch.Tensor,
grammar: BaseGrammarObject,
allocate_token_bitmask: torch.Tensor,
):
"""
Traverse the tree constructed by the draft model to generate the logits mask.
"""
assert (
retrieve_next_token.shape == retrieve_next_sibling.shape == draft_tokens.shape
)
allocate_token_bitmask.fill_(0)
def dfs(
curr: int,
retrieve_next_token: torch.Tensor,
retrieve_next_sibling: torch.Tensor,
parent_pos: int,
):
if curr == 0:
# the first token generated by the target model, and thus it is always
# accepted from the previous iteration
accepted = True
else:
parent_bitmask = allocate_token_bitmask[parent_pos]
curr_token_id = draft_tokens[curr]
# 32 boolean bitmask values are packed into 32-bit integers
accepted = (
parent_bitmask[curr_token_id // 32] & (1 << (curr_token_id % 32))
) != 0
if accepted:
if curr != 0:
# Accept the current token
grammar.accept_token(draft_tokens[curr])
if not grammar.is_terminated():
# Generate the bitmask for the current token
grammar.fill_vocab_mask(allocate_token_bitmask, curr)
if retrieve_next_token[curr] != -1:
# Visit the child node
dfs(
retrieve_next_token[curr],
retrieve_next_token,
retrieve_next_sibling,
curr,
)
if curr != 0:
# Rollback the current token
grammar.rollback(1)
if retrieve_next_sibling[curr] != -1:
# Visit the sibling node
dfs(
retrieve_next_sibling[curr],
retrieve_next_token,
retrieve_next_sibling,
parent_pos,
)
dfs(0, retrieve_next_token, retrieve_next_sibling, -1)
def generate_token_bitmask(
reqs: List[Req],
verify_input: EagleVerifyInput,
retrieve_next_token_cpu: torch.Tensor,
retrieve_next_sibling_cpu: torch.Tensor,
draft_tokens_cpu: torch.Tensor,
vocab_size: int,
):
"""
Generate the logit mask for structured output.
Draft model's token can be either valid or invalid with respect to the grammar.
We need to perform DFS to figure out:
1. which tokens are accepted by the grammar
2. what is the corresponding logit mask.
"""
num_draft_tokens = draft_tokens_cpu.shape[-1]
allocate_token_bitmask = None
assert len(reqs) == retrieve_next_token_cpu.shape[0]
grammar = None
for i, req in enumerate(reqs):
if req.grammar is not None:
if allocate_token_bitmask is None:
allocate_token_bitmask = req.grammar.allocate_vocab_mask(
vocab_size=vocab_size,
batch_size=draft_tokens_cpu.numel(),
device="cpu",
)
grammar = req.grammar
traverse_tree(
retrieve_next_token_cpu[i],
retrieve_next_sibling_cpu[i],
draft_tokens_cpu[i],
req.grammar,
allocate_token_bitmask[
i * num_draft_tokens : (i + 1) * num_draft_tokens
],
)
verify_input.grammar = grammar
return allocate_token_bitmask
...@@ -31,6 +31,7 @@ from sglang.srt.speculative.eagle_utils import ( ...@@ -31,6 +31,7 @@ from sglang.srt.speculative.eagle_utils import (
EagleVerifyInput, EagleVerifyInput,
EagleVerifyOutput, EagleVerifyOutput,
assign_draft_cache_locs, assign_draft_cache_locs,
generate_token_bitmask,
select_top_k_tokens, select_top_k_tokens,
) )
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
...@@ -492,11 +493,41 @@ class EAGLEWorker(TpModelWorker): ...@@ -492,11 +493,41 @@ class EAGLEWorker(TpModelWorker):
batch.forward_mode = ForwardMode.TARGET_VERIFY batch.forward_mode = ForwardMode.TARGET_VERIFY
batch.spec_info = spec_info batch.spec_info = spec_info
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
if batch.has_grammar:
retrieve_next_token_cpu = spec_info.retrive_next_token.cpu()
retrieve_next_sibling_cpu = spec_info.retrive_next_sibling.cpu()
draft_tokens_cpu = spec_info.draft_token.view(
spec_info.retrive_next_token.shape
).cpu()
# Forward
logits_output, _, can_run_cuda_graph = ( logits_output, _, can_run_cuda_graph = (
self.target_worker.forward_batch_generation( self.target_worker.forward_batch_generation(
model_worker_batch, skip_sample=True model_worker_batch, skip_sample=True
) )
) )
vocab_mask = None
if batch.has_grammar:
# Generate the logit mask for structured output.
# Overlap the CPU operations for bitmask generation with the forward pass.
vocab_mask = generate_token_bitmask(
batch.reqs,
spec_info,
retrieve_next_token_cpu,
retrieve_next_sibling_cpu,
draft_tokens_cpu,
batch.sampling_info.vocab_size,
)
if vocab_mask is not None:
assert spec_info.grammar is not None
vocab_mask = vocab_mask.to(spec_info.retrive_next_token.device)
# otherwise, this vocab mask will be the one from the previous extend stage
# and will be applied to produce wrong results
batch.sampling_info.vocab_mask = None
self._detect_nan_if_needed(logits_output) self._detect_nan_if_needed(logits_output)
spec_info.hidden_states = logits_output.hidden_states spec_info.hidden_states = logits_output.hidden_states
res: EagleVerifyOutput = spec_info.verify( res: EagleVerifyOutput = spec_info.verify(
...@@ -504,6 +535,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -504,6 +535,7 @@ class EAGLEWorker(TpModelWorker):
logits_output, logits_output,
self.token_to_kv_pool_allocator, self.token_to_kv_pool_allocator,
self.page_size, self.page_size,
vocab_mask,
) )
# Post process based on verified outputs. # Post process based on verified outputs.
......
...@@ -481,6 +481,41 @@ class TestEAGLEServer(CustomTestCase): ...@@ -481,6 +481,41 @@ class TestEAGLEServer(CustomTestCase):
with ThreadPoolExecutor(8) as executor: with ThreadPoolExecutor(8) as executor:
list(executor.map(self.run_decode, args)) list(executor.map(self.run_decode, args))
def test_constrained_decoding(self):
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Give me a json"},
]
response = requests.post(
self.base_url + "/v1/chat/completions",
json={
"model": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
"messages": messages,
"temperature": 0,
"response_format": {"type": "json_object"},
},
)
self.assertEqual(response.status_code, 200)
res = response.json()
# Validate response structure
self.assertIn("choices", res)
self.assertEqual(len(res["choices"]), 1)
self.assertIn("message", res["choices"][0])
self.assertIn("content", res["choices"][0]["message"])
# Validate JSON content
content_json = res["choices"][0]["message"]["content"]
is_valid_json = True
try:
content = json.loads(content_json)
self.assertIsInstance(content, dict)
except Exception:
print(f"parse JSON failed: {content_json}")
is_valid_json = False
self.assertTrue(is_valid_json)
class TestEAGLERetract(TestEAGLEServer): class TestEAGLERetract(TestEAGLEServer):
@classmethod @classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment