Commit d2b52805 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.2rc1' into v0.10.2rc1-ori

parents 9a521c23 5438967f
...@@ -52,11 +52,9 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind, ...@@ -52,11 +52,9 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
requests = [ requests = [
EngineCoreRequest(request_id=f"request-{idx}", EngineCoreRequest(request_id=f"request-{idx}",
prompt_token_ids=prompt_tokens, prompt_token_ids=prompt_tokens,
arrival_time=0, mm_features=None,
mm_kwargs=None,
mm_hashes=None,
mm_placeholders=None,
eos_token_id=None, eos_token_id=None,
arrival_time=0,
lora_request=None, lora_request=None,
cache_salt=None, cache_salt=None,
data_parallel_rank=None, data_parallel_rank=None,
...@@ -401,11 +399,9 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind, ...@@ -401,11 +399,9 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
requests = [ requests = [
EngineCoreRequest(request_id=request_id_list[idx], EngineCoreRequest(request_id=request_id_list[idx],
prompt_token_ids=prompt_tokens, prompt_token_ids=prompt_tokens,
arrival_time=0, mm_features=None,
mm_kwargs=None,
mm_hashes=None,
mm_placeholders=None,
eos_token_id=None, eos_token_id=None,
arrival_time=0,
lora_request=None, lora_request=None,
cache_salt=None, cache_salt=None,
data_parallel_rank=None, data_parallel_rank=None,
...@@ -566,11 +562,9 @@ def test_stop_token(include_stop_str_in_output: bool, ...@@ -566,11 +562,9 @@ def test_stop_token(include_stop_str_in_output: bool,
request = EngineCoreRequest( request = EngineCoreRequest(
request_id=request_id, request_id=request_id,
prompt_token_ids=prompt_tokens, prompt_token_ids=prompt_tokens,
arrival_time=0, mm_features=None,
mm_kwargs=None,
mm_hashes=None,
mm_placeholders=None,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
arrival_time=0,
lora_request=None, lora_request=None,
cache_salt=None, cache_salt=None,
data_parallel_rank=None, data_parallel_rank=None,
...@@ -665,11 +659,9 @@ def test_stop_string(include_stop_str_in_output: bool, ...@@ -665,11 +659,9 @@ def test_stop_string(include_stop_str_in_output: bool,
EngineCoreRequest( EngineCoreRequest(
request_id=request_id_list[idx], request_id=request_id_list[idx],
prompt_token_ids=prompt_tokens, prompt_token_ids=prompt_tokens,
arrival_time=0, mm_features=None,
mm_kwargs=None,
mm_hashes=None,
mm_placeholders=None,
eos_token_id=None, eos_token_id=None,
arrival_time=0,
lora_request=None, lora_request=None,
cache_salt=None, cache_salt=None,
data_parallel_rank=None, data_parallel_rank=None,
...@@ -781,11 +773,9 @@ def test_iteration_stats(dummy_test_vectors): ...@@ -781,11 +773,9 @@ def test_iteration_stats(dummy_test_vectors):
EngineCoreRequest( EngineCoreRequest(
request_id=f"request-{idx}", request_id=f"request-{idx}",
prompt_token_ids=prompt_tokens, prompt_token_ids=prompt_tokens,
arrival_time=0, mm_features=None,
mm_kwargs=None,
mm_hashes=None,
mm_placeholders=None,
eos_token_id=None, eos_token_id=None,
arrival_time=0,
lora_request=None, lora_request=None,
cache_salt=None, cache_salt=None,
data_parallel_rank=None, data_parallel_rank=None,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, VllmConfig
from vllm.platforms.interface import UnspecifiedPlatform
from vllm.sampling_params import SamplingParams
from vllm.v1.engine import processor as processor_mod
from vllm.v1.engine.processor import Processor
cherry_pil_image = ImageAsset("cherry_blossom").pil_image
stop_pil_image = ImageAsset("stop_sign").pil_image
baby_reading_np_ndarrays = VideoAsset("baby_reading").np_ndarrays
# Mock processor for testing
def _mk_processor(monkeypatch,
*,
mm_cache_gb: float = 4.0,
enable_prefix_caching: bool = True) -> Processor:
"""
Create a Processor instance with minimal configuration suitable for unit
tests without accessing external resources.
"""
monkeypatch.setattr(ModelConfig,
"try_get_generation_config",
lambda self: {},
raising=True)
monkeypatch.setattr(ModelConfig,
"__post_init__",
lambda self: None,
raising=True)
monkeypatch.setattr(UnspecifiedPlatform,
"is_async_output_supported",
classmethod(lambda cls, enforce_eager: True),
raising=True)
monkeypatch.setattr(
ModelConfig,
"verify_async_output_proc",
lambda self, parallel_config, speculative_config, device_config: None,
raising=True)
monkeypatch.setattr(ModelConfig,
"verify_with_parallel_config",
lambda self, parallel_config: None,
raising=True)
monkeypatch.setattr(processor_mod,
"processor_cache_from_config",
lambda vllm_config, mm_registry: None,
raising=True)
monkeypatch.setattr(VllmConfig,
"__post_init__",
lambda self: None,
raising=True)
model_config = ModelConfig(
skip_tokenizer_init=True,
max_model_len=128,
mm_processor_cache_gb=mm_cache_gb,
generation_config="vllm",
tokenizer="dummy",
)
# Minimal multimodal_config to satisfy references in
# Processor.process_inputs.
class _MockMMConfig:
def __init__(self, gb: float):
self.mm_processor_cache_gb = gb
model_config.multimodal_config = _MockMMConfig(
mm_cache_gb) # type: ignore[attr-defined]
vllm_config = VllmConfig(
model_config=model_config,
cache_config=CacheConfig(enable_prefix_caching=enable_prefix_caching),
device_config=DeviceConfig(device="cpu"),
)
# Pass tokenizer=None; InputPreprocessor handles None when
# skip_tokenizer_init is True.
return Processor(vllm_config, tokenizer=None) # type: ignore[arg-type]
def test_multi_modal_uuids_length_mismatch_raises(monkeypatch):
processor = _mk_processor(monkeypatch)
prompt = {
"prompt": "USER: <image>\nDescribe\nASSISTANT:",
"multi_modal_data": {
"image": [cherry_pil_image, stop_pil_image]
},
# Mismatch: 2 items but only 1 uuid provided
"multi_modal_uuids": {
"image": ["hash_cherry"]
},
}
with pytest.raises(ValueError, match="must have same length as data"):
processor.process_inputs(
request_id="req-1",
prompt=prompt, # type: ignore[arg-type]
params=SamplingParams(),
)
def test_multi_modal_uuids_missing_modality_raises(monkeypatch):
processor = _mk_processor(monkeypatch)
prompt = {
"prompt": "USER: <image><video>\nDescribe\nASSISTANT:",
# Two modalities provided in data
"multi_modal_data": {
"image": [cherry_pil_image],
"video": [baby_reading_np_ndarrays]
},
# Only image uuids provided; video missing should raise
"multi_modal_uuids": {
"image": ["hash_cherry"]
},
}
with pytest.raises(ValueError,
match="must be provided if multi_modal_data"):
processor.process_inputs(
request_id="req-2",
prompt=prompt, # type: ignore[arg-type]
params=SamplingParams(),
)
@pytest.mark.parametrize(
"mm_cache_gb, enable_prefix_caching",
[
(4.0, True), # default behavior
(4.0, False), # prefix caching disabled
(0.0, True), # processor cache disabled
],
)
def test_multi_modal_uuids_accepts_none_and_passes_through(
monkeypatch, mm_cache_gb: float, enable_prefix_caching: bool):
processor = _mk_processor(monkeypatch,
mm_cache_gb=mm_cache_gb,
enable_prefix_caching=enable_prefix_caching)
# Capture the overrides passed to InputPreprocessor.preprocess
captured: dict[str, object] = {}
def fake_preprocess(prompt,
*,
tokenization_kwargs=None,
lora_request=None,
mm_hash_overrides=None):
captured["mm_hash_overrides"] = mm_hash_overrides
# Minimal processed inputs for decoder-only flow
return {"type": "token", "prompt_token_ids": [1]}
# Monkeypatch only the bound preprocess method on this instance
monkeypatch.setattr(processor.input_preprocessor,
"preprocess",
fake_preprocess,
raising=True)
# Use a consistent two-image scenario across all configurations
mm_uuids = {"image": [None, "hash_stop"], "video": None}
prompt = {
"prompt": "USER: <image><image>\nTwo images\nASSISTANT:",
"multi_modal_data": {
"image": [cherry_pil_image, stop_pil_image],
"video": baby_reading_np_ndarrays,
},
"multi_modal_uuids": mm_uuids,
}
processor.process_inputs(
request_id="req-3",
prompt=prompt, # type: ignore[arg-type]
params=SamplingParams(),
)
assert captured["mm_hash_overrides"] == mm_uuids
def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
# When both processor cache is 0 and prefix caching disabled, the
# processor builds overrides from request id instead of using user UUIDs.
processor = _mk_processor(monkeypatch,
mm_cache_gb=0.0,
enable_prefix_caching=False)
captured: dict[str, object] = {}
def fake_preprocess(prompt,
*,
tokenization_kwargs=None,
lora_request=None,
mm_hash_overrides=None):
captured["mm_hash_overrides"] = mm_hash_overrides
return {"type": "token", "prompt_token_ids": [1]}
monkeypatch.setattr(processor.input_preprocessor,
"preprocess",
fake_preprocess,
raising=True)
request_id = "req-42"
mm_uuids = {"image": ["hash_cherry", "hash_stop"], "video": "hash_video"}
prompt = {
"prompt": "USER: <image><image><video>\nDescribe\nASSISTANT:",
"multi_modal_data": {
"image": [cherry_pil_image, stop_pil_image],
"video": baby_reading_np_ndarrays,
},
"multi_modal_uuids": mm_uuids,
}
processor.process_inputs(
request_id=request_id,
prompt=prompt, # type: ignore[arg-type]
params=SamplingParams(),
)
# Expect request-id-based overrides are passed through
assert captured["mm_hash_overrides"] == {
"image": [f"{request_id}-image-0", f"{request_id}-image-1"],
"video": [f"{request_id}-video-0"],
}
...@@ -11,9 +11,11 @@ from typing import TYPE_CHECKING, Any ...@@ -11,9 +11,11 @@ from typing import TYPE_CHECKING, Any
import jsonschema import jsonschema
import pytest import pytest
import regex as re import regex as re
import torch
from pydantic import BaseModel from pydantic import BaseModel
from tests.reasoning.utils import run_reasoning_extraction from tests.reasoning.utils import run_reasoning_extraction
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.entrypoints.llm import LLM from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -39,8 +41,11 @@ EAGLE_SPEC_CONFIG = { ...@@ -39,8 +41,11 @@ EAGLE_SPEC_CONFIG = {
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [ PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto", None), ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto", None),
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None), ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None),
("mistralai/Ministral-8B-Instruct-2410", "lm-format-enforcer", "auto",
None),
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None), ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None),
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None), ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None),
("Qwen/Qwen2.5-1.5B-Instruct", "lm-format-enforcer", "auto", None),
("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", None), ("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", None),
("mistralai/Ministral-8B-Instruct-2410", "outlines", "mistral", None), ("mistralai/Ministral-8B-Instruct-2410", "outlines", "mistral", None),
("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", ("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto",
...@@ -127,13 +132,15 @@ def test_structured_output( ...@@ -127,13 +132,15 @@ def test_structured_output(
temperature=1.0, temperature=1.0,
max_tokens=4096, max_tokens=4096,
guided_decoding=GuidedDecodingParams(json=sample_json_schema)) guided_decoding=GuidedDecodingParams(json=sample_json_schema))
outputs = llm.generate(prompts=[
(f"Give an example JSON for an employee profile that fits this " prompt = ("Give an example JSON for an employee profile that fits this "
f"schema. Make the response as short as possible. Schema: " "schema. Make the response as short as possible. Schema: "
f"{sample_json_schema}") f"{sample_json_schema}")
] * 2, outputs = llm.generate(
sampling_params=sampling_params, [prompt] * 2,
use_tqdm=True) sampling_params=sampling_params,
use_tqdm=True,
)
assert outputs is not None assert outputs is not None
...@@ -144,7 +151,8 @@ def test_structured_output( ...@@ -144,7 +151,8 @@ def test_structured_output(
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
assert generated_text is not None assert generated_text is not None
assert "\n" not in generated_text if guided_decoding_backend != 'lm-format-enforcer':
assert "\n" not in generated_text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text) output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=sample_json_schema) jsonschema.validate(instance=output_json, schema=sample_json_schema)
...@@ -191,20 +199,24 @@ def test_structured_output( ...@@ -191,20 +199,24 @@ def test_structured_output(
with pytest.raises(ValueError, with pytest.raises(ValueError,
match="The provided JSON schema contains features " match="The provided JSON schema contains features "
"not supported by xgrammar."): "not supported by xgrammar."):
prompt = (f"Give an example JSON for an employee profile that "
f"fits this schema: {unsupported_json_schema}. "
f"Make the response as short as possible.")
llm.generate( llm.generate(
prompts=[(f"Give an example JSON for an employee profile that " [prompt] * 2,
f"fits this schema: {unsupported_json_schema}. "
f"Make the response as short as possible.")] * 2,
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=True) use_tqdm=True,
)
else: else:
outputs = llm.generate(prompts=( prompt = (f"Give an example JSON object for a grade that "
"Give an example JSON object for a grade " f"fits this schema: {unsupported_json_schema}. "
"that fits this schema: " f"Make the response as short as possible.")
f"{unsupported_json_schema}. Make the response as short as " outputs = llm.generate(
"possible."), prompt,
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=True) use_tqdm=True,
)
assert outputs is not None assert outputs is not None
for output in outputs: for output in outputs:
assert output is not None assert output is not None
...@@ -217,7 +229,7 @@ def test_structured_output( ...@@ -217,7 +229,7 @@ def test_structured_output(
parsed_json = json.loads(generated_text) parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict) assert isinstance(parsed_json, dict)
if guided_decoding_backend != "outlines": if guided_decoding_backend not in ["outlines", "lm-format-enforcer"]:
# #
# Test 4: Generate SQL statement using EBNF grammar # Test 4: Generate SQL statement using EBNF grammar
# #
...@@ -227,10 +239,9 @@ def test_structured_output( ...@@ -227,10 +239,9 @@ def test_structured_output(
max_tokens=1000, max_tokens=1000,
guided_decoding=GuidedDecodingParams(grammar=sample_sql_ebnf)) guided_decoding=GuidedDecodingParams(grammar=sample_sql_ebnf))
outputs = llm.generate( outputs = llm.generate(
prompts=( ("Generate a sql statement that selects col_1 from "
"Generate a sql statement that selects col_1 from " "table_1 where it is equal to 1. Make the response as short as "
"table_1 where it is equal to 1. Make the response as short as " "possible."),
"possible."),
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=True, use_tqdm=True,
) )
...@@ -261,10 +272,9 @@ def test_structured_output( ...@@ -261,10 +272,9 @@ def test_structured_output(
max_tokens=1000, max_tokens=1000,
guided_decoding=GuidedDecodingParams(grammar=sample_sql_lark)) guided_decoding=GuidedDecodingParams(grammar=sample_sql_lark))
outputs = llm.generate( outputs = llm.generate(
prompts=( ("Generate a sql statement that selects col_1 from "
"Generate a sql statement that selects col_1 from " "table_1 where it is equal to 1. Make the response as short as "
"table_1 where it is equal to 1. Make the response as short as " "possible."),
"possible."),
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=True, use_tqdm=True,
) )
...@@ -301,7 +311,6 @@ def test_structured_output( ...@@ -301,7 +311,6 @@ def test_structured_output(
guided_decoding=GuidedDecodingParams(grammar="not a grammar")) guided_decoding=GuidedDecodingParams(grammar="not a grammar"))
with pytest.raises(ValueError, match="Failed to convert the grammar "): with pytest.raises(ValueError, match="Failed to convert the grammar "):
llm.generate( llm.generate(
prompts=
("Generate a sql statement that selects col_1 from " ("Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1. Make the response as short " "table_1 where it is equal to 1. Make the response as short "
"as possible."), "as possible."),
...@@ -316,11 +325,11 @@ def test_structured_output( ...@@ -316,11 +325,11 @@ def test_structured_output(
temperature=0.8, temperature=0.8,
top_p=0.95, top_p=0.95,
guided_decoding=GuidedDecodingParams(regex=sample_regex)) guided_decoding=GuidedDecodingParams(regex=sample_regex))
prompt = (f"Give an example IPv4 address with this regex: {sample_regex}. "
f"Make the response as short as possible.")
outputs = llm.generate( outputs = llm.generate(
prompts=[ [prompt] * 2,
(f"Give an example IPv4 address with this regex: {sample_regex}. "
f"Make the response as short as possible.")
] * 2,
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=True, use_tqdm=True,
) )
...@@ -343,11 +352,13 @@ def test_structured_output( ...@@ -343,11 +352,13 @@ def test_structured_output(
temperature=0.8, temperature=0.8,
top_p=0.95, top_p=0.95,
guided_decoding=GuidedDecodingParams(choice=sample_guided_choice)) guided_decoding=GuidedDecodingParams(choice=sample_guided_choice))
outputs = llm.generate( outputs = llm.generate(
prompts=("The best language for type-safe systems programming is " ("The best language for type-safe systems programming is "
"(Make the response as short as possible.) "), "(Make the response as short as possible.) "),
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=True) use_tqdm=True,
)
assert outputs is not None assert outputs is not None
for output in outputs: for output in outputs:
assert output is not None assert output is not None
...@@ -367,12 +378,14 @@ def test_structured_output( ...@@ -367,12 +378,14 @@ def test_structured_output(
temperature=1.0, temperature=1.0,
max_tokens=1000, max_tokens=1000,
guided_decoding=GuidedDecodingParams(json=json_schema)) guided_decoding=GuidedDecodingParams(json=json_schema))
outputs = llm.generate(prompts=(
"Generate a JSON with the brand, model and car_type of the most " outputs = llm.generate(
"iconic car from the 90's. Make the response as short as " ("Generate a JSON with the brand, model and car_type of the most "
"possible."), "iconic car from the 90's. Make the response as short as "
sampling_params=sampling_params, "possible."),
use_tqdm=True) sampling_params=sampling_params,
use_tqdm=True,
)
assert outputs is not None assert outputs is not None
...@@ -411,10 +424,11 @@ def test_structured_output( ...@@ -411,10 +424,11 @@ def test_structured_output(
guided_decoding=GuidedDecodingParams(json=json_schema)) guided_decoding=GuidedDecodingParams(json=json_schema))
outputs = llm.generate( outputs = llm.generate(
prompts=("Generate a description of a frog using 50 characters. " ("Generate a description of a frog using 50 characters. "
"Make the response as short as possible."), "Make the response as short as possible."),
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=True) use_tqdm=True,
)
assert outputs is not None assert outputs is not None
...@@ -429,7 +443,7 @@ def test_structured_output( ...@@ -429,7 +443,7 @@ def test_structured_output(
output_json = json.loads(generated_text) output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=json_schema) jsonschema.validate(instance=output_json, schema=json_schema)
if guided_decoding_backend != "outlines": if guided_decoding_backend not in ["outlines", "lm-format-enforcer"]:
# #
# Test 11: Generate structured output using structural_tag format # Test 11: Generate structured output using structural_tag format
# #
...@@ -498,7 +512,7 @@ Make the response as short as possible. ...@@ -498,7 +512,7 @@ Make the response as short as possible.
""" """
# Change this once other backends support structural_tag # Change this once other backends support structural_tag
outputs = llm.generate(prompts=prompt, outputs = llm.generate(prompt,
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=True) use_tqdm=True)
assert outputs is not None assert outputs is not None
...@@ -639,15 +653,13 @@ def test_structured_output_auto_mode( ...@@ -639,15 +653,13 @@ def test_structured_output_auto_mode(
f"{unsupported_json_schema}. Make the response as short as possible.") f"{unsupported_json_schema}. Make the response as short as possible.")
# This would fail with the default of "xgrammar", but in "auto" # This would fail with the default of "xgrammar", but in "auto"
# we will handle fallback automatically. # we will handle fallback automatically.
outputs = llm.generate(prompts=prompts, outputs = llm.generate(prompts,
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=True) use_tqdm=True)
# Make sure `auto` backend handling doesn't mess up sampling_params # Make sure `auto` backend handling doesn't mess up sampling_params
# and that we can reuse it without error. # and that we can reuse it without error.
outputs.extend( outputs.extend(
llm.generate(prompts=prompts, llm.generate(prompts, sampling_params=sampling_params, use_tqdm=True))
sampling_params=sampling_params,
use_tqdm=True))
assert outputs is not None assert outputs is not None
for output in outputs: for output in outputs:
...@@ -705,7 +717,7 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch): ...@@ -705,7 +717,7 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch):
max_tokens=256, max_tokens=256,
guided_decoding=guided_params) guided_decoding=guided_params)
outputs = llm.generate(prompts=prompt, sampling_params=sampling_params) outputs = llm.generate(prompt, sampling_params=sampling_params)
assert outputs is not None assert outputs is not None
generated_text = outputs[0].outputs[0].text generated_text = outputs[0].outputs[0].text
assert generated_text is not None assert generated_text is not None
...@@ -721,3 +733,83 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch): ...@@ -721,3 +733,83 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch):
assert "a4" not in generated assert "a4" not in generated
assert "a5" not in generated assert "a5" not in generated
assert "a6" not in generated assert "a6" not in generated
@pytest.mark.parametrize("guided_decoding_backend",
["guidance", "xgrammar", "outlines"])
def test_structured_output_batched_with_non_guided_requests(
monkeypatch: pytest.MonkeyPatch,
sample_json_schema: dict[str, Any],
guided_decoding_backend: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1")
# Don't use eager execution on TPUs because we want to test for no
# recompilation at runtime
enforce_eager = bool(not current_platform.is_tpu())
llm = LLM(
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
enforce_eager=enforce_eager,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend,
guided_decoding_disable_any_whitespace=(guided_decoding_backend
in {"xgrammar", "guidance"}),
)
guided_prompt = (
"Give an example JSON for an employee profile that fits this "
"schema. Make the response as short as possible. Schema: "
f"{sample_json_schema}")
non_guided_prompt = "The diameter of the Earth in kilometers is "
prompts = [guided_prompt, non_guided_prompt]
sampling_params = [
SamplingParams(
temperature=1.0,
max_tokens=400,
guided_decoding=GuidedDecodingParams(json=sample_json_schema)),
# No max tokens, temp=0 to assert on contents
SamplingParams(
seed=42,
temperature=0,
top_p=1.0,
),
]
outputs = llm.generate(prompts=prompts,
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
# Free memory as soon as possible as failed assertions
# will short circuit and not free up memory
del llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
for index, output in enumerate(outputs):
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
print(f"Prompt:\n{prompt!r}\nGenerated text:\n{generated_text!r}")
if index == 0:
# First prompt is guided, expect valid JSON
assert "\n" not in generated_text
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json,
schema=sample_json_schema)
else:
# Second prompt is not guided, expect valid output
# Cannot assert on exact output, but we can expect it to be factual
assert "12,742" in generated_text
# non-guided requests should not return a valid JSON here
with pytest.raises(ValueError):
output_json = json.loads(generated_text)
...@@ -73,3 +73,16 @@ async def test_chat_with_input_type(client: openai.AsyncOpenAI): ...@@ -73,3 +73,16 @@ async def test_chat_with_input_type(client: openai.AsyncOpenAI):
], ) ], )
print(response) print(response)
assert response.status == "completed" assert response.status == "completed"
@pytest.mark.asyncio
async def test_logprobs(client: openai.AsyncOpenAI):
response = await client.responses.create(
include=["message.output_text.logprobs"],
input="What is 13 * 24?",
top_logprobs=5,
)
print(response)
outputs = response.output
assert outputs[-1].content[-1].logprobs
assert len(outputs[-1].content[-1].logprobs[0].top_logprobs) == 5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import os
from typing import Any, Callable, Optional, Union
import pytest
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.sampling_params import SamplingParams
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.llm_engine import LLMEngine
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
class Mock:
...
class CustomMultiprocExecutor(MultiprocExecutor):
def collective_rpc(self,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict] = None,
non_block: bool = False,
unique_reply_rank: Optional[int] = None) -> list[Any]:
# Drop marker to show that this was ran
with open(".marker", "w"):
...
return super().collective_rpc(method, timeout, args, kwargs)
CustomMultiprocExecutorAsync = CustomMultiprocExecutor
MODEL = "Qwen/Qwen3-0.6B"
def test_custom_executor_type_checking():
with pytest.raises(ValueError):
engine_args = EngineArgs(
model=MODEL,
gpu_memory_utilization=0.2,
max_model_len=8192,
distributed_executor_backend=Mock,
)
LLMEngine.from_engine_args(engine_args)
with pytest.raises(ValueError):
engine_args = AsyncEngineArgs(model=MODEL,
gpu_memory_utilization=0.2,
max_model_len=8192,
distributed_executor_backend=Mock)
AsyncLLM.from_engine_args(engine_args)
@pytest.mark.parametrize("distributed_executor_backend", [
CustomMultiprocExecutor,
"tests.v1.executor.test_executor.CustomMultiprocExecutor"
])
def test_custom_executor(distributed_executor_backend, tmp_path):
cwd = os.path.abspath(".")
os.chdir(tmp_path)
try:
assert not os.path.exists(".marker")
engine_args = EngineArgs(
model=MODEL,
gpu_memory_utilization=0.2,
max_model_len=8192,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True, # reduce test time
)
engine = LLMEngine.from_engine_args(engine_args)
sampling_params = SamplingParams(max_tokens=1)
engine.add_request("0", "foo", sampling_params)
engine.step()
assert os.path.exists(".marker")
finally:
os.chdir(cwd)
@pytest.mark.parametrize("distributed_executor_backend", [
CustomMultiprocExecutorAsync,
"tests.v1.executor.test_executor.CustomMultiprocExecutorAsync"
])
def test_custom_executor_async(distributed_executor_backend, tmp_path):
cwd = os.path.abspath(".")
os.chdir(tmp_path)
try:
assert not os.path.exists(".marker")
engine_args = AsyncEngineArgs(
model=MODEL,
gpu_memory_utilization=0.2,
max_model_len=8192,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True, # reduce test time
)
engine = AsyncLLM.from_engine_args(engine_args)
sampling_params = SamplingParams(max_tokens=1)
async def t():
stream = engine.generate(request_id="0",
prompt="foo",
sampling_params=sampling_params)
async for x in stream:
...
asyncio.run(t())
assert os.path.exists(".marker")
finally:
os.chdir(cwd)
...@@ -14,6 +14,7 @@ from unittest.mock import patch ...@@ -14,6 +14,7 @@ from unittest.mock import patch
import pytest import pytest
import ray import ray
import torch
from vllm import LLM from vllm import LLM
from vllm.config import KVTransferConfig from vllm.config import KVTransferConfig
...@@ -22,6 +23,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( ...@@ -22,6 +23,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnectorWorker) NixlConnectorWorker)
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from .utils import create_request, create_scheduler, create_vllm_config from .utils import create_request, create_scheduler, create_vllm_config
...@@ -98,7 +100,6 @@ class FakeNixlWrapper: ...@@ -98,7 +100,6 @@ class FakeNixlWrapper:
def set_cycles_before_xfer_done(self, cycles: int): def set_cycles_before_xfer_done(self, cycles: int):
"""Set the number of cycles before a transfer is considered done.""" """Set the number of cycles before a transfer is considered done."""
self._cycles_before_xfer_done = cycles
@contextlib.contextmanager @contextlib.contextmanager
...@@ -562,3 +563,86 @@ def _run_abort_timeout_test(llm_kwargs: dict, timeout: int): ...@@ -562,3 +563,86 @@ def _run_abort_timeout_test(llm_kwargs: dict, timeout: int):
sampling_params) sampling_params)
# Request-0 times out and is cleared! # Request-0 times out and is cleared!
assert '0' not in req_to_blocks assert '0' not in req_to_blocks
def test_register_kv_caches(dist_init):
"""
Test that register_kv_caches() properly calls nixl_wrapper methods with
correct data.
This test verifies:
1. nixl_wrapper.get_reg_descs() is called with caches_data containing
tensor metadata
2. nixl_wrapper.get_xfer_descs() is called with blocks_data containing
block layout info
"""
vllm_config = create_vllm_config()
# Create test kv cache tensors using proper backend shape
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(num_blocks=2,
block_size=16,
num_kv_heads=4,
head_size=64)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
"layer2": shared_tensor,
}
# Store tensor info for validation
expected_tensor_size = shared_tensor[0].element_size(
) * shared_tensor[0].numel()
expected_base_addrs = [
shared_tensor[0].data_ptr(), shared_tensor[1].data_ptr(),
unique_tensor[0].data_ptr(), unique_tensor[1].data_ptr()
]
with patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper") as mock_nixl_wrapper, \
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event"), \
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"): # noqa: E501
# Create connector
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0)
# Get the mock instance
mock_wrapper_instance = mock_nixl_wrapper.return_value
connector.connector_worker.nixl_wrapper = mock_wrapper_instance
# Execute register_kv_caches
connector.register_kv_caches(kv_caches)
# Verify get_reg_descs was called with caches_data
assert mock_wrapper_instance.get_reg_descs.called
caches_data, _ = mock_wrapper_instance.get_reg_descs.call_args[0]
assert len(caches_data) == 4
for i, cache_entry in enumerate(caches_data):
base_addr, size, _tp_rank, _ = cache_entry
assert size == expected_tensor_size, \
f"Entry {i}: Expected tensor size {expected_tensor_size}, " \
f"got {size}"
assert base_addr == expected_base_addrs[i], \
f"Entry {i}: Expected base address {expected_base_addrs[i]}, " \
f"got {base_addr}"
# Verify get_xfer_descs was called with blocks_data
assert mock_wrapper_instance.get_xfer_descs.called
blocks_data, _ = mock_wrapper_instance.get_xfer_descs.call_args[0]
# Validate blocks_data structure and size
expected_blocks_count = 8
assert len(blocks_data) == expected_blocks_count, \
f"Expected {expected_blocks_count} blocks, " \
f"got {len(blocks_data)}"
expected_block_len = expected_tensor_size // 2
for i, block_entry in enumerate(blocks_data):
block_start_addr, block_len, tp_rank = block_entry
assert block_len == expected_block_len, \
f"Block entry {i}: Expected block len {expected_block_len}, " \
f"got {block_len}"
...@@ -162,9 +162,7 @@ def create_request(request_id: int, ...@@ -162,9 +162,7 @@ def create_request(request_id: int,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params, sampling_params=sampling_params,
pooling_params=None, pooling_params=None,
multi_modal_kwargs=None, mm_features=None,
multi_modal_placeholders=None,
multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID, eos_token_id=EOS_TOKEN_ID,
block_hasher=get_request_block_hasher(block_size, hash_fn), block_hasher=get_request_block_hasher(block_size, hash_fn),
) )
...@@ -200,7 +198,6 @@ def create_model_runner_output( ...@@ -200,7 +198,6 @@ def create_model_runner_output(
req_ids=req_ids, req_ids=req_ids,
req_id_to_index=req_id_to_index, req_id_to_index=req_id_to_index,
sampled_token_ids=sampled_token_ids, sampled_token_ids=sampled_token_ids,
spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=None, pooler_output=None,
......
...@@ -8,10 +8,9 @@ from typing import Optional ...@@ -8,10 +8,9 @@ from typing import Optional
import torch import torch
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.sampling_params import SamplingParams
from vllm.v1.sample.logits_processor import (LOGITSPROCS_GROUP, BatchUpdate, from vllm.v1.sample.logits_processor import (LOGITSPROCS_GROUP, BatchUpdate,
LogitsProcessor, LogitsProcessor)
MoveDirectionality) from vllm.v1.sample.logits_processor.builtin import process_dict_updates
MODEL_NAME = "facebook/opt-125m" MODEL_NAME = "facebook/opt-125m"
POOLING_MODEL_NAME = "BAAI/bge-base-en-v1.5" POOLING_MODEL_NAME = "BAAI/bge-base-en-v1.5"
...@@ -45,37 +44,19 @@ class DummyLogitsProcessor(LogitsProcessor): ...@@ -45,37 +44,19 @@ class DummyLogitsProcessor(LogitsProcessor):
def __init__(self, vllm_config: "VllmConfig", device: torch.device, def __init__(self, vllm_config: "VllmConfig", device: torch.device,
is_pin_memory: bool): is_pin_memory: bool):
self.req_info: dict[int, SamplingParams] = {} self.req_info: dict[int, int] = {}
def is_argmax_invariant(self) -> bool: def is_argmax_invariant(self) -> bool:
"""Never impacts greedy sampling""" """Never impacts greedy sampling"""
return False return False
def update_state(self, batch_update: Optional[BatchUpdate]): def update_state(self, batch_update: Optional[BatchUpdate]):
if not batch_update: process_dict_updates(
return self.req_info,
batch_update,
# Process added requests. lambda params, _, __: params.extra_args and
for index, params, _, _ in batch_update.added: (params.extra_args.get("target_token")),
assert params is not None )
if params.extra_args and (target_token :=
params.extra_args.get("target_token")):
self.req_info[index] = target_token
if self.req_info:
# Process removed requests.
for index in batch_update.removed:
self.req_info.pop(index, None)
# Process moved requests, unidirectional move (a->b) and swap
# (a<->b)
for adx, bdx, direct in batch_update.moved:
a_val = self.req_info.pop(adx, None)
b_val = self.req_info.pop(bdx, None)
if a_val is not None:
self.req_info[bdx] = a_val
if direct == MoveDirectionality.SWAP and b_val is not None:
self.req_info[adx] = b_val
def apply(self, logits: torch.Tensor) -> torch.Tensor: def apply(self, logits: torch.Tensor) -> torch.Tensor:
if not self.req_info: if not self.req_info:
......
...@@ -456,9 +456,7 @@ def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch): ...@@ -456,9 +456,7 @@ def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch):
assert len(logprob) == vocab_size assert len(logprob) == vocab_size
@pytest.mark.parametrize( @pytest.mark.parametrize("logprobs_mode", list(LogprobsMode))
"logprobs_mode",
["raw_logprobs", "raw_logits", "processed_logprobs", "processed_logits"])
def test_logprobs_mode(logprobs_mode: LogprobsMode, def test_logprobs_mode(logprobs_mode: LogprobsMode,
monkeypatch: pytest.MonkeyPatch): monkeypatch: pytest.MonkeyPatch):
"""Test with LLM engine with different logprobs_mode. """Test with LLM engine with different logprobs_mode.
...@@ -487,12 +485,14 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode, ...@@ -487,12 +485,14 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode,
for logprobs in output.logprobs: for logprobs in output.logprobs:
for token_id in logprobs: for token_id in logprobs:
logprob = logprobs[token_id] logprob = logprobs[token_id]
if "logprobs" in logprobs_mode: if logprobs_mode in (LogprobsMode.RAW_LOGPROBS,
LogprobsMode.PROCESSED_LOGPROBS):
assert logprob.logprob <= 0 assert logprob.logprob <= 0
if logprob.logprob > 0: if logprob.logprob > 0:
positive_values = positive_values + 1 positive_values = positive_values + 1
total_token_with_logprobs = total_token_with_logprobs + 1 total_token_with_logprobs = total_token_with_logprobs + 1
assert total_token_with_logprobs >= len(results[0].outputs) assert total_token_with_logprobs >= len(results[0].outputs)
if "logits" in logprobs_mode: if logprobs_mode in (LogprobsMode.RAW_LOGITS,
LogprobsMode.PROCESSED_LOGITS):
assert positive_values > 0 assert positive_values > 0
del llm del llm
...@@ -50,6 +50,7 @@ def forward_attention( ...@@ -50,6 +50,7 @@ def forward_attention(
dtype=torch.int32, dtype=torch.int32,
) )
context_lens = seq_lens - query_lens context_lens = seq_lens - query_lens
max_seq_len = int(seq_lens.max())
max_query_len = q_len max_query_len = q_len
num_actual_tokens = query_start_loc[-1] num_actual_tokens = query_start_loc[-1]
...@@ -81,6 +82,7 @@ def forward_attention( ...@@ -81,6 +82,7 @@ def forward_attention(
num_reqs=batch_size, num_reqs=batch_size,
num_actual_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len, max_query_len=max_query_len,
max_seq_len=max_seq_len,
block_table_tensor=block_table, block_table_tensor=block_table,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
) )
......
...@@ -75,9 +75,10 @@ async def generate( ...@@ -75,9 +75,10 @@ async def generate(
], ],
) )
@pytest.mark.parametrize("data_parallel_backend", ["mp", "ray"]) @pytest.mark.parametrize("data_parallel_backend", ["mp", "ray"])
@pytest.mark.parametrize("async_scheduling", [True, False])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_load(output_kind: RequestOutputKind, async def test_load(output_kind: RequestOutputKind, data_parallel_backend: str,
data_parallel_backend: str): async_scheduling: bool):
stats_loggers = {} stats_loggers = {}
...@@ -105,6 +106,7 @@ async def test_load(output_kind: RequestOutputKind, ...@@ -105,6 +106,7 @@ async def test_load(output_kind: RequestOutputKind,
prompt = "This is a test of data parallel" prompt = "This is a test of data parallel"
engine_args.data_parallel_backend = data_parallel_backend engine_args.data_parallel_backend = data_parallel_backend
engine_args.async_scheduling = async_scheduling
engine = AsyncLLM.from_engine_args(engine_args, engine = AsyncLLM.from_engine_args(engine_args,
stat_loggers=[SimpleStatsLogger]) stat_loggers=[SimpleStatsLogger])
after.callback(engine.shutdown) after.callback(engine.shutdown)
......
...@@ -11,7 +11,8 @@ import torch ...@@ -11,7 +11,8 @@ import torch
from vllm.multimodal.inputs import (MultiModalBatchedField, from vllm.multimodal.inputs import (MultiModalBatchedField,
MultiModalFieldElem, MultiModalFlatField, MultiModalFieldElem, MultiModalFlatField,
MultiModalKwargs, MultiModalKwargsItem, MultiModalKwargsItem,
MultiModalKwargsItems,
MultiModalSharedField, NestedTensors) MultiModalSharedField, NestedTensors)
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
...@@ -96,7 +97,7 @@ def test_encode_decode(monkeypatch: pytest.MonkeyPatch): ...@@ -96,7 +97,7 @@ def test_encode_decode(monkeypatch: pytest.MonkeyPatch):
class MyRequest(msgspec.Struct): class MyRequest(msgspec.Struct):
mm: Optional[list[MultiModalKwargs]] mm: Optional[list[MultiModalKwargsItems]]
def test_multimodal_kwargs(): def test_multimodal_kwargs():
...@@ -119,7 +120,7 @@ def test_multimodal_kwargs(): ...@@ -119,7 +120,7 @@ def test_multimodal_kwargs():
audio = MultiModalKwargsItem.from_elems([e1]) audio = MultiModalKwargsItem.from_elems([e1])
video = MultiModalKwargsItem.from_elems([e2]) video = MultiModalKwargsItem.from_elems([e2])
image = MultiModalKwargsItem.from_elems([e3, e4]) image = MultiModalKwargsItem.from_elems([e3, e4])
mm = MultiModalKwargs([audio, video, image]) mm = MultiModalKwargsItems.from_seq([audio, video, image])
# pack mm kwargs into a mock request so that it can be decoded properly # pack mm kwargs into a mock request so that it can be decoded properly
req = MyRequest([mm]) req = MyRequest([mm])
...@@ -133,19 +134,22 @@ def test_multimodal_kwargs(): ...@@ -133,19 +134,22 @@ def test_multimodal_kwargs():
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded) total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
# expected total encoding length, should be 14255, +-20 for minor changes # expected total encoding length, should be 14306, +-20 for minor changes
assert 14250 <= total_len <= 14300 assert 14275 <= total_len <= 14325
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0] decoded = decoder.decode(encoded).mm[0]
assert isinstance(decoded, MultiModalKwargsItems)
# check all modalities were recovered and do some basic sanity checks # check all modalities were recovered and do some basic sanity checks
assert len(decoded.modalities) == 3 assert len(decoded) == 3
images = decoded.get_items("image") images = decoded["image"]
assert len(images) == 1 assert len(images) == 1
assert len(images[0].items()) == 2 assert len(images[0].items()) == 2
assert list(images[0].keys()) == ["i0", "i1"] assert list(images[0].keys()) == ["i0", "i1"]
# check the tensor contents and layout in the main dict # check the tensor contents and layout in the main dict
assert all(nested_equal(mm[k], decoded[k]) for k in mm) mm_data = mm.get_data()
decoded_data = decoded.get_data()
assert all(nested_equal(mm_data[k], decoded_data[k]) for k in mm_data)
def nested_equal(a: NestedTensors, b: NestedTensors): def nested_equal(a: NestedTensors, b: NestedTensors):
......
...@@ -85,7 +85,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: ...@@ -85,7 +85,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_input_ids=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids={},
grammar_bitmask=None, grammar_bitmask=None,
) )
...@@ -164,7 +164,7 @@ def test_update_states_request_finished(model_runner): ...@@ -164,7 +164,7 @@ def test_update_states_request_finished(model_runner):
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids={req_id}, finished_req_ids={req_id},
free_encoder_input_ids=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids={},
grammar_bitmask=None, grammar_bitmask=None,
) )
...@@ -194,7 +194,7 @@ def test_update_states_request_resumed(model_runner): ...@@ -194,7 +194,7 @@ def test_update_states_request_resumed(model_runner):
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_input_ids=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids={},
grammar_bitmask=None, grammar_bitmask=None,
) )
...@@ -221,7 +221,7 @@ def test_update_states_request_resumed(model_runner): ...@@ -221,7 +221,7 @@ def test_update_states_request_resumed(model_runner):
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_input_ids=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids={},
grammar_bitmask=None, grammar_bitmask=None,
) )
...@@ -252,7 +252,7 @@ def test_update_states_no_changes(model_runner): ...@@ -252,7 +252,7 @@ def test_update_states_no_changes(model_runner):
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_input_ids=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids={},
grammar_bitmask=None, grammar_bitmask=None,
) )
...@@ -287,7 +287,7 @@ def test_update_states_request_unscheduled(model_runner): ...@@ -287,7 +287,7 @@ def test_update_states_request_unscheduled(model_runner):
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_input_ids=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids={},
grammar_bitmask=None, grammar_bitmask=None,
) )
......
...@@ -205,6 +205,7 @@ def _construct_cached_request_state(req_id_suffix: int): ...@@ -205,6 +205,7 @@ def _construct_cached_request_state(req_id_suffix: int):
pooling_params=None, pooling_params=None,
mm_kwargs=[], mm_kwargs=[],
mm_positions=[], mm_positions=[],
mm_hashes=[],
block_ids=([], ), block_ids=([], ),
generator=None, generator=None,
num_computed_tokens=len(output_token_ids), num_computed_tokens=len(output_token_ids),
......
...@@ -141,7 +141,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: ...@@ -141,7 +141,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_input_ids=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids={},
grammar_bitmask=None, grammar_bitmask=None,
) )
...@@ -207,7 +207,7 @@ def test_update_states_request_finished(model_runner, dist_init): ...@@ -207,7 +207,7 @@ def test_update_states_request_finished(model_runner, dist_init):
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids={req_id}, finished_req_ids={req_id},
free_encoder_input_ids=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids={},
grammar_bitmask=None, grammar_bitmask=None,
) )
...@@ -239,7 +239,7 @@ def test_update_states_request_resumed(model_runner, dist_init): ...@@ -239,7 +239,7 @@ def test_update_states_request_resumed(model_runner, dist_init):
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_input_ids=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids={},
grammar_bitmask=None, grammar_bitmask=None,
) )
...@@ -266,7 +266,7 @@ def test_update_states_request_resumed(model_runner, dist_init): ...@@ -266,7 +266,7 @@ def test_update_states_request_resumed(model_runner, dist_init):
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_input_ids=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids={},
grammar_bitmask=None, grammar_bitmask=None,
) )
...@@ -347,7 +347,7 @@ def test_update_states_no_changes(model_runner, dist_init): ...@@ -347,7 +347,7 @@ def test_update_states_no_changes(model_runner, dist_init):
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_input_ids=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids={},
grammar_bitmask=None, grammar_bitmask=None,
) )
...@@ -384,7 +384,7 @@ def test_update_states_request_unscheduled(model_runner, dist_init): ...@@ -384,7 +384,7 @@ def test_update_states_request_unscheduled(model_runner, dist_init):
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_input_ids=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids={},
grammar_bitmask=None, grammar_bitmask=None,
) )
...@@ -680,6 +680,7 @@ def test_init_kv_cache_with_kv_sharing_valid(): ...@@ -680,6 +680,7 @@ def test_init_kv_cache_with_kv_sharing_valid():
kv_cache_spec[layer_0].page_size_bytes kv_cache_spec[layer_0].page_size_bytes
runner.initialize_kv_cache(kv_cache_config) runner.initialize_kv_cache(kv_cache_config)
kv_cache_config_after_init = runner.kv_cache_config
layer_0_kv = vllm_ctx[layer_0].kv_cache[0] layer_0_kv = vllm_ctx[layer_0].kv_cache[0]
layer_1_kv = vllm_ctx[layer_1].kv_cache[0] layer_1_kv = vllm_ctx[layer_1].kv_cache[0]
...@@ -687,10 +688,12 @@ def test_init_kv_cache_with_kv_sharing_valid(): ...@@ -687,10 +688,12 @@ def test_init_kv_cache_with_kv_sharing_valid():
assert id(layer_1_kv) == id(layer_0_kv) assert id(layer_1_kv) == id(layer_0_kv)
# check layer 1 added to kv cache group's layer names # check layer 1 added to kv cache group's layer names
assert len(kv_cache_config.kv_cache_groups) == 1 assert len(kv_cache_config_after_init.kv_cache_groups) == 1
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2 assert len(kv_cache_config_after_init.kv_cache_groups[0].layer_names) == 2
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0 assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1 0] == layer_0
assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[
1] == layer_1
def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
......
...@@ -26,9 +26,5 @@ compressed-tensors, nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-W8A8-testing ...@@ -26,9 +26,5 @@ compressed-tensors, nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-W8A8-testing
awq, casperhansen/mixtral-instruct-awq, main awq, casperhansen/mixtral-instruct-awq, main
awq_marlin, casperhansen/mixtral-instruct-awq, main awq_marlin, casperhansen/mixtral-instruct-awq, main
fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main
marlin, nm-testing/zephyr-beta-7b-marlin-g128, main
marlin, robertgshaw2/zephyr-7b-beta-channelwise-marlin, main
qqq, HandH1998/QQQ-Llama-3-8b-g128, main
qqq, HandH1998/QQQ-Llama-3-8b, main
hqq, nm-testing/Llama-3.2-1B-Instruct-HQQ, main hqq, nm-testing/Llama-3.2-1B-Instruct-HQQ, main
None, mgleize/fairseq2-dummy-Llama-3.2-1B, main None, mgleize/fairseq2-dummy-Llama-3.2-1B, main
\ No newline at end of file
...@@ -9,10 +9,7 @@ from vllm.attention import AttentionMetadata, AttentionMetadataBuilder ...@@ -9,10 +9,7 @@ from vllm.attention import AttentionMetadata, AttentionMetadataBuilder
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.backends.utils import CommonAttentionState
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
from vllm.worker.pooling_model_runner import (
ModelInputForGPUWithPoolingMetadata)
class MockAttentionBackend(AttentionBackend): class MockAttentionBackend(AttentionBackend):
...@@ -114,54 +111,3 @@ def test_model_runner_input(): ...@@ -114,54 +111,3 @@ def test_model_runner_input():
assert (received_model_input.sampling_metadata.selected_token_indices == assert (received_model_input.sampling_metadata.selected_token_indices ==
sampling_metadata.selected_token_indices) sampling_metadata.selected_token_indices)
assert received_model_input.sampling_metadata.seq_groups is None assert received_model_input.sampling_metadata.seq_groups is None
def test_embedding_model_runner_input():
pooling_metadata = PoolingMetadata(
seq_groups=[[0]],
seq_data={},
prompt_lens=[1],
)
attn_metadata = AttentionMetadata(
num_prefills=1,
num_prefill_tokens=2,
num_decode_tokens=3,
slot_mapping=torch.zeros(1),
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
)
model_input = ModelInputForGPUWithPoolingMetadata(
input_tokens=torch.ones(10),
input_positions=torch.ones(10),
pooling_metadata=pooling_metadata,
attn_metadata=attn_metadata)
assert isinstance(model_input, ModelInputForGPUWithPoolingMetadata)
# Test round trip serialization.
tensor_dict = model_input.as_broadcastable_tensor_dict()
attn_backend = MockAttentionBackend()
received_model_input = (
ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
tensor_dict, attn_backend=attn_backend))
# Check that received copy has correct values.
assert isinstance(received_model_input,
ModelInputForGPUWithPoolingMetadata)
assert received_model_input.input_tokens is not None
assert (
received_model_input.input_tokens == model_input.input_tokens).all()
assert received_model_input.input_positions is not None
assert (received_model_input.input_positions == model_input.input_positions
).all()
assert received_model_input.multi_modal_kwargs is None
assert (received_model_input.multi_modal_kwargs ==
model_input.multi_modal_kwargs)
assert received_model_input.lora_requests is None
assert received_model_input.lora_requests == model_input.lora_requests
assert received_model_input.lora_mapping is None
assert received_model_input.lora_mapping == model_input.lora_mapping
for field in dataclasses.fields(AttentionMetadata):
assert getattr(received_model_input.attn_metadata, field.name,
None) == getattr(attn_metadata, field.name, None)
# Pooling metadata is not broadcast.
assert received_model_input.pooling_metadata is None
...@@ -37,7 +37,7 @@ ALLOWED_FILES = set([ ...@@ -37,7 +37,7 @@ ALLOWED_FILES = set([
'vllm/distributed/utils.py', 'vllm/distributed/utils.py',
'vllm/distributed/parallel_state.py', 'vllm/distributed/parallel_state.py',
'vllm/engine/multiprocessing/client.py', 'vllm/engine/multiprocessing/client.py',
'vllm/distributed/device_communicators/custom_all_reduce_utils.py', 'vllm/distributed/device_communicators/all_reduce_utils.py',
'vllm/distributed/device_communicators/shm_broadcast.py', 'vllm/distributed/device_communicators/shm_broadcast.py',
'vllm/engine/multiprocessing/engine.py', 'vllm/engine/multiprocessing/engine.py',
'benchmarks/kernels/graph_machete_bench.py', 'benchmarks/kernels/graph_machete_bench.py',
......
...@@ -77,6 +77,7 @@ clone_repo() { ...@@ -77,6 +77,7 @@ clone_repo() {
local repo_url=$1 local repo_url=$1
local dir_name=$2 local dir_name=$2
local key_file=$3 local key_file=$3
local commit_hash=$4
if [ -d "$dir_name" ]; then if [ -d "$dir_name" ]; then
# Check if directory has uncommitted changes (dirty) # Check if directory has uncommitted changes (dirty)
...@@ -87,17 +88,27 @@ clone_repo() { ...@@ -87,17 +88,27 @@ clone_repo() {
echo "$dir_name directory exists but clone appears incomplete, cleaning up and re-cloning" echo "$dir_name directory exists but clone appears incomplete, cleaning up and re-cloning"
rm -rf "$dir_name" rm -rf "$dir_name"
git clone "$repo_url" git clone "$repo_url"
if [ -n "$commit_hash" ]; then
cd "$dir_name"
git checkout "$commit_hash"
cd ..
fi
else else
echo "$dir_name directory exists and appears complete; manually update if needed" echo "$dir_name directory exists and appears complete; manually update if needed"
fi fi
else else
git clone "$repo_url" git clone "$repo_url"
if [ -n "$commit_hash" ]; then
cd "$dir_name"
git checkout "$commit_hash"
cd ..
fi
fi fi
} }
# build and install pplx, require pytorch installed # build and install pplx, require pytorch installed
pushd $WORKSPACE pushd $WORKSPACE
clone_repo "https://github.com/ppl-ai/pplx-kernels" "pplx-kernels" "setup.py" clone_repo "https://github.com/ppl-ai/pplx-kernels" "pplx-kernels" "setup.py" "c336faf"
cd pplx-kernels cd pplx-kernels
# see https://github.com/pypa/pip/issues/9955#issuecomment-838065925 # see https://github.com/pypa/pip/issues/9955#issuecomment-838065925
# PIP_NO_BUILD_ISOLATION=0 disables build isolation # PIP_NO_BUILD_ISOLATION=0 disables build isolation
...@@ -106,7 +117,7 @@ popd ...@@ -106,7 +117,7 @@ popd
# build and install deepep, require pytorch installed # build and install deepep, require pytorch installed
pushd $WORKSPACE pushd $WORKSPACE
clone_repo "https://github.com/deepseek-ai/DeepEP" "DeepEP" "setup.py" clone_repo "https://github.com/deepseek-ai/DeepEP" "DeepEP" "setup.py" "e3908bf"
cd DeepEP cd DeepEP
export NVSHMEM_DIR=$WORKSPACE/nvshmem_install export NVSHMEM_DIR=$WORKSPACE/nvshmem_install
PIP_NO_BUILD_ISOLATION=0 pip install -vvv -e . PIP_NO_BUILD_ISOLATION=0 pip install -vvv -e .
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment