Unverified Commit 762dbf3f authored by drbh's avatar drbh Committed by GitHub
Browse files

fix: handle batches with and without grammars (#1676)

This PR correctly handles batches with a mixture of constrained and non
constrained generations.

Currently if batch contains mixed generations the generation will throw
an error because it will incorrectly attempt to constrain a request with
an empty grammar.

We now handled `None` grammars and only apply the mask if needed

Fixes:
https://github.com/huggingface/text-generation-inference/issues/1643
parent 818aee37
...@@ -555,6 +555,9 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor): ...@@ -555,6 +555,9 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer) self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
self.fsms = [] self.fsms = []
for grammar, grammar_type in zip(grammars, grammar_types): for grammar, grammar_type in zip(grammars, grammar_types):
if len(grammar) == 0:
self.fsms.append(None)
continue
fsm = GrammarLogitProcessor._cached_compile_fsm( fsm = GrammarLogitProcessor._cached_compile_fsm(
grammar_type, grammar, self.tokenizer grammar_type, grammar, self.tokenizer
) )
...@@ -572,7 +575,7 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor): ...@@ -572,7 +575,7 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
continue continue
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i]) allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
mask[i, allowed_tokens] = 0 mask[i, allowed_tokens] = 0
logits += mask logits[i] += mask[i]
return logits return logits
def advance_batch(self, next_token_ids, fsm_grammar_states): def advance_batch(self, next_token_ids, fsm_grammar_states):
...@@ -584,6 +587,8 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor): ...@@ -584,6 +587,8 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
] ]
def advance_at_index(self, next_token_id, fsm_grammar_state, index): def advance_at_index(self, next_token_id, fsm_grammar_state, index):
if self.fsms[index] is None:
return fsm_grammar_state
return GrammarLogitProcessor._advance( return GrammarLogitProcessor._advance(
next_token_id, fsm_grammar_state, self.fsms[index] next_token_id, fsm_grammar_state, self.fsms[index]
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment