# SPDX-License-Identifier: Apache-2.0 import copy import json import os from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional, Union import torch from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.sampling_params import GuidedDecodingParams, SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.utils import LazyLoader from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, StructuredOutputGrammar, StructuredOutputOptions) from vllm.v1.structured_output.request import get_structured_output_key if TYPE_CHECKING: import llguidance import llguidance.hf as llguidance_hf import llguidance.torch as llguidance_torch else: llguidance = LazyLoader("llguidance", globals(), "llguidance") llguidance_hf = LazyLoader("llguidance.hf", globals(), "llguidance.hf") llguidance_torch = LazyLoader("llguidance.torch", globals(), "llguidance.torch") logger = init_logger(__name__) def _walk_json_for_additional_properties(data: object): if isinstance(data, dict): for value in data.values(): _walk_json_for_additional_properties(value) if 'additionalProperties' not in data and \ ('properties' in data or 'patternProperties' in data): data['additionalProperties'] = False elif isinstance(data, list): for item in data: _walk_json_for_additional_properties(item) def process_for_additional_properties( guide_json: Union[str, dict[str, Any]]) -> dict[str, Any]: if isinstance(guide_json, str): guide_json_obj = json.loads(guide_json) else: # copy for modifications guide_json_obj = copy.deepcopy(guide_json) _walk_json_for_additional_properties(guide_json_obj) return guide_json_obj class GuidanceBackend(StructuredOutputBackend): def __init__(self, vllm_config: VllmConfig): self.vllm_config = vllm_config tokenizer_group = init_tokenizer_from_configs( model_config=vllm_config.model_config, scheduler_config=vllm_config.scheduler_config, parallel_config=vllm_config.parallel_config, lora_config=vllm_config.lora_config) # type: ignore[arg-type] tokenizer_group.ping() self.vllm_config = vllm_config self.vocab_size = vllm_config.model_config.get_vocab_size() self.disable_any_whitespace = False self.no_additional_properties = False backend_options = GuidedDecodingParams( backend=vllm_config.decoding_config.guided_decoding_backend ).backend_options() for option in backend_options: if option == "disable-any-whitespace": self.disable_any_whitespace = True elif option == "no-additional-properties": self.no_additional_properties = True else: raise ValueError( f"Unsupported option for the guidance backend: {option}") tokenizer = tokenizer_group.get_lora_tokenizer(None) self.ll_tokenizer = llguidance_hf.from_tokenizer( tokenizer, self.vocab_size) def compile_grammar(self, request_type: StructuredOutputOptions, grammar_spec: str) -> StructuredOutputGrammar: self.serialized_grammar = serialize_guidance_grammar( request_type, grammar_spec, self.disable_any_whitespace, self.no_additional_properties) ll_matcher = llguidance.LLMatcher( self.ll_tokenizer, self.serialized_grammar, log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")), ) r = GuidanceGrammar( ll_matcher=ll_matcher, ll_tokenizer=self.ll_tokenizer, vocab_size=self.vocab_size, ) r.check_error() return r def allocate_token_bitmask(self, max_num_seqs: int): return llguidance_torch.allocate_token_bitmask( max_num_seqs, self.ll_tokenizer.vocab_size) @dataclass class GuidanceGrammar(StructuredOutputGrammar): ll_matcher: llguidance.LLMatcher ll_tokenizer: llguidance.LLTokenizer vocab_size: int printed_error: bool = False terminated: bool = False def check_error(self): if not self.printed_error: err = self.ll_matcher.get_error() if err: self.printed_error = True logger.warning("LLMatcher error: %s", err) def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: """Accepts a list of tokens and advances the parser. Returns True if the parser was advanced successfully. Returns False if the parser failed to advance. """ if self.ll_tokenizer.eos_token in tokens: self.terminated = True if self.ll_matcher.is_stopped(): return True # TODO - Add jump decoding support in the future: # self.ll_matcher.compute_ff_bytes() - this should always work # self.ll_matcher.compute_ff_tokens() - this only works for # "canonical" tokenizers # For conversion between the two, see # https://github.com/guidance-ai/llguidance/blob/main/docs/fast_forward.md r = self.ll_matcher.consume_tokens(tokens) self.check_error() return r def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None: # this will automatically return [EOS] mask if the matcher is stopped # or otherwise in an error state llguidance_torch.fill_next_token_bitmask(self.ll_matcher, bitmask, idx) self.check_error() def is_terminated(self) -> bool: return self.terminated def reset(self): # This method may be not needed anymore? TODO self.ll_matcher.reset() def serialize_guidance_grammar( request_type: StructuredOutputOptions, grammar_spec: Union[str, dict[str, Any]], disable_any_whitespace: bool = False, no_additional_properties: bool = False, ) -> str: if request_type == StructuredOutputOptions.JSON: if no_additional_properties: grammar_spec = process_for_additional_properties(grammar_spec) return llguidance.LLMatcher.grammar_from_json_schema( grammar_spec, defaults={ "whitespace_flexible": not disable_any_whitespace, }) elif request_type == StructuredOutputOptions.JSON_OBJECT: return llguidance.LLMatcher.grammar_from_json_schema( '{"type": "object"}', defaults={ "whitespace_flexible": not disable_any_whitespace, }) else: if request_type == StructuredOutputOptions.REGEX: tp = "regex" elif request_type == StructuredOutputOptions.GRAMMAR: tp = "grammar" elif request_type == StructuredOutputOptions.CHOICE: tp = "choice" else: logger.error("Validation should have already occurred. " "Please file an issue.") raise ValueError("grammar is not of valid supported types. " f"({request_type!s})") return llguidance.grammar_from(tp, grammar_spec) def validate_guidance_grammar( sampling_params: SamplingParams, tokenizer: Optional[llguidance.LLTokenizer] = None) -> None: tp, grm = get_structured_output_key(sampling_params) guidance_grm = serialize_guidance_grammar(tp, grm) err = llguidance.LLMatcher.validate_grammar(guidance_grm, tokenizer) if err: raise ValueError(f"Grammar error: {err}")