Unverified Commit 04421dff authored by Russell Bryant's avatar Russell Bryant Committed by GitHub
Browse files

[V1] Prevent xgrammar from breaking TPU support (#14575)


Signed-off-by: default avatarRussell Bryant <rbryant@redhat.com>
parent 432d6dad
...@@ -4,6 +4,7 @@ import time ...@@ -4,6 +4,7 @@ import time
from collections.abc import Mapping from collections.abc import Mapping
from typing import Optional, Union from typing import Optional, Union
import vllm.platforms
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
PromptType, SingletonInputsAdapter) PromptType, SingletonInputsAdapter)
...@@ -133,6 +134,9 @@ class Processor: ...@@ -133,6 +134,9 @@ class Processor:
if self.vllm_config.speculative_config: if self.vllm_config.speculative_config:
raise ValueError("Structured output is not supported with " raise ValueError("Structured output is not supported with "
"speculative decoding.") "speculative decoding.")
if vllm.platforms.current_platform.is_tpu():
raise ValueError("Structured output is not supported on TPU.")
validate_structured_output_request(params) validate_structured_output_request(params)
def process_inputs( def process_inputs(
......
...@@ -17,6 +17,7 @@ from vllm.v1.structured_output.grammar import (Grammar, StructuredOutputKey, ...@@ -17,6 +17,7 @@ from vllm.v1.structured_output.grammar import (Grammar, StructuredOutputKey,
if TYPE_CHECKING: if TYPE_CHECKING:
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
import torch
import xgrammar as xgr import xgrammar as xgr
from vllm.v1.request import Request from vllm.v1.request import Request
...@@ -53,8 +54,7 @@ class StructuredOutputManager: ...@@ -53,8 +54,7 @@ class StructuredOutputManager:
# compilation, so we set it to half the number of CPUs. # compilation, so we set it to half the number of CPUs.
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2) max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
self.executor = ThreadPoolExecutor(max_workers=max_workers) self.executor = ThreadPoolExecutor(max_workers=max_workers)
self._grammar_bitmask = xgr.allocate_token_bitmask( self._grammar_bitmask: Optional[torch.Tensor] = None
self.vllm_config.scheduler_config.max_num_seqs, self.vocab_size)
def __getitem__(self, key: StructuredOutputKey) -> Optional[Grammar]: def __getitem__(self, key: StructuredOutputKey) -> Optional[Grammar]:
# We need to pop and re-insert the grammar here for LRU cache # We need to pop and re-insert the grammar here for LRU cache
...@@ -134,6 +134,11 @@ class StructuredOutputManager: ...@@ -134,6 +134,11 @@ class StructuredOutputManager:
if not structured_output_request_ids: if not structured_output_request_ids:
return None return None
if self._grammar_bitmask is None:
self._grammar_bitmask = xgr.allocate_token_bitmask(
self.vllm_config.scheduler_config.max_num_seqs,
self.vocab_size)
# Fill the bitmask using the index of each request equal to its # Fill the bitmask using the index of each request equal to its
# position in the batch. Resize the bitmask down to the size of # position in the batch. Resize the bitmask down to the size of
# the batch. # the batch.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment