Unverified Commit 343aa7a1 authored by drbh's avatar drbh Committed by GitHub
Browse files

fix: Handle concurrent grammar requests (#1610)

This PR fixes parallel grammar requests, currently grammar states are
not concatenated correctly when a new request is added to the batch and
this results in incorrect generation. This PR updates the `concatenate`
function to correctly include the previous states.

fixes: #1601
parent e6bb3ff8
......@@ -61,7 +61,7 @@
},
{
"id": 29906,
"logprob": -0.2376709,
"logprob": -0.33666992,
"special": false,
"text": "2"
},
......@@ -180,7 +180,7 @@
},
{
"id": 29906,
"logprob": -0.23840332,
"logprob": -0.33740234,
"special": false,
"text": "2"
},
......@@ -299,7 +299,7 @@
},
{
"id": 29906,
"logprob": -0.23840332,
"logprob": -0.33740234,
"special": false,
"text": "2"
},
......@@ -418,7 +418,7 @@
},
{
"id": 29906,
"logprob": -0.23840332,
"logprob": -0.33740234,
"special": false,
"text": "2"
},
......
......@@ -530,6 +530,7 @@ class FlashCausalLMBatch(Batch):
read_offsets = []
next_token_chooser_parameters = []
fsm_grammar_states = []
stopping_criterias = []
top_n_tokens = []
......@@ -578,6 +579,7 @@ class FlashCausalLMBatch(Batch):
read_offsets.extend(batch.read_offsets)
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states)
stopping_criterias.extend(batch.stopping_criterias)
top_n_tokens.extend(batch.top_n_tokens)
......@@ -593,6 +595,7 @@ class FlashCausalLMBatch(Batch):
dtype=batches[0].next_token_chooser.dtype,
device=batches[0].next_token_chooser.device,
tokenizer=batches[0].next_token_chooser.tokenizer,
fsm_grammar_states=fsm_grammar_states,
)
speculative_ids = (
......
......@@ -466,6 +466,7 @@ class HeterogeneousNextTokenChooser:
dtype: torch.dtype,
device: torch.device,
tokenizer: PreTrainedTokenizerBase,
fsm_grammar_states: Optional[List[int]] = None,
) -> "HeterogeneousNextTokenChooser":
return HeterogeneousNextTokenChooser(
watermark=[pb_.watermark for pb_ in pb],
......@@ -482,7 +483,9 @@ class HeterogeneousNextTokenChooser:
tokenizer=tokenizer,
grammars=[pb_.grammar for pb_ in pb],
grammar_types=[pb_.grammar_type for pb_ in pb],
fsm_grammar_states=[0] * len(pb),
fsm_grammar_states=(
fsm_grammar_states if fsm_grammar_states else [0] * len(pb)
),
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment