Unverified Commit 6d6ea5af authored by tazjin's avatar tazjin Committed by GitHub
Browse files

fix: do not wrap invalid grammar objects during constrained generation (#11328)

parent 1dacedd2
...@@ -17,7 +17,11 @@ from typing import List, Optional, Tuple ...@@ -17,7 +17,11 @@ from typing import List, Optional, Tuple
import torch import torch
from .base_grammar_backend import BaseGrammarBackend, BaseGrammarObject from .base_grammar_backend import (
INVALID_GRAMMAR_OBJ,
BaseGrammarBackend,
BaseGrammarObject,
)
class ReasonerGrammarObject(BaseGrammarObject): class ReasonerGrammarObject(BaseGrammarObject):
...@@ -81,10 +85,9 @@ class ReasonerGrammarBackend(BaseGrammarBackend): ...@@ -81,10 +85,9 @@ class ReasonerGrammarBackend(BaseGrammarBackend):
self.grammar_backend = grammar_backend self.grammar_backend = grammar_backend
self.think_end_id = think_end_id self.think_end_id = think_end_id
def _init_value_dispatch( def _init_value_dispatch(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
self, key: Tuple[str, str]
) -> Optional[ReasonerGrammarObject]:
ret = self.grammar_backend._init_value_dispatch(key) ret = self.grammar_backend._init_value_dispatch(key)
if ret is None: # avoid wrapping invalid grammar, so that the scheduler can detect it
return None if ret is None or ret is INVALID_GRAMMAR_OBJ:
return ret
return ReasonerGrammarObject(ret, self.think_end_id) return ReasonerGrammarObject(ret, self.think_end_id)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment