Unverified Commit e493e485 authored by Madeesh Kannan's avatar Madeesh Kannan Committed by GitHub
Browse files

[V0][Bugfix] Fix parallel sampling performance regression when guided decoding is enabled (#17731)


Signed-off-by: default avatarMadeesh Kannan <shadeMe@users.noreply.github.com>
Co-authored-by: default avatarRussell Bryant <rbryant@redhat.com>
parent 4ce64e2d
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import copy
import os import os
from typing import Any from typing import Any
...@@ -34,9 +35,24 @@ class GuidanceLogitsProcessor: ...@@ -34,9 +35,24 @@ class GuidanceLogitsProcessor:
self.grammar = grammar self.grammar = grammar
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.tokenizer_name = tokenizer.name_or_path self.tokenizer_name = tokenizer.name_or_path
self.ll_tokenizer = None
self.ll_matcher = None
self.bitmask = None
self.new_sampling = False self.new_sampling = False
self.initialized = False self.initialized = False
def clone(self) -> "GuidanceLogitsProcessor":
cloned = copy.copy(self)
if self.initialized:
cloned.ll_matcher = llguidance.LLMatcher(
self.ll_tokenizer, # type: ignore[assignment]
self.grammar,
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
)
self.bitmask = llguidance.torch.allocate_token_bitmask(
1, self.ll_tokenizer.vocab_size) # type: ignore[attr-defined]
return cloned
def _initialize(self): def _initialize(self):
if self.initialized: if self.initialized:
return return
...@@ -56,7 +72,7 @@ class GuidanceLogitsProcessor: ...@@ -56,7 +72,7 @@ class GuidanceLogitsProcessor:
# create reusable bitmask # create reusable bitmask
self.bitmask = llguidance.torch.allocate_token_bitmask( self.bitmask = llguidance.torch.allocate_token_bitmask(
1, self.ll_tokenizer.vocab_size) 1, self.ll_tokenizer.vocab_size) # type: ignore[attr-defined]
self.initialized = True self.initialized = True
...@@ -70,15 +86,17 @@ class GuidanceLogitsProcessor: ...@@ -70,15 +86,17 @@ class GuidanceLogitsProcessor:
self._initialize() self._initialize()
if self.new_sampling and len(input_ids) > 0: if self.new_sampling and len(input_ids) > 0:
self.ll_matcher.consume_token(input_ids[-1]) self.ll_matcher.consume_token( # type: ignore[attr-defined]
err = self.ll_matcher.get_error() input_ids[-1])
err = self.ll_matcher.get_error() # type: ignore[attr-defined]
if err: if err:
logger.warning("Error in LLMatcher: %s", err) logger.warning("Error in LLMatcher: %s", err)
llguidance.torch.fill_next_token_bitmask(self.ll_matcher, self.bitmask, llguidance.torch.fill_next_token_bitmask(self.ll_matcher, self.bitmask,
0) 0)
llguidance.torch.apply_token_bitmask_inplace( llguidance.torch.apply_token_bitmask_inplace(
scores, self.bitmask.to(scores.device)) scores,
self.bitmask.to(scores.device)) # type: ignore[attr-defined]
self.new_sampling = True self.new_sampling = True
......
...@@ -56,6 +56,12 @@ class BaseLogitsProcessor: ...@@ -56,6 +56,12 @@ class BaseLogitsProcessor:
self._fsm_state: defaultdict[int, Union[int, self._fsm_state: defaultdict[int, Union[int,
CFGState]] = defaultdict(int) CFGState]] = defaultdict(int)
def clone(self) -> "BaseLogitsProcessor":
cloned = copy.copy(self)
cloned._guide = self._guide.copy()
cloned._fsm_state = copy.deepcopy(self._fsm_state)
return cloned
def __call__(self, input_ids: list[int], def __call__(self, input_ids: list[int],
scores: torch.Tensor) -> torch.Tensor: scores: torch.Tensor) -> torch.Tensor:
"""Use the FSM to bias the logits before sampling the next token.""" """Use the FSM to bias the logits before sampling the next token."""
...@@ -218,6 +224,12 @@ class CFGLogitsProcessor(BaseLogitsProcessor): ...@@ -218,6 +224,12 @@ class CFGLogitsProcessor(BaseLogitsProcessor):
reasoner) reasoner)
self._guide = self._guide.copy() self._guide = self._guide.copy()
def clone(self) -> "CFGLogitsProcessor":
cloned = copy.copy(self)
cloned._fsm_state = copy.deepcopy(self._fsm_state)
cloned._guide = self._guide.copy()
return cloned
@lru_cache(maxsize=32) @lru_cache(maxsize=32)
def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase): def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
......
...@@ -302,6 +302,7 @@ class XGrammarLogitsProcessor: ...@@ -302,6 +302,7 @@ class XGrammarLogitsProcessor:
prefilled: bool = field(default=False) prefilled: bool = field(default=False)
def __post_init__(self): def __post_init__(self):
if self.tokenizer_info is None:
self.tokenizer_info = self.config.tokenizer_info( self.tokenizer_info = self.config.tokenizer_info(
self.config.tokenizer_data) self.config.tokenizer_data)
...@@ -400,7 +401,8 @@ class XGrammarLogitsProcessor: ...@@ -400,7 +401,8 @@ class XGrammarLogitsProcessor:
def clone(self) -> XGrammarLogitsProcessor: def clone(self) -> XGrammarLogitsProcessor:
"""Create a new instance with shared compiled grammar """Create a new instance with shared compiled grammar
but separate state""" but separate state"""
new_processor = XGrammarLogitsProcessor(self.config, self.reasoner) new_processor = XGrammarLogitsProcessor(self.config, self.reasoner,
None, self.tokenizer_info)
# Share the compiled grammar context (immutable after compilation) # Share the compiled grammar context (immutable after compilation)
new_processor.ctx = self.ctx new_processor.ctx = self.ctx
......
...@@ -1494,7 +1494,7 @@ class ParallelSampleSequenceGroup(SequenceGroupBase): ...@@ -1494,7 +1494,7 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
for i in range(original_params.n): for i in range(original_params.n):
request_id_i = f"{request_id}_parallel_sample_{i}" request_id_i = f"{request_id}_parallel_sample_{i}"
group.seq_id_to_index[request_id_i] = i group.seq_id_to_index[request_id_i] = i
params = copy.deepcopy(original_params) params = params.clone()
params.n = 1 params.n = 1
if params.seed is not None: if params.seed is not None:
params.seed += i params.seed += i
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment