Unverified Commit 6d6ea5af authored by tazjin's avatar tazjin Committed by GitHub
Browse files

fix: do not wrap invalid grammar objects during constrained generation (#11328)

parent 1dacedd2
......@@ -17,7 +17,11 @@ from typing import List, Optional, Tuple
import torch
from .base_grammar_backend import BaseGrammarBackend, BaseGrammarObject
from .base_grammar_backend import (
INVALID_GRAMMAR_OBJ,
BaseGrammarBackend,
BaseGrammarObject,
)
class ReasonerGrammarObject(BaseGrammarObject):
......@@ -81,10 +85,9 @@ class ReasonerGrammarBackend(BaseGrammarBackend):
self.grammar_backend = grammar_backend
self.think_end_id = think_end_id
def _init_value_dispatch(
self, key: Tuple[str, str]
) -> Optional[ReasonerGrammarObject]:
def _init_value_dispatch(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
ret = self.grammar_backend._init_value_dispatch(key)
if ret is None:
return None
# avoid wrapping invalid grammar, so that the scheduler can detect it
if ret is None or ret is INVALID_GRAMMAR_OBJ:
return ret
return ReasonerGrammarObject(ret, self.think_end_id)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment