Unverified Commit 9c745d07 authored by DarkSharpness's avatar DarkSharpness Committed by GitHub
Browse files

[Performance] Update xgrammar-related constrained decoding (#2056)

parent ebaa2f31
...@@ -81,10 +81,20 @@ class OutlinesGrammar(BaseGrammarObject): ...@@ -81,10 +81,20 @@ class OutlinesGrammar(BaseGrammarObject):
): ):
self.state = next_state self.state = next_state
def fill_vocab_mask(self, vocab_mask: torch.Tensor): def allocate_vocab_mask(
self, vocab_size: int, batch_size: int, device
) -> torch.Tensor:
return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
vocab_mask = vocab_mask[idx]
vocab_mask.fill_(1) vocab_mask.fill_(1)
vocab_mask[self.guide.get_next_instruction(self.state).tokens] = 0 vocab_mask[self.guide.get_next_instruction(self.state).tokens] = 0
@staticmethod
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor):
logits.masked_fill_(vocab_mask, float("-inf"))
def copy(self): def copy(self):
return OutlinesGrammar(self.guide, self.jump_forward_map) return OutlinesGrammar(self.guide, self.jump_forward_map)
......
...@@ -21,7 +21,12 @@ from typing import List, Tuple ...@@ -21,7 +21,12 @@ from typing import List, Tuple
import torch import torch
try: try:
from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher from xgrammar import (
CachedGrammarCompiler,
CompiledGrammar,
GrammarMatcher,
TokenizerInfo,
)
import_error = None import_error = None
except ImportError as e: except ImportError as e:
...@@ -80,19 +85,23 @@ class XGrammarGrammar(BaseGrammarObject): ...@@ -80,19 +85,23 @@ class XGrammarGrammar(BaseGrammarObject):
for i in range(k, len(new_output_ids)): for i in range(k, len(new_output_ids)):
assert self.matcher.accept_token(new_output_ids[i]) assert self.matcher.accept_token(new_output_ids[i])
def fill_vocab_mask(self, vocab_mask: torch.Tensor): def allocate_vocab_mask(
# Note that this bitmask is a bitset, not bool self, vocab_size: int, batch_size: int, device
bitmask = self.matcher.get_next_token_bitmask() ) -> torch.Tensor:
# Mask the tokens that are not allowed return self.matcher.allocate_token_bitmask(vocab_size, batch_size)
vocab_mask[
self.matcher.get_rejected_tokens_from_bitmask(bitmask, self.vocab_size) def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
] = 1 self.matcher.fill_next_token_bitmask(vocab_mask, idx)
@staticmethod
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
GrammarMatcher.apply_token_bitmask_inplace(logits, vocab_mask)
def copy(self): def copy(self):
matcher = GrammarMatcher( matcher = GrammarMatcher(
self.ctx, self.ctx,
max_rollback_tokens=MAX_ROLLBACK_TOKENS, max_rollback_tokens=MAX_ROLLBACK_TOKENS,
mask_vocab_size=self.vocab_size, vocab_size=self.vocab_size,
) )
return XGrammarGrammar(matcher, self.vocab_size, self.ctx) return XGrammarGrammar(matcher, self.vocab_size, self.ctx)
...@@ -112,7 +121,8 @@ class XGrammarGrammarBackend(BaseGrammarBackend): ...@@ -112,7 +121,8 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
self.grammar_cache = None self.grammar_cache = None
return return
self.grammar_cache = CachedGrammarCompiler(tokenizer_or_vocab=tokenizer) tokenizer_info = TokenizerInfo.from_huggingface(tokenizer)
self.grammar_cache = CachedGrammarCompiler(tokenizer_info=tokenizer_info)
self.vocab_size = vocab_size self.vocab_size = vocab_size
def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar: def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
...@@ -122,9 +132,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend): ...@@ -122,9 +132,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
key_type, key_string = key key_type, key_string = key
if key_type == "json": if key_type == "json":
try: try:
ctx = self.grammar_cache.get_compiled_grammar_for_json_schema( ctx = self.grammar_cache.compile_json_schema_grammar(schema=key_string)
key_string
)
except RuntimeError as e: except RuntimeError as e:
logging.warning( logging.warning(
f"Skip invalid json_schema: json_schema={key_string}, {e=}" f"Skip invalid json_schema: json_schema={key_string}, {e=}"
...@@ -141,7 +149,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend): ...@@ -141,7 +149,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
matcher = GrammarMatcher( matcher = GrammarMatcher(
ctx, ctx,
max_rollback_tokens=MAX_ROLLBACK_TOKENS, max_rollback_tokens=MAX_ROLLBACK_TOKENS,
mask_vocab_size=self.vocab_size, vocab_size=self.vocab_size,
) )
return XGrammarGrammar(matcher, self.vocab_size, ctx) return XGrammarGrammar(matcher, self.vocab_size, ctx)
......
...@@ -645,7 +645,7 @@ class ModelRunner: ...@@ -645,7 +645,7 @@ class ModelRunner:
# Apply regex vocab_mask # Apply regex vocab_mask
if sampling_info.vocab_mask is not None: if sampling_info.vocab_mask is not None:
logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf")) sampling_info.apply_mask(logits=logits, vocab_mask=sampling_info.vocab_mask)
return logits return logits
......
from __future__ import annotations from __future__ import annotations
import dataclasses import dataclasses
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, Callable, List, Optional
import torch import torch
...@@ -29,7 +29,7 @@ class SamplingBatchInfo: ...@@ -29,7 +29,7 @@ class SamplingBatchInfo:
vocab_size: int vocab_size: int
logit_bias: torch.Tensor = None logit_bias: torch.Tensor = None
vocab_mask: Optional[torch.Tensor] = None vocab_mask: Optional[torch.Tensor] = None
apply_mask: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
grammars: Optional[List] = None grammars: Optional[List] = None
# Penalizer # Penalizer
...@@ -135,17 +135,23 @@ class SamplingBatchInfo: ...@@ -135,17 +135,23 @@ class SamplingBatchInfo:
def update_regex_vocab_mask(self): def update_regex_vocab_mask(self):
if not self.grammars or not any(grammar for grammar in self.grammars): if not self.grammars or not any(grammar for grammar in self.grammars):
self.vocab_mask = None self.vocab_mask = None
self.apply_mask = None
return return
self.vocab_mask = torch.zeros( # find a grammar from the list
len(self.temperatures), grammar = next(grammar for grammar in self.grammars if grammar is not None)
self.vocab_size,
dtype=torch.bool, # maybe we can reuse the existing mask?
self.vocab_mask = grammar.allocate_vocab_mask(
vocab_size=self.vocab_size,
batch_size=len(self.temperatures),
device=self.device, device=self.device,
) )
self.apply_mask = type(grammar).apply_vocab_mask # force to use static method
for i, grammar in enumerate(self.grammars): for i, grammar in enumerate(self.grammars):
if grammar is not None: if grammar is not None:
grammar.fill_vocab_mask(self.vocab_mask[i]) grammar.fill_vocab_mask(self.vocab_mask, i)
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor): def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
if self.penalizer_orchestrator: if self.penalizer_orchestrator:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment