Unverified Commit 4a98edff authored by Aaron Pham's avatar Aaron Pham Committed by GitHub
Browse files

[Structured Outputs][V1] Skipping with models doesn't contain tokenizers (#20365)


Signed-off-by: default avatarAaron Pham <contact@aarnphm.xyz>
Co-authored-by: default avatarNick Hill <nhill@redhat.com>
parent a7bab0c9
...@@ -9,7 +9,7 @@ import torch ...@@ -9,7 +9,7 @@ import torch
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, SpeculativeConfig, VllmConfig) SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
...@@ -17,6 +17,7 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, ...@@ -17,6 +17,7 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.structured_output.request import StructuredOutputRequest
EOS_TOKEN_ID = 50256 EOS_TOKEN_ID = 50256
...@@ -33,6 +34,7 @@ def create_scheduler( ...@@ -33,6 +34,7 @@ def create_scheduler(
block_size: int = 16, block_size: int = 16,
max_model_len: Optional[int] = None, max_model_len: Optional[int] = None,
num_speculative_tokens: Optional[int] = None, num_speculative_tokens: Optional[int] = None,
skip_tokenizer_init: bool = False,
) -> Scheduler: ) -> Scheduler:
'''Create scheduler under test. '''Create scheduler under test.
...@@ -65,6 +67,7 @@ def create_scheduler( ...@@ -65,6 +67,7 @@ def create_scheduler(
trust_remote_code=True, trust_remote_code=True,
dtype="float16", dtype="float16",
seed=42, seed=42,
skip_tokenizer_init=skip_tokenizer_init,
) )
# Cache config, optionally force APC # Cache config, optionally force APC
kwargs_cache = ({} if enable_prefix_caching is None else { kwargs_cache = ({} if enable_prefix_caching is None else {
...@@ -1857,3 +1860,39 @@ def test_priority_scheduling_heap_property(): ...@@ -1857,3 +1860,39 @@ def test_priority_scheduling_heap_property():
# Verify requests were scheduled in priority order (lowest value first) # Verify requests were scheduled in priority order (lowest value first)
expected_priorities = sorted(priorities) expected_priorities = sorted(priorities)
assert scheduled_priorities == expected_priorities assert scheduled_priorities == expected_priorities
def test_schedule_skip_tokenizer_init():
scheduler = create_scheduler(skip_tokenizer_init=True)
requests = create_requests(num_requests=5)
for request in requests:
scheduler.add_request(request)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests)
assert output.grammar_bitmask is None
def test_schedule_skip_tokenizer_init_structured_output_request():
scheduler = create_scheduler(skip_tokenizer_init=True)
guided_params = GuidedDecodingParams(regex="[0-9]+")
sampling_params = SamplingParams(
ignore_eos=False,
max_tokens=16,
guided_decoding=guided_params,
)
request = Request(
request_id="0",
prompt_token_ids=[0, 1],
multi_modal_inputs=None,
multi_modal_hashes=None,
multi_modal_placeholders=None,
sampling_params=sampling_params,
pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
structured_output_request=StructuredOutputRequest(sampling_params),
)
scheduler.add_request(request)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 0
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import random import random
from typing import Optional from typing import TYPE_CHECKING, Optional
import pytest import pytest
from vllm import LLM, SamplingParams from vllm import LLM
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from vllm.v1.metrics.reader import Counter, Gauge, Histogram, Metric, Vector from vllm.v1.metrics.reader import Counter, Gauge, Histogram, Metric, Vector
if TYPE_CHECKING:
from tests.conftest import VllmRunner
MODEL = "facebook/opt-125m" MODEL = "facebook/opt-125m"
DTYPE = "half" DTYPE = "half"
def _vllm_model(apc: bool, vllm_runner, monkeypatch): def _vllm_model(
apc: bool,
vllm_runner: type[VllmRunner],
monkeypatch: pytest.MonkeyPatch,
*,
skip_tokenizer_init: bool = False,
):
"""Set up VllmRunner instance.""" """Set up VllmRunner instance."""
monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_USE_V1", "1")
return vllm_runner( return vllm_runner(
...@@ -23,6 +34,7 @@ def _vllm_model(apc: bool, vllm_runner, monkeypatch): ...@@ -23,6 +34,7 @@ def _vllm_model(apc: bool, vllm_runner, monkeypatch):
enforce_eager=True, enforce_eager=True,
enable_prefix_caching=apc, enable_prefix_caching=apc,
gpu_memory_utilization=0.5, gpu_memory_utilization=0.5,
skip_tokenizer_init=skip_tokenizer_init,
) )
...@@ -45,9 +57,27 @@ def vllm_model_apc(vllm_runner, monkeypatch): ...@@ -45,9 +57,27 @@ def vllm_model_apc(vllm_runner, monkeypatch):
yield vllm_model yield vllm_model
@pytest.fixture(
# Function scope decouples tests & allows
# env var adjustment via monkeypatch
scope="function",
# Prefix caching
params=[False, True])
def vllm_model_skip_tokenizer_init(vllm_runner, request, monkeypatch):
"""VllmRunner test fixture with APC."""
with _vllm_model(
request.param,
vllm_runner,
monkeypatch,
skip_tokenizer_init=True,
) as vllm_model:
yield vllm_model
def _get_test_sampling_params( def _get_test_sampling_params(
prompt_list: list[str], prompt_list: list[str],
seed: Optional[int] = 42, seed: Optional[int] = 42,
structured_outputs: bool = False,
) -> tuple[list[SamplingParams], list[int]]: ) -> tuple[list[SamplingParams], list[int]]:
"""Generate random sampling params for a batch.""" """Generate random sampling params for a batch."""
...@@ -62,11 +92,31 @@ def _get_test_sampling_params( ...@@ -62,11 +92,31 @@ def _get_test_sampling_params(
n_list = [get_mostly_n_gt1() for _ in range(len(prompt_list))] n_list = [get_mostly_n_gt1() for _ in range(len(prompt_list))]
# High temperature to maximize the chance of unique completions # High temperature to maximize the chance of unique completions
return [ return [
SamplingParams(temperature=0.95, top_p=0.95, n=n, seed=seed) SamplingParams(
for n in n_list temperature=0.95,
top_p=0.95,
n=n,
seed=seed,
guided_decoding=GuidedDecodingParams(
regex="[0-9]+") if structured_outputs else None,
) for n in n_list
], n_list ], n_list
def test_compatibility_with_skip_tokenizer_init(
vllm_model_skip_tokenizer_init: VllmRunner,
example_prompts: list[str],
):
# Case 1: Structured output request should raise an error.
sampling_params_list, _ = _get_test_sampling_params(
example_prompts,
structured_outputs=True,
)
model: LLM = vllm_model_skip_tokenizer_init.model
with pytest.raises(ValueError):
_ = model.generate(example_prompts, sampling_params_list)
def test_parallel_sampling(vllm_model, example_prompts) -> None: def test_parallel_sampling(vllm_model, example_prompts) -> None:
"""Test passes if parallel sampling `n>1` yields `n` unique completions. """Test passes if parallel sampling `n>1` yields `n` unique completions.
......
...@@ -152,6 +152,11 @@ class Processor: ...@@ -152,6 +152,11 @@ class Processor:
if not params.guided_decoding or not self.decoding_config: if not params.guided_decoding or not self.decoding_config:
return return
if self.model_config.skip_tokenizer_init and params.guided_decoding:
raise ValueError(
"Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'" # noqa: E501
)
engine_level_backend = self.decoding_config.backend engine_level_backend = self.decoding_config.backend
if params.guided_decoding.backend: if params.guided_decoding.backend:
# Request-level backend selection is not supported in V1. # Request-level backend selection is not supported in V1.
......
...@@ -40,10 +40,12 @@ class StructuredOutputManager: ...@@ -40,10 +40,12 @@ class StructuredOutputManager:
self._grammar_bitmask: Optional[torch.Tensor] = None self._grammar_bitmask: Optional[torch.Tensor] = None
self._full_mask = torch.tensor(-1, dtype=torch.int32) self._full_mask = torch.tensor(-1, dtype=torch.int32)
# The default max_workers if not specified is the number of CPUs * 5, if not self.vllm_config.model_config.skip_tokenizer_init:
# which is way too high since these tasks are CPU-bound, not I/O bound. # The default max_workers if not specified is the number of
# We also know we would never dominate CPU usage with just grammar # CPUs * 5, which is way too high since these tasks are CPU-bound,
# compilation, so we set it to half the number of CPUs. # not I/O bound. We also know we would never dominate CPU usage
# with just grammar compilation, so we set it to half the number
# of CPUs.
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2) max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
self.executor = ThreadPoolExecutor(max_workers=max_workers) self.executor = ThreadPoolExecutor(max_workers=max_workers)
self.tokenizer = init_tokenizer_from_configs( self.tokenizer = init_tokenizer_from_configs(
...@@ -51,7 +53,8 @@ class StructuredOutputManager: ...@@ -51,7 +53,8 @@ class StructuredOutputManager:
scheduler_config=self.vllm_config.scheduler_config, scheduler_config=self.vllm_config.scheduler_config,
lora_config=self.vllm_config.lora_config, lora_config=self.vllm_config.lora_config,
).get_lora_tokenizer(None) ).get_lora_tokenizer(None)
reasoning_backend = vllm_config.decoding_config.reasoning_backend reasoning_backend = \
self.vllm_config.decoding_config.reasoning_backend
if reasoning_backend: if reasoning_backend:
reasoner_cls = ReasoningParserManager.get_reasoning_parser( reasoner_cls = ReasoningParserManager.get_reasoning_parser(
reasoning_backend) reasoning_backend)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment