"...hello_world/multinode_example/components/frontend.py" did not exist on "d675d2218e5b271e8434cd03bb3384a2641f12b1"
Unverified Commit 80e9afb5 authored by Aaron Pham's avatar Aaron Pham Committed by GitHub
Browse files

[V1][Core] Support for Structured Outputs (#12388)


Signed-off-by: default avatarAaron Pham <contact@aarnphm.xyz>
Signed-off-by: default avatarRussell Bryant <rbryant@redhat.com>
Co-authored-by: default avatarRussell Bryant <rbryant@redhat.com>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
Co-authored-by: default avatarNick Hill <nhill@redhat.com>
parent 1e3598ed
......@@ -3,13 +3,15 @@
import enum
from typing import TYPE_CHECKING, Optional, Union
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
EngineCoreRequest, FinishReason)
from vllm.v1.structured_output.request import StructuredOutputRequest
from vllm.v1.utils import ConstantList
if TYPE_CHECKING:
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.inputs import PlaceholderRange
......@@ -27,15 +29,19 @@ class Request:
sampling_params: SamplingParams,
eos_token_id: Optional[int],
arrival_time: float,
lora_request: Optional[LoRARequest] = None,
lora_request: Optional["LoRARequest"] = None,
structured_output_request: Optional["StructuredOutputRequest"] = None,
) -> None:
self.request_id = request_id
self.sampling_params = sampling_params
# Because of LoRA, the eos token id can be different for each request.
self.eos_token_id = eos_token_id
self.lora_request = lora_request
self.structured_output_request = structured_output_request
self.status = RequestStatus.WAITING
self.status = (RequestStatus.WAITING_FOR_FSM
if sampling_params.guided_decoding is not None else
RequestStatus.WAITING)
self.events: list[EngineCoreEvent] = []
self.stop_reason: Union[int, str, None] = None
assert sampling_params.max_tokens is not None
......@@ -78,6 +84,8 @@ class Request:
eos_token_id=request.eos_token_id,
arrival_time=request.arrival_time,
lora_request=request.lora_request,
structured_output_request=StructuredOutputRequest(
sampling_params=request.sampling_params),
)
def queued(self, timestamp: Optional[float] = None) -> None:
......@@ -134,18 +142,23 @@ class Request:
num_tokens = self.mm_positions[input_id]["length"]
return num_tokens
@property
def use_structured_output(self) -> bool:
return self.sampling_params.guided_decoding is not None
class RequestStatus(enum.IntEnum):
"""Status of a request."""
WAITING = 0
RUNNING = 1
PREEMPTED = 2
# Note: anything after PREEMPTED (2) will be considered
WAITING = enum.auto()
WAITING_FOR_FSM = enum.auto()
RUNNING = enum.auto()
PREEMPTED = enum.auto()
# Note: anything after PREEMPTED will be considered
# as a finished status.
FINISHED_STOPPED = 3
FINISHED_LENGTH_CAPPED = 4
FINISHED_ABORTED = 5
FINISHED_IGNORED = 6
FINISHED_STOPPED = enum.auto()
FINISHED_LENGTH_CAPPED = enum.auto()
FINISHED_ABORTED = enum.auto()
FINISHED_IGNORED = enum.auto()
@staticmethod
def is_finished(status: "RequestStatus") -> bool:
......
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import copy
import multiprocessing
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, Optional
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import LazyLoader
from vllm.v1.structured_output.grammar import (Grammar, StructuredOutputKey,
StructuredOutputOptions)
if TYPE_CHECKING:
import numpy as np
import numpy.typing as npt
import xgrammar as xgr
from vllm.v1.request import Request
else:
xgr = LazyLoader("xgr", globals(), "xgrammar")
logger = init_logger(__name__)
class StructuredOutputManager:
def __init__(self, vllm_config: VllmConfig, max_cache_size: int = 500):
tokenizer_group = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
parallel_config=vllm_config.parallel_config,
lora_config=vllm_config.lora_config) # type: ignore[arg-type]
tokenizer_group.ping()
self.vocab_size = vllm_config.model_config.get_vocab_size()
self.vllm_config = vllm_config
tokenizer = tokenizer_group.get_lora_tokenizer(None)
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
tokenizer, vocab_size=self.vocab_size)
self.compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8)
self.max_cache_size = max_cache_size
self.request_key_to_grammar: OrderedDict[StructuredOutputKey,
Grammar] = OrderedDict()
# The default max_workers if not specified is the number of CPUs * 5,
# which is way too high since these tasks are CPU-bound, not I/O bound.
# We also know we would never dominate CPU usage with just grammar
# compilation, so we set it to half the number of CPUs.
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
self.executor = ThreadPoolExecutor(max_workers=max_workers)
self._grammar_bitmask = xgr.allocate_token_bitmask(
self.vllm_config.scheduler_config.max_num_seqs, self.vocab_size)
def __getitem__(self, key: StructuredOutputKey) -> Optional[Grammar]:
# We need to pop and re-insert the grammar here for LRU cache
# of request_key_to_grammar
if key in self.request_key_to_grammar:
# Move accessed item to the end (most recently used)
value = self.request_key_to_grammar.pop(key)
if value is not None:
self.request_key_to_grammar[key] = value
return value
return None
def populate_cache(self, request: Request) -> None:
if request.structured_output_request is None:
return
grammar = self.request_key_to_grammar.get(
request.structured_output_request.structured_output_key)
if grammar:
request.structured_output_request.grammar = copy.copy(grammar)
return
request.structured_output_request.grammar = self.cache(request)
def cache(self, request: Request):
return self.executor.submit(self._executor_loop, request)
def _executor_loop(self, request: Request) -> Grammar:
# NOTE: The structured_output_request should never be
# None in this case, but mypy can't infer this
# correctly, so we need to ignore the error here.
key = request.structured_output_request.structured_output_key # type: ignore[union-attr]
grammar = self.request_key_to_grammar.get(key)
if grammar is not None:
return copy.copy(grammar)
grammar = self.initialize_grammar(key)
# If cache is full, remove the least recently used item
if len(self.request_key_to_grammar) >= self.max_cache_size:
self.request_key_to_grammar.popitem(last=False)
self.request_key_to_grammar[key] = grammar
return copy.copy(grammar)
def initialize_grammar(self, key: StructuredOutputKey) -> Grammar:
# Note that the request was validated in the engine core client,
# so at this point we know it is a supported type of request.
#
# TODO: we still need to handle xgrammar compilation failures
request_type, grammar_spec = key
if request_type == StructuredOutputOptions.JSON:
# TODO -- allow any_whitespace to be configurable
# pending merge of https://github.com/vllm-project/vllm/pull/12744
ctx = self.compiler.compile_json_schema(grammar_spec,
any_whitespace=False)
elif request_type == StructuredOutputOptions.JSON_OBJECT:
ctx = self.compiler.compile_builtin_json_grammar()
elif request_type == StructuredOutputOptions.GRAMMAR:
ctx = self.compiler.compile_grammar(grammar_spec)
else:
logger.error("Validation should have already occurred. "
"Please file an issue.")
raise ValueError(
f"grammar is not of valid supported types. ({request_type!s})")
return Grammar(
matcher=xgr.GrammarMatcher(ctx),
vocab_size=self.vocab_size,
ctx=ctx,
)
def grammar_bitmask(
self,
requests: dict[str, Request],
structured_output_request_ids: dict[str, int],
batch_len: int,
) -> Optional[npt.NDArray[np.int32]]:
# Prepare the structured output bitmask for this batch.
if not structured_output_request_ids:
return None
# Fill the bitmask using the index of each request equal to its
# position in the batch. Resize the bitmask down to the size of
# the batch.
bitmask_tensor = self._grammar_bitmask
for req_id, batch_index in structured_output_request_ids.items():
request = requests[req_id].structured_output_request
assert request is not None and request.grammar is not None
if not request.grammar.matcher.is_terminated():
request.grammar.fill_bitmask(bitmask_tensor, batch_index)
if batch_len < self._grammar_bitmask.shape[0]:
bitmask_tensor = self._grammar_bitmask[:batch_len]
# After finishing with the xgrammar operations, we convert to
# np.ndarray, because that is much more efficient for serialization
# and deserialization when sending this to the GPU workers.
return bitmask_tensor.numpy()
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import enum
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
import torch
from vllm.logger import init_logger
from vllm.utils import LazyLoader
if TYPE_CHECKING:
import xgrammar as xgr
else:
xgr = LazyLoader("xgr", globals(), "xgrammar")
logger = init_logger(__name__)
class StructuredOutputOptions(enum.Enum):
JSON = enum.auto()
JSON_OBJECT = enum.auto()
REGEX = enum.auto()
GRAMMAR = enum.auto()
CHOICE = enum.auto()
StructuredOutputKey = tuple[StructuredOutputOptions, str]
@dataclass
class Grammar:
# NOTE: This would be a generic-enough class for
# supporting different backends, in the future.
# For now, just xgrammar.
#
# TODO: support max_rollback_tokens
# https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string
# for jump-forward decoding
vocab_size: int
matcher: xgr.GrammarMatcher = field(hash=False)
ctx: xgr.CompiledGrammar = field(hash=False)
num_processed_tokens: int = field(default_factory=lambda: 0,
repr=False,
hash=False,
init=False)
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
"""Accepts a list of tokens and advances the FSM.
Returns True if the FSM was advanced successfully.
Returns False if the FSM failed to advance.
"""
for token in tokens:
if not self.matcher.accept_token(token):
logger.error(
"Failed to advance FSM for request %s "
"for tokens %s. Please file an issue.", request_id, token)
return False
self.num_processed_tokens += 1
return True
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> bool:
return self.matcher.fill_next_token_bitmask(bitmask, idx)
def reset(self):
self.num_processed_tokens = 0
self.matcher.reset()
def __copy__(self):
return Grammar(
matcher=xgr.GrammarMatcher(self.ctx),
vocab_size=self.vocab_size,
ctx=self.ctx,
)
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import dataclasses
import functools
import json
from concurrent.futures import Future
from concurrent.futures._base import TimeoutError
from typing import Optional, Union, cast
from vllm.sampling_params import SamplingParams
from vllm.v1.structured_output.grammar import (Grammar, StructuredOutputKey,
StructuredOutputOptions)
@dataclasses.dataclass
class StructuredOutputRequest:
sampling_params: SamplingParams
_grammar: Optional[Union[Future[Grammar], Grammar]] = None
def _check_grammar_completion(self) -> bool:
# NOTE: We have to lazy import to gate circular imports
from vllm.v1.request import RequestStatus
if isinstance(self._grammar, Future):
try:
# We will check whether the future is ready within 100 us
self._grammar = self._grammar.result(timeout=0.0001)
self.status = RequestStatus.WAITING
except TimeoutError:
return False
return True
@property
def is_grammar_ready(self) -> bool:
return self._check_grammar_completion()
@property
def grammar(self) -> Optional[Grammar]:
completed = self._check_grammar_completion()
return cast(Optional[Grammar], self._grammar) if completed else None
@grammar.setter
def grammar(self, grammar: Union[Grammar, Future[Grammar]]) -> None:
self._grammar = grammar
@functools.cached_property
def structured_output_key(self) -> StructuredOutputKey:
params = self.sampling_params.guided_decoding
assert params is not None, "params can't be None."
if params.json is not None:
if not isinstance(params.json, str):
json_str = json.dumps(params.json)
else:
json_str = params.json
return (StructuredOutputOptions.JSON, json_str)
elif params.json_object:
return (StructuredOutputOptions.JSON_OBJECT, "")
elif params.regex is not None:
return (StructuredOutputOptions.REGEX, params.regex)
elif params.choice is not None:
if not isinstance(params.choice, str):
json_str = json.dumps(params.choice)
else:
json_str = params.choice
return (StructuredOutputOptions.CHOICE, json_str)
elif params.grammar is not None:
return (StructuredOutputOptions.GRAMMAR, params.grammar)
else:
raise ValueError("No valid structured output parameter found")
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import json
import re
from typing import TYPE_CHECKING, Any
from vllm.sampling_params import SamplingParams
from vllm.utils import LazyLoader
if TYPE_CHECKING:
import xgrammar as xgr
else:
xgr = LazyLoader("xgr", globals(), "xgrammar")
def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool:
"""Check if JSON schema contains features unsupported by xgrammar."""
def check_object(obj: dict[str, Any]) -> bool:
if not isinstance(obj, dict):
return False
# Check for pattern restrictions
if "pattern" in obj:
return True
# Check for enum restrictions
if "enum" in obj:
return True
# Check for numeric ranges
if obj.get("type") in ("integer", "number") and any(
key in obj
for key in ("minimum", "maximum", "exclusiveMinimum",
"exclusiveMaximum", "multipleOf")):
return True
# Check for array unsupported keywords
if obj.get("type") == "array" and any(
key in obj
for key in ("uniqueItems", "contains", "minContains",
"maxContains", "minItems", "maxItems")):
return True
# Unsupported keywords for strings
if obj.get("type") == "string" and any(
key in obj for key in ("minLength", "maxLength", "format")):
return True
# Unsupported keywords for objects
if obj.get("type") == "object" and any(
key in obj for key in ("minProperties", "maxProperties",
"propertyNames", "patternProperties")):
return True
# Recursively check all nested objects and arrays
for value in obj.values():
if isinstance(value, dict):
if check_object(value):
return True
elif isinstance(value, list):
for item in value:
if isinstance(item, dict) and check_object(item):
return True
return False
return check_object(schema)
def grammar_is_likely_lark(grammar_str: str) -> bool:
"""
Check if grammar appears to use Lark syntax.
Args:
grammar_str: Input grammar string
Returns:
bool: True if grammar appears to be in Lark format, False otherwise
Examples:
>>> grammar_is_likely_lark("rule: 'abc'")
True
>>> grammar_is_likely_lark("rule ::= 'abc'")
False
"""
if not grammar_str or not isinstance(grammar_str, str):
return False
for line in grammar_str.split('\n'):
# Remove both comment styles
line = re.sub(r'(#|//).*$', '', line).strip()
if not line:
continue
# Look for EBNF rule definition
if '::=' in line:
return False
return True
def convert_lark_to_ebnf(grammar_str: str) -> str:
"""
Convert a Lark grammar string to EBNF format.
EBNF reference:
https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
Lark grammar reference:
https://lark-parser.readthedocs.io/en/latest/grammar.html
Args:
grammar_str: Input grammar in Lark format
Returns:
str: Converted grammar in EBNF format
Examples:
>>> print(convert_lark_to_ebnf("rule: 'hello'"))
root ::= rule
rule ::= "hello"
"""
if not isinstance(grammar_str, str):
raise ValueError(f"Grammar must be a string, got {type(grammar_str)}")
if not grammar_str.strip():
raise ValueError("Grammar string cannot be empty")
defined_rules = set()
referenced_rules = set()
output_lines = []
def clean_line(line: str) -> str:
"""Remove comments and whitespace from line."""
return re.sub(r'(#|//).*$', '', line).strip()
def check_quotes(text: str, rule_name: str, line_num: int) -> None:
"""Validate quote matching in text."""
if text.count("'") % 2 != 0 or text.count('"') % 2 != 0:
raise ValueError(
f"Mismatched quotes in {rule_name} on line {line_num}")
def extract_references(text: str) -> set:
"""Extract rule references from text."""
# Remove quoted strings and special characters
text = re.sub(r'"[^"]*"', '', text)
text = re.sub(r'[+*?()|\[\]{}]', ' ', text)
return set(re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', text))
# First pass: Find root rule and validate rule definitions
lines = [clean_line(line) for line in grammar_str.split('\n')]
first_rule = None
for line_num, line in enumerate(lines, 1):
if not line or line.startswith('|'):
continue
if ':' in line:
try:
name = line.split(':', 1)[0].strip().strip('?')
defined_rules.add(name)
if first_rule is None:
first_rule = name
if name == 'start':
first_rule = 'start'
except IndexError as e:
raise ValueError(f"Invalid rule format on line {line_num}. "
"Expected 'rule_name: definition'") from e
if not defined_rules:
raise ValueError("No valid rules found in grammar")
# Add root rule
output_lines.append(f"root ::= {first_rule}")
# Second pass: Process rule definitions and alternatives
current_rule = None
current_definition = []
for line_num, line in enumerate(lines, 1):
if not line:
continue
try:
if ':' in line and not line.startswith('|'):
# Save previous rule if exists
if current_rule:
output_lines.append(
f"{current_rule} ::= {' | '.join(current_definition)}")
# Process new rule
name, definition = line.split(':', 1)
current_rule = name.strip().strip('?')
check_quotes(definition, f"rule '{current_rule}'", line_num)
definition = re.sub(r"'([^']*)'", r'"\1"', definition)
referenced_rules.update(extract_references(definition))
current_definition = [definition.strip()]
elif line.startswith('|'):
if not current_rule:
raise ValueError(f"Alternative '|' on line {line_num} "
"without a preceding rule definition")
alt_def = line[1:].strip()
check_quotes(alt_def, f"alternative for rule '{current_rule}'",
line_num)
alt_def = re.sub(r"'([^']*)'", r'"\1"', alt_def)
referenced_rules.update(extract_references(alt_def))
current_definition.append(alt_def)
except ValueError as e:
raise ValueError(f"Error on line {line_num}: {str(e)}") from e
# Add final rule if exists
if current_rule:
output_lines.append(
f"{current_rule} ::= {' | '.join(current_definition)}")
# Validate all rules are defined
undefined_rules = referenced_rules - defined_rules - {'root'}
if undefined_rules:
raise ValueError("Referenced rules are not defined: "
f"{', '.join(sorted(undefined_rules))}")
return '\n'.join(output_lines)
def choice_as_grammar(choice: list[str]) -> str:
def escape_ebnf_string(s: str) -> str:
"""Escape special characters in a EBNF string."""
# Escape double quotes and backslashes
return re.sub(r'(["\\])', r'\\\1', s)
escaped_choices = (escape_ebnf_string(c) for c in choice)
grammar = ('root ::= ' + ' | '.join(f'"{c}"' for c in escaped_choices))
return grammar
def validate_structured_output_request(
sampling_params: SamplingParams) -> None:
"""Validate that the request is supported by structured output.
Raises ValueError if the request is not supported.
"""
if sampling_params.guided_decoding is None:
return
gd_params = sampling_params.guided_decoding
if gd_params.regex:
raise ValueError("Regex structured output is not supported.")
if gd_params.choice:
choice_grammar = choice_as_grammar(gd_params.choice)
try:
xgr.Grammar.from_ebnf(choice_grammar)
except Exception as err:
raise ValueError("Failed to transform choices into a grammar: "
"{err}") from err
gd_params.choice = None
gd_params.grammar = choice_grammar
return
if gd_params.json:
if isinstance(gd_params.json, str):
try:
schema = json.loads(gd_params.json)
except json.JSONDecodeError as e:
raise ValueError("Invalid JSON grammar specification.") from e
else:
schema = gd_params.json
if has_xgrammar_unsupported_json_features(schema):
raise ValueError("The provided JSON schema contains features not "
"supported by xgrammar.")
return
if gd_params.grammar:
if grammar_is_likely_lark(gd_params.grammar):
# xgrammar supports EBNF grammars only
try:
gd_params.grammar = convert_lark_to_ebnf(gd_params.grammar)
except ValueError as e:
raise ValueError(
"Failed to convert the grammar from Lark to EBNF. ") from e
# Test parsing EBNF grammar, possibly already converted from Lark
try:
# parse the grammar, but we aren't compiling it.
xgr.Grammar.from_ebnf(gd_params.grammar)
except Exception as e:
raise ValueError("Invalid grammar specification.") from e
......@@ -25,7 +25,8 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LayerBlockType, cdiv, is_pin_memory_available)
LayerBlockType, LazyLoader, cdiv,
is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.engine.mm_input_cache import MMInputCacheClient
......@@ -40,7 +41,11 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
if TYPE_CHECKING:
import xgrammar as xgr
from vllm.v1.core.scheduler_output import SchedulerOutput
else:
xgr = LazyLoader("xgr", globals(), "xgrammar")
logger = init_logger(__name__)
......@@ -860,6 +865,53 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def get_model(self) -> nn.Module:
return self.model
def apply_grammar_bitmask(
self,
scheduler_output: "SchedulerOutput",
logits: torch.Tensor,
):
# Serialization of np.ndarray is much more efficient than a tensor,
# so we receive it in that format.
grammar_bitmask = scheduler_output.grammar_bitmask
if grammar_bitmask is None:
return
# We receive the structured output bitmask from the scheduler, but the
# indices of the requests in the batch may not match the indices of
# the bitmask since the scheduler doesn't know how the gpu runner is
# ordering the requests in the batch. We need to sort the bitmask to
# match the order of the requests used here.
struct_out_req_batch_indices: dict[str, int] = {}
indices_match = True
for req_id in self.input_batch.req_ids:
mask_index = scheduler_output.structured_output_request_ids.get(
req_id)
if mask_index is None:
# not a structured output request
continue
batch_index = self.input_batch.req_id_to_index[req_id]
if batch_index != mask_index:
indices_match = False
struct_out_req_batch_indices[req_id] = batch_index
if not indices_match:
# Sort the bitmask to match the order of the requests
sorted_bitmask = np.zeros_like(grammar_bitmask)
for req_id, batch_index in struct_out_req_batch_indices.items():
orig_index = scheduler_output.structured_output_request_ids[
req_id]
sorted_bitmask[batch_index] = grammar_bitmask[orig_index]
grammar_bitmask = sorted_bitmask
grammar_bitmask = torch.from_numpy(grammar_bitmask)
# TODO: compatibility with spec decode
xgr.apply_token_bitmask_inplace(
logits,
grammar_bitmask.to(self.device, non_blocking=True),
indices=list(struct_out_req_batch_indices.values()),
)
@torch.inference_mode()
def execute_model(
self,
......@@ -945,6 +997,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
# Apply structured output bitmasks if present
if scheduler_output.grammar_bitmask is not None:
self.apply_grammar_bitmask(scheduler_output, logits)
# Sample the next token and get logprobs if needed.
sampling_metadata = self.input_batch.sampling_metadata
if not self.use_spec_decode:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment