Unverified Commit 3a1e6481 authored by Russell Bryant's avatar Russell Bryant Committed by GitHub
Browse files

[V1] Refactor Structured Output for multiple backends (#14694)


Signed-off-by: default avatarRussell Bryant <rbryant@redhat.com>
parent 46c759c1
...@@ -119,16 +119,21 @@ class Processor: ...@@ -119,16 +119,21 @@ class Processor:
def _validate_structured_output(self, params: SamplingParams) -> None: def _validate_structured_output(self, params: SamplingParams) -> None:
if not params.guided_decoding or not self.decoding_config: if not params.guided_decoding or not self.decoding_config:
return return
if self.decoding_config.guided_decoding_backend != "xgrammar":
raise ValueError( supported_backends = ["xgrammar"]
"Only xgrammar structured output is supported in V1.") engine_level_backend = self.decoding_config.guided_decoding_backend
if (params.guided_decoding.backend if engine_level_backend not in supported_backends:
and params.guided_decoding.backend != 'xgrammar'): raise ValueError(f"Only {supported_backends} structured output is "
raise ValueError( "supported in V1.")
"Only xgrammar structured output is supported in V1.") if params.guided_decoding.backend:
if self.vllm_config.speculative_config: if params.guided_decoding.backend != engine_level_backend:
raise ValueError("Structured output is not supported with " raise ValueError("Request-level structured output backend "
"speculative decoding.") "must match engine-level backend. "
f"{params.guided_decoding.backend}"
f" != {engine_level_backend}")
else:
params.guided_decoding.backend = engine_level_backend
if vllm.platforms.current_platform.is_tpu(): if vllm.platforms.current_platform.is_tpu():
raise ValueError("Structured output is not supported on TPU.") raise ValueError("Structured output is not supported on TPU.")
......
...@@ -7,75 +7,27 @@ from typing import TYPE_CHECKING, Optional ...@@ -7,75 +7,27 @@ from typing import TYPE_CHECKING, Optional
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer StructuredOutputGrammar)
from vllm.utils import LazyLoader from vllm.v1.structured_output.backend_xgrammar import XgrammarBackend
from vllm.v1.structured_output.grammar import Grammar, StructuredOutputOptions
if TYPE_CHECKING: if TYPE_CHECKING:
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
import xgrammar as xgr import torch
from vllm.v1.request import Request from vllm.v1.request import Request
else:
xgr = LazyLoader("xgr", globals(), "xgrammar")
logger = init_logger(__name__) logger = init_logger(__name__)
class StructuredOutputManager: class StructuredOutputManager:
"""Engine-level manager for structured output requests."""
def __init__(self, vllm_config: VllmConfig): def __init__(self, vllm_config: VllmConfig):
self.backend: Optional[StructuredOutputBackend] = None
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.init_complete = False self._grammar_bitmask: Optional[torch.Tensor] = None
def _delayed_init(self):
"""Initialization delayed until we know it is needed."""
tokenizer_group = init_tokenizer_from_configs(
model_config=self.vllm_config.model_config,
scheduler_config=self.vllm_config.scheduler_config,
parallel_config=self.vllm_config.parallel_config,
lora_config=self.vllm_config.lora_config) # type: ignore[arg-type]
tokenizer_group.ping()
tokenizer = tokenizer_group.get_lora_tokenizer(None)
self.vocab_size = self.vllm_config.model_config.get_vocab_size()
if isinstance(tokenizer, MistralTokenizer):
# NOTE: ideally, xgrammar should handle this accordingly.
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
try:
encoded_vocab = [
token for token, _ in sorted(
tokenizer.get_vocab().items(),
key=lambda x: x[1],
)
]
stop_token_ids = None
if hasattr(
tokenizer,
"eos_token_id",
) and tokenizer.eos_token_id is not None:
stop_token_ids = [tokenizer.eos_token_id]
except AttributeError as e:
raise ValueError(
f"Cannot get the vocabulary of the tokenizer "
f"{type(tokenizer)}. The tokenizer should have a "
"get_vocab method.") from e
tokenizer_info = xgr.TokenizerInfo(
encoded_vocab=encoded_vocab,
# NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
vocab_type=xgr.VocabType.BYTE_FALLBACK,
vocab_size=self.vocab_size,
stop_token_ids=stop_token_ids,
add_prefix_space=True,
)
else:
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
tokenizer,
vocab_size=self.vocab_size,
)
self.compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8)
# The default max_workers if not specified is the number of CPUs * 5, # The default max_workers if not specified is the number of CPUs * 5,
# which is way too high since these tasks are CPU-bound, not I/O bound. # which is way too high since these tasks are CPU-bound, not I/O bound.
...@@ -83,28 +35,30 @@ class StructuredOutputManager: ...@@ -83,28 +35,30 @@ class StructuredOutputManager:
# compilation, so we set it to half the number of CPUs. # compilation, so we set it to half the number of CPUs.
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2) max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
self.executor = ThreadPoolExecutor(max_workers=max_workers) self.executor = ThreadPoolExecutor(max_workers=max_workers)
self._grammar_bitmask = xgr.allocate_token_bitmask(
self.vllm_config.scheduler_config.max_num_seqs,
self.vocab_size,
)
self.init_complete = True
def grammar_init(self, request: Request) -> None: def grammar_init(self, request: Request) -> None:
if request.structured_output_request is None: if request.structured_output_request is None:
return return
# The first time this is called, we need to finish initialization # Initialize the backend the first time it is needed.
# of xgrammar. We defer it to avoid the import of xgrammar and #
# initialization cost if it is not going to be used. # NOTE: We only support a single backend. We do NOT support different
if not self.init_complete: # backends on a per-request basis in V1 (for now, anyway...).
self._delayed_init() if self.backend is None:
backend_name = request.sampling_params.guided_decoding.backend_name
if backend_name == "xgrammar":
self.backend = XgrammarBackend(self.vllm_config)
else:
raise ValueError(
f"Unsupported structured output backend: {backend_name}")
grammar: Future[Grammar] = self.executor.submit( grammar: Future[StructuredOutputGrammar] = self.executor.submit(
self._async_create_grammar, request) self._async_create_grammar, request, self.backend)
request.structured_output_request.grammar = grammar # type: ignore[assignment] request.structured_output_request.grammar = grammar # type: ignore[assignment]
def _async_create_grammar(self, request: Request) -> Grammar: def _async_create_grammar(
self, request: Request,
backend: StructuredOutputBackend) -> StructuredOutputGrammar:
key = request.structured_output_request.structured_output_key # type: ignore[union-attr] key = request.structured_output_request.structured_output_key # type: ignore[union-attr]
# Note that the request was validated in the engine core client, # Note that the request was validated in the engine core client,
...@@ -114,28 +68,8 @@ class StructuredOutputManager: ...@@ -114,28 +68,8 @@ class StructuredOutputManager:
# though it should be unlikely as we test that up front as well. # though it should be unlikely as we test that up front as well.
request_type, grammar_spec = key request_type, grammar_spec = key
if request_type == StructuredOutputOptions.JSON: assert self.backend is not None
# TODO -- allow any_whitespace to be configurable return self.backend.compile_grammar(request_type, grammar_spec)
# pending merge of https://github.com/vllm-project/vllm/pull/12744
ctx = self.compiler.compile_json_schema(grammar_spec,
any_whitespace=False)
elif request_type == StructuredOutputOptions.JSON_OBJECT:
ctx = self.compiler.compile_builtin_json_grammar()
elif request_type == StructuredOutputOptions.GRAMMAR:
ctx = self.compiler.compile_grammar(grammar_spec)
elif request_type == StructuredOutputOptions.REGEX:
ctx = self.compiler.compile_regex(grammar_spec)
else:
logger.error("Validation should have already occurred. "
"Please file an issue.")
raise ValueError(
f"grammar is not of valid supported types. ({request_type!s})")
return Grammar(
matcher=xgr.GrammarMatcher(ctx),
vocab_size=self.vocab_size,
ctx=ctx,
)
def grammar_bitmask( def grammar_bitmask(
self, self,
...@@ -147,6 +81,11 @@ class StructuredOutputManager: ...@@ -147,6 +81,11 @@ class StructuredOutputManager:
if not structured_output_request_ids: if not structured_output_request_ids:
return None return None
if self._grammar_bitmask is None:
assert self.backend is not None
self._grammar_bitmask = self.backend.allocate_token_bitmask(
self.vllm_config.scheduler_config.max_num_seqs)
# Fill the bitmask using the index of each request equal to its # Fill the bitmask using the index of each request equal to its
# position in the batch. Resize the bitmask down to the size of # position in the batch. Resize the bitmask down to the size of
# the batch. # the batch.
...@@ -154,7 +93,7 @@ class StructuredOutputManager: ...@@ -154,7 +93,7 @@ class StructuredOutputManager:
for req_id, batch_index in structured_output_request_ids.items(): for req_id, batch_index in structured_output_request_ids.items():
request = requests[req_id].structured_output_request request = requests[req_id].structured_output_request
assert request is not None and request.grammar is not None assert request is not None and request.grammar is not None
if not request.grammar.matcher.is_terminated(): if not request.grammar.is_terminated():
request.grammar.fill_bitmask(bitmask_tensor, batch_index) request.grammar.fill_bitmask(bitmask_tensor, batch_index)
if batch_len < self._grammar_bitmask.shape[0]: if batch_len < self._grammar_bitmask.shape[0]:
bitmask_tensor = self._grammar_bitmask[:batch_len] bitmask_tensor = self._grammar_bitmask[:batch_len]
......
# SPDX-License-Identifier: Apache-2.0
import enum
from abc import ABC, abstractmethod
import torch
class StructuredOutputOptions(enum.Enum):
JSON = enum.auto()
JSON_OBJECT = enum.auto()
REGEX = enum.auto()
GRAMMAR = enum.auto()
CHOICE = enum.auto()
StructuredOutputKey = tuple[StructuredOutputOptions, str]
class StructuredOutputGrammar(ABC):
"""Request-level backend for structured output requests."""
@abstractmethod
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
"""
Determines whether the provided tokens are accepted for the
given request.
Args:
request_id (str): The unique identifier for the request.
tokens (list[int]): A list of token IDs to evaluate.
Returns:
bool: True if the tokens are accepted, False otherwise.
"""
@abstractmethod
def fill_bitmask(self, bitmask: torch.Tensor, batch_index: int) -> None:
"""
Fills the bitmask for a specific batch index.
Args:
bitmask (torch.Tensor): The bitmask to fill
batch_index (int): The index in the bitmask to fill
"""
@abstractmethod
def is_terminated(self) -> bool:
"""
Checks whether the structured output process has terminated.
Returns:
bool: True if the process is terminated, False otherwise.
"""
@abstractmethod
def reset(self):
"""
Resets the state of the structured output grammar.
"""
class StructuredOutputBackend(ABC):
"""Engine-level backend for structured output requests."""
@abstractmethod
def compile_grammar(self, request_type: StructuredOutputOptions,
grammar_spec: str) -> StructuredOutputGrammar:
"""
Compiles a grammar specification into a structured output grammar.
Args:
request_type (StructuredOutputOptions): The type of structured
output request.
grammar_spec (str): The grammar specification to compile.
Returns:
StructuredOutputGrammar: The compiled structured output grammar.
"""
@abstractmethod
def allocate_token_bitmask(self, max_num_seqs: int):
"""
Allocates a token bitmask for the specified maximum number of sequences.
Args:
max_num_seqs (int): The maximum number of sequences for which
to allocate the bitmask.
"""
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.utils import LazyLoader
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
StructuredOutputGrammar,
StructuredOutputOptions)
if TYPE_CHECKING:
import xgrammar as xgr
else:
xgr = LazyLoader("xgr", globals(), "xgrammar")
logger = init_logger(__name__)
class XgrammarBackend(StructuredOutputBackend):
def __init__(self, vllm_config: VllmConfig):
self.vllm_config = vllm_config
tokenizer_group = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
parallel_config=vllm_config.parallel_config,
lora_config=vllm_config.lora_config) # type: ignore[arg-type]
tokenizer_group.ping()
tokenizer = tokenizer_group.get_lora_tokenizer(None)
self.vocab_size = vllm_config.model_config.get_vocab_size()
if isinstance(tokenizer, MistralTokenizer):
# NOTE: ideally, xgrammar should handle this accordingly.
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
try:
encoded_vocab = [
token for token, _ in sorted(
tokenizer.get_vocab().items(),
key=lambda x: x[1],
)
]
stop_token_ids = None
if hasattr(
tokenizer,
"eos_token_id",
) and tokenizer.eos_token_id is not None:
stop_token_ids = [tokenizer.eos_token_id]
except AttributeError as e:
raise ValueError(
f"Cannot get the vocabulary of the tokenizer "
f"{type(tokenizer)}. The tokenizer should have a "
"get_vocab method.") from e
tokenizer_info = xgr.TokenizerInfo( # type: ignore
encoded_vocab=encoded_vocab,
# NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
vocab_type=xgr.VocabType.BYTE_FALLBACK,
vocab_size=self.vocab_size,
stop_token_ids=stop_token_ids,
add_prefix_space=True,
)
else:
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
tokenizer,
vocab_size=self.vocab_size,
)
self.compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8)
def compile_grammar(self, request_type: StructuredOutputOptions,
grammar_spec: str) -> StructuredOutputGrammar:
if request_type == StructuredOutputOptions.JSON:
ctx = self.compiler.compile_json_schema(grammar_spec,
any_whitespace=False)
elif request_type == StructuredOutputOptions.JSON_OBJECT:
ctx = self.compiler.compile_builtin_json_grammar()
elif request_type == StructuredOutputOptions.GRAMMAR:
ctx = self.compiler.compile_grammar(grammar_spec)
elif request_type == StructuredOutputOptions.REGEX:
ctx = self.compiler.compile_regex(grammar_spec)
else:
logger.error(
"Validation should have already occurred. Please file an issue."
)
raise ValueError(
f"grammar is not of valid supported types. ({request_type!s})")
return XgrammarGrammar(
matcher=xgr.GrammarMatcher(ctx),
vocab_size=self.vocab_size,
ctx=ctx,
)
def allocate_token_bitmask(self, max_num_seqs: int):
return xgr.allocate_token_bitmask(max_num_seqs, self.vocab_size)
@dataclass
class XgrammarGrammar(StructuredOutputGrammar):
# NOTE: This would be a generic-enough class for
# supporting different backends, in the future.
# For now, just xgrammar.
#
# TODO: support max_rollback_tokens
# https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string
# for jump-forward decoding
vocab_size: int
matcher: xgr.GrammarMatcher = field(hash=False)
ctx: xgr.CompiledGrammar = field(hash=False)
num_processed_tokens: int = field(default_factory=lambda: 0,
repr=False,
hash=False,
init=False)
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
"""Accepts a list of tokens and advances the FSM.
Returns True if the FSM was advanced successfully.
Returns False if the FSM failed to advance.
"""
for token in tokens:
if not self.matcher.accept_token(token):
logger.error(
"Failed to advance FSM for request %s "
"for tokens %s. Please file an issue.", request_id, token)
return False
self.num_processed_tokens += 1
return True
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
self.matcher.fill_next_token_bitmask(bitmask, idx)
def is_terminated(self) -> bool:
return self.matcher.is_terminated()
def reset(self):
self.num_processed_tokens = 0
self.matcher.reset()
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import enum
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
import torch
from vllm.logger import init_logger
from vllm.utils import LazyLoader
if TYPE_CHECKING:
import xgrammar as xgr
else:
xgr = LazyLoader("xgr", globals(), "xgrammar")
logger = init_logger(__name__)
class StructuredOutputOptions(enum.Enum):
JSON = enum.auto()
JSON_OBJECT = enum.auto()
REGEX = enum.auto()
GRAMMAR = enum.auto()
CHOICE = enum.auto()
StructuredOutputKey = tuple[StructuredOutputOptions, str]
@dataclass
class Grammar:
# NOTE: This would be a generic-enough class for
# supporting different backends, in the future.
# For now, just xgrammar.
#
# TODO: support max_rollback_tokens
# https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string
# for jump-forward decoding
vocab_size: int
matcher: xgr.GrammarMatcher = field(hash=False)
ctx: xgr.CompiledGrammar = field(hash=False)
num_processed_tokens: int = field(default_factory=lambda: 0,
repr=False,
hash=False,
init=False)
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
"""Accepts a list of tokens and advances the FSM.
Returns True if the FSM was advanced successfully.
Returns False if the FSM failed to advance.
"""
for token in tokens:
if not self.matcher.accept_token(token):
logger.error(
"Failed to advance FSM for request %s "
"for tokens %s. Please file an issue.", request_id, token)
return False
self.num_processed_tokens += 1
return True
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> bool:
return self.matcher.fill_next_token_bitmask(bitmask, idx)
def reset(self):
self.num_processed_tokens = 0
self.matcher.reset()
def __copy__(self):
return Grammar(
matcher=xgr.GrammarMatcher(self.ctx),
vocab_size=self.vocab_size,
ctx=self.ctx,
)
...@@ -9,7 +9,8 @@ from concurrent.futures._base import TimeoutError ...@@ -9,7 +9,8 @@ from concurrent.futures._base import TimeoutError
from typing import Optional, Union, cast from typing import Optional, Union, cast
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.structured_output.grammar import (Grammar, StructuredOutputKey, from vllm.v1.structured_output.backend_types import (StructuredOutputGrammar,
StructuredOutputKey,
StructuredOutputOptions) StructuredOutputOptions)
...@@ -17,7 +18,8 @@ from vllm.v1.structured_output.grammar import (Grammar, StructuredOutputKey, ...@@ -17,7 +18,8 @@ from vllm.v1.structured_output.grammar import (Grammar, StructuredOutputKey,
class StructuredOutputRequest: class StructuredOutputRequest:
sampling_params: SamplingParams sampling_params: SamplingParams
_grammar: Optional[Union[Future[Grammar], Grammar]] = None _grammar: Optional[Union[Future[StructuredOutputGrammar],
StructuredOutputGrammar]] = None
def _check_grammar_completion(self) -> bool: def _check_grammar_completion(self) -> bool:
# NOTE: We have to lazy import to gate circular imports # NOTE: We have to lazy import to gate circular imports
...@@ -37,12 +39,16 @@ class StructuredOutputRequest: ...@@ -37,12 +39,16 @@ class StructuredOutputRequest:
return self._check_grammar_completion() return self._check_grammar_completion()
@property @property
def grammar(self) -> Optional[Grammar]: def grammar(self) -> Optional[StructuredOutputGrammar]:
completed = self._check_grammar_completion() completed = self._check_grammar_completion()
return cast(Optional[Grammar], self._grammar) if completed else None return cast(Optional[StructuredOutputGrammar],
self._grammar) if completed else None
@grammar.setter @grammar.setter
def grammar(self, grammar: Union[Grammar, Future[Grammar]]) -> None: def grammar(
self, grammar: Union[StructuredOutputGrammar,
Future[StructuredOutputGrammar]]
) -> None:
self._grammar = grammar self._grammar = grammar
@functools.cached_property @functools.cached_property
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment