"lib/vscode:/vscode.git/clone" did not exist on "d3b0cae1db6eed9abe546f985bb2d11cdaa622e7"
Commit 9c4ecf15 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.4' into v0.8.4-ori

parents bfc2d6f7 dc1b4a6f
...@@ -51,6 +51,10 @@ def run_test( ...@@ -51,6 +51,10 @@ def run_test(
model_info.check_available_online(on_fail="skip") model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip") model_info.check_transformers_version(on_fail="skip")
# Disable other modalities to save memory
default_limits = {"image": 0, "video": 0, "audio": 0}
limit_mm_per_prompt = default_limits | limit_mm_per_prompt
vllm_outputs_per_mm = [] vllm_outputs_per_mm = []
hf_outputs_per_mm = [] hf_outputs_per_mm = []
......
...@@ -204,6 +204,12 @@ def idefics3_trunc_hf_output(hf_output: RunnerOutput, ...@@ -204,6 +204,12 @@ def idefics3_trunc_hf_output(hf_output: RunnerOutput,
return output_ids, output_str, out_logprobs return output_ids, output_str, out_logprobs
def smolvlm_trunc_hf_output(hf_output: RunnerOutput,
model: str) -> RunnerOutput:
# Based on Idefics3
return idefics3_trunc_hf_output(hf_output, model)
def minicpmv_trunc_hf_output(hf_output: RunnerOutput, def minicpmv_trunc_hf_output(hf_output: RunnerOutput,
model: str) -> RunnerOutput: model: str) -> RunnerOutput:
output_ids, output_str, out_logprobs = hf_output output_ids, output_str, out_logprobs = hf_output
......
...@@ -2,13 +2,16 @@ ...@@ -2,13 +2,16 @@
# ruff: noqa: E501 # ruff: noqa: E501
"""Compare the scoring outputs of HF and vLLM models. """Compare the scoring outputs of HF and vLLM models.
Run `pytest tests/models/embedding/language/test_jina_reranker_v2.py`. Run `pytest tests/models/embedding/language/test_jina.py`.
""" """
import math import math
import pytest import pytest
MODELS = [ from tests.models.embedding.utils import check_embeddings_close, matryoshka_fy
from vllm import PoolingParams
SCORING_MODELS = [
"jinaai/jina-reranker-v2-base-multilingual", # Roberta "jinaai/jina-reranker-v2-base-multilingual", # Roberta
] ]
...@@ -27,8 +30,21 @@ TEXTS_2 = [ ...@@ -27,8 +30,21 @@ TEXTS_2 = [
"新しいメイクのトレンドは鮮やかな色と革新的な技術に焦点を当てています", "新しいメイクのトレンドは鮮やかな色と革新的な技術に焦点を当てています",
] ]
EMBEDDING_MODELS = [
"jinaai/jina-embeddings-v3",
]
EMBEDDING_PROMPTS = [
"Follow the white rabbit.", # English
"Sigue al conejo blanco.", # Spanish
"Suis le lapin blanc.", # French
"跟着白兔走。", # Chinese
"اتبع الأرنب الأبيض.", # Arabic
"Folge dem weißen Kaninchen.", # German
]
@pytest.fixture(scope="module", params=MODELS) @pytest.fixture(scope="module", params=SCORING_MODELS)
def model_name(request): def model_name(request):
yield request.param yield request.param
...@@ -68,3 +84,83 @@ def test_llm_1_to_N(vllm_runner, hf_runner, model_name, dtype: str): ...@@ -68,3 +84,83 @@ def test_llm_1_to_N(vllm_runner, hf_runner, model_name, dtype: str):
assert math.isclose(hf_outputs[0], vllm_outputs[0], rel_tol=0.01) assert math.isclose(hf_outputs[0], vllm_outputs[0], rel_tol=0.01)
assert math.isclose(hf_outputs[1], vllm_outputs[1], rel_tol=0.01) assert math.isclose(hf_outputs[1], vllm_outputs[1], rel_tol=0.01)
@pytest.fixture(scope="module", params=EMBEDDING_MODELS)
def emb_model_name(request):
yield request.param
def test_is_matryoshka(vllm_runner, emb_model_name):
with vllm_runner(emb_model_name, task="embed",
max_model_len=None) as vllm_model:
assert vllm_model.model.llm_engine.model_config.is_matryoshka
@pytest.mark.parametrize("model", EMBEDDING_MODELS)
@pytest.mark.parametrize("dtype", ["half"])
def test_embeddings(
hf_runner,
vllm_runner,
model,
dtype: str,
monkeypatch,
) -> None:
example_prompts = EMBEDDING_PROMPTS
with hf_runner(
model,
dtype=dtype,
is_sentence_transformer=True,
) as hf_model:
hf_outputs = hf_model.encode(example_prompts, task="text-matching")
with vllm_runner(model, task="embed", dtype=dtype,
max_model_len=None) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts)
check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
tol=1e-2,
)
@pytest.mark.parametrize("model", EMBEDDING_MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("dimensions", [16, 32])
def test_matryoshka(
hf_runner,
vllm_runner,
model,
dtype: str,
dimensions: int,
monkeypatch,
) -> None:
example_prompts = EMBEDDING_PROMPTS
with hf_runner(
model,
dtype=dtype,
is_sentence_transformer=True,
) as hf_model:
hf_outputs = hf_model.encode(example_prompts, task="text-matching")
hf_outputs = matryoshka_fy(hf_outputs, dimensions)
with vllm_runner(model, task="embed", dtype=dtype,
max_model_len=None) as vllm_model:
vllm_outputs = vllm_model.encode(
example_prompts,
pooling_params=PoolingParams(dimensions=dimensions))
check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
tol=1e-2,
)
...@@ -30,3 +30,10 @@ def check_embeddings_close( ...@@ -30,3 +30,10 @@ def check_embeddings_close(
f"\n{name_1}:\t{embeddings_1[:16]!r}") f"\n{name_1}:\t{embeddings_1[:16]!r}")
assert sim >= 1 - tol, fail_msg assert sim >= 1 - tol, fail_msg
def matryoshka_fy(tensor, dimensions):
tensor = torch.tensor(tensor)
tensor = tensor[..., :dimensions]
tensor = F.normalize(tensor, p=2, dim=1)
return tensor
...@@ -209,14 +209,15 @@ def _run_test( ...@@ -209,14 +209,15 @@ def _run_test(
# will hurt multiprocessing backend with fork method (the default method). # will hurt multiprocessing backend with fork method (the default method).
# max_model_len should be greater than image_feature_size # max_model_len should be greater than image_feature_size
with vllm_runner(model, with vllm_runner(
dtype=dtype, model,
max_model_len=4096, dtype=dtype,
max_num_seqs=3, max_model_len=19212, # 3 max size images
tensor_parallel_size=tensor_parallel_size, max_num_seqs=3,
distributed_executor_backend=distributed_executor_backend, tensor_parallel_size=tensor_parallel_size,
limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT distributed_executor_backend=distributed_executor_backend,
}) as vllm_model: limit_mm_per_prompt={"image":
_LIMIT_IMAGE_PER_PROMPT}) as vllm_model:
vllm_outputs_per_image = [ vllm_outputs_per_image = [
vllm_model.generate_greedy_logprobs(prompts, vllm_model.generate_greedy_logprobs(prompts,
max_tokens, max_tokens,
...@@ -422,7 +423,7 @@ def test_bnb_regression( ...@@ -422,7 +423,7 @@ def test_bnb_regression(
llm = LLM( llm = LLM(
model=model, model=model,
dtype=dtype, dtype=dtype,
max_model_len=4096, max_model_len=8192,
max_num_seqs=2, max_num_seqs=2,
quantization="bitsandbytes", quantization="bitsandbytes",
) )
...@@ -475,7 +476,7 @@ def test_explicit_implicit_prompt( ...@@ -475,7 +476,7 @@ def test_explicit_implicit_prompt(
llm = LLM( llm = LLM(
model=model, model=model,
dtype=dtype, dtype=dtype,
max_model_len=4096, max_model_len=8192,
max_num_seqs=2, max_num_seqs=2,
tensor_parallel_size=1, tensor_parallel_size=1,
) )
...@@ -506,8 +507,8 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens, ...@@ -506,8 +507,8 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens,
with global_force_attn_backend_context_manager(attn_backend), vllm_runner( with global_force_attn_backend_context_manager(attn_backend), vllm_runner(
model, model,
dtype=dtype, dtype=dtype,
max_model_len=4096, max_model_len=8192,
max_num_seqs=2, max_num_seqs=4,
tensor_parallel_size=1, tensor_parallel_size=1,
limit_mm_per_prompt={"image": limit_mm_per_prompt={"image":
_LIMIT_IMAGE_PER_PROMPT}) as vllm_model: _LIMIT_IMAGE_PER_PROMPT}) as vllm_model:
...@@ -552,6 +553,23 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens, ...@@ -552,6 +553,23 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens,
num_logprobs, num_logprobs,
images=images) images=images)
# Mixed batch with text and images with different numbers of tiles
prompts = [
"<|begin_of_text|>Hello!",
"<|begin_of_text|>Some text before.<|image|>What is in the image?", # noqa: E501
"<|begin_of_text|>Some text before.<|image|>What is in the image?", # noqa: E501
]
images = [
None,
[stop_sign],
# smaller image must be 2nd for the repro
[stop_sign.resize((448, 448))],
]
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs,
images=images)
class DummyModel: class DummyModel:
image_token_id = MLLAMA_IMAGE_TOKEN_ID image_token_id = MLLAMA_IMAGE_TOKEN_ID
...@@ -674,3 +692,26 @@ def test_get_full_text_row_masked_out_mask(input_indices) -> None: ...@@ -674,3 +692,26 @@ def test_get_full_text_row_masked_out_mask(input_indices) -> None:
f"full_text_row_masked_out_mask[{idx}] must be " \ f"full_text_row_masked_out_mask[{idx}] must be " \
f"'{must_be_masked}' " f"'{must_be_masked}' "
idx += 1 idx += 1
@pytest.mark.core_model
@pytest.mark.parametrize("encoder_seq_lens, num_tiles, expected", [
([6404], [[4]], [6404]),
([0, 6404], [[4]], [6404]),
([0, 1601, 8005], [[1], [4, 1]], [1601, 8005]),
([0, 19212, 0, 3202], [[4, 4, 4], [2]], [19212, 3202]),
])
def test_parse_and_validate_encoder_lens(encoder_seq_lens, num_tiles,
expected) -> None:
dummy = DummyModel()
num_tokens_per_tile = 1601
actual_encoder_seq_lens = MllamaForConditionalGeneration \
._get_and_validate_encoder_lens(
dummy,
encoder_seq_lens,
num_tiles,
num_tokens_per_tile,
)
assert actual_encoder_seq_lens == expected, \
f"Expected {expected} but got {actual_encoder_seq_lens}"
...@@ -257,6 +257,8 @@ def _test_processing_correctness_mistral( ...@@ -257,6 +257,8 @@ def _test_processing_correctness_mistral(
"h2oai/h2ovl-mississippi-800m", "h2oai/h2ovl-mississippi-800m",
"OpenGVLab/InternVL2-1B", "OpenGVLab/InternVL2-1B",
"HuggingFaceM4/Idefics3-8B-Llama3", "HuggingFaceM4/Idefics3-8B-Llama3",
"HuggingFaceTB/SmolVLM2-2.2B-Instruct",
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
"llava-hf/llava-1.5-7b-hf", "llava-hf/llava-1.5-7b-hf",
"llava-hf/llava-v1.6-mistral-7b-hf", "llava-hf/llava-v1.6-mistral-7b-hf",
"llava-hf/LLaVA-NeXT-Video-7B-hf", "llava-hf/LLaVA-NeXT-Video-7B-hf",
......
...@@ -71,29 +71,14 @@ def test_processor_override( ...@@ -71,29 +71,14 @@ def test_processor_override(
# image token offsets # image token offsets
img_locs = processed_inputs["mm_placeholders"].get("image", []) img_locs = processed_inputs["mm_placeholders"].get("image", [])
assert len(img_locs) == num_imgs assert len(img_locs) == num_imgs
assert [img_loc["offset"] for img_loc in img_locs] == \ assert [img_loc.offset for img_loc in img_locs] == \
[i for i, v in enumerate(prompt_token_ids) \ [i for i, v in enumerate(prompt_token_ids) \
if v == config.boi_token_index] if v == config.boi_token_index]
# patch sizes and masks # patch sizes and masks
assert prompt_token_ids.count(config.image_token_index) \
== sum(img_patch.sum() for img_patch in mm_kwargs["embed_is_patch"])
patch_token_id = vocab[hf_processor.img_patch_token]
num_patches = processed_inputs["prompt_token_ids"].count(patch_token_id)
mm_counts = {"image": num_imgs}
assert num_patches / num_imgs <= \
processor.info.get_mm_max_tokens_per_item(32768, mm_counts)["image"]
num_patches_per_chunk = processor.info.get_patch_per_chunk( num_patches_per_chunk = processor.info.get_patch_per_chunk(
config.vision_config) config.vision_config)
assert prompt_token_ids.count(config.image_token_index) \ assert prompt_token_ids.count(config.image_token_index) \
== mm_kwargs["patches_per_image"].sum() * num_patches_per_chunk == mm_kwargs["patches_per_image"].sum() * num_patches_per_chunk
assert mm_kwargs["pixel_values"].shape[0] \ assert mm_kwargs["pixel_values"].shape[0] \
== mm_kwargs["patches_per_image"].sum() == mm_kwargs["patches_per_image"].sum()
for embed_is_patch, aspect_ratio in zip(mm_kwargs["embed_is_patch"],
mm_kwargs["aspect_ratios"]):
assert embed_is_patch.shape[0] == \
len(tokenizer.encode(
hf_processor._prompt_split_image(
aspect_ratio, num_patches_per_chunk),
add_special_tokens=False))
...@@ -92,8 +92,8 @@ def _validate_image_prompt_replacements_one( ...@@ -92,8 +92,8 @@ def _validate_image_prompt_replacements_one(
first_placeholder = image_placeholders[0] first_placeholder = image_placeholders[0]
# NOTE: There is a BOS token # NOTE: There is a BOS token
assert first_placeholder["offset"] == 1 assert first_placeholder.offset == 1
assert first_placeholder["length"] == ( assert first_placeholder.length == (
len(processed_inputs["prompt_token_ids"]) - 1) // num_imgs len(processed_inputs["prompt_token_ids"]) - 1) // num_imgs
except Exception as exc: except Exception as exc:
......
...@@ -92,8 +92,8 @@ def _validate_image_prompt_replacements_one( ...@@ -92,8 +92,8 @@ def _validate_image_prompt_replacements_one(
first_placeholder = image_placeholders[0] first_placeholder = image_placeholders[0]
assert first_placeholder["offset"] == 0 assert first_placeholder.offset == 0
assert first_placeholder["length"] == len( assert first_placeholder.length == len(
processed_inputs["prompt_token_ids"]) // num_imgs processed_inputs["prompt_token_ids"]) // num_imgs
except Exception as exc: except Exception as exc:
failed_size_excs.append((image_size, exc)) failed_size_excs.append((image_size, exc))
......
# SPDX-License-Identifier: Apache-2.0
"""Tests for mllama's multimodal preprocessing and profiling."""
import pytest
from transformers import MllamaConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.profiling import MultiModalProfiler
from ...utils import build_model_context
@pytest.mark.parametrize("model_id",
["meta-llama/Llama-3.2-11B-Vision-Instruct"])
@pytest.mark.parametrize("max_model_len", [4096, 8192, 25600, 131072])
@pytest.mark.parametrize("max_num_seqs", [1, 2, 8])
def test_profiling(
model_id: str,
max_model_len: int,
max_num_seqs: int,
):
# regression test for https://github.com/vllm-project/vllm/issues/13929
from vllm.model_executor.models.mllama import calc_token_per_chunk
model_config_kwargs = {
"max_model_len": max_model_len,
}
ctx = build_model_context(
model_id,
model_config_kwargs=model_config_kwargs,
limit_mm_per_prompt={"image": 1},
)
mm_config = ctx.get_mm_config()
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
profiler = MultiModalProfiler(processor)
dummy_encoder_data = profiler.get_encoder_dummy_data(
max_model_len,
mm_counts=mm_config.limit_per_prompt,
)
dummy_mm_data = processor.dummy_inputs.get_dummy_processor_inputs(
max_model_len,
mm_counts=mm_config.limit_per_prompt,
)
hf_config = ctx.get_hf_config(MllamaConfig)
image_size = hf_config.vision_config.image_size
encoder_seq_lens = [len(dummy_encoder_data.prompt_token_ids)
] * max_num_seqs
mm_kwargs = processor.apply(
prompt=dummy_mm_data.prompt_text,
mm_data=dummy_mm_data.mm_data,
hf_processor_mm_kwargs=dict(),
)["mm_kwargs"]
# Get the actual number of encoder tokens for each sample.
# Because attn_metadata.encoder_seq_lens only counts the last
# group of images for each sample, which is used to cheat the
# block manager to allocate blocks for those images only.
# See MllamaMultiModalProcessor for more details.
num_tiles = [[t] for t in mm_kwargs.pop("num_tiles")]
num_tokens_per_tile = calc_token_per_chunk(image_size)
actual_encoder_seq_lens = [
sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles
]
# simulate mllama image-present prefill.
for actual_len, last_group_len in zip(actual_encoder_seq_lens,
encoder_seq_lens):
assert actual_len >= last_group_len
# SPDX-License-Identifier: Apache-2.0
"""Tests for smolvlm's multimodal preprocessing kwargs."""
import pytest
from transformers import SmolVLMConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from ....conftest import _ImageAssets
from ...utils import build_model_context
@pytest.mark.parametrize("model_id", ["HuggingFaceTB/SmolVLM2-2.2B-Instruct"])
# yapf: disable
@pytest.mark.parametrize(
("mm_processor_kwargs", "expected_toks_per_img"),
[
({"max_image_size": {"longest_edge": 384}}, 1377),
({"max_image_size": {"longest_edge": 768}}, 405),
])
# yapf: enable
@pytest.mark.parametrize("num_imgs", [1, 2])
@pytest.mark.parametrize("kwargs_on_init", [True, False])
def test_processor_override(
image_assets: _ImageAssets,
model_id: str,
mm_processor_kwargs: dict[str, object],
expected_toks_per_img: int,
num_imgs: int,
kwargs_on_init: bool,
):
"""Ensure Idefics3MultiModalProcessor handles num_crops properly."""
# Same as the previous test - don't initialize mm_processor_kwargs
# in this test and assume that the kwargs will be correctly expanded by
# the partial when calling the custom input processor.
ctx = build_model_context(
model_id,
mm_processor_kwargs=mm_processor_kwargs if kwargs_on_init else None,
limit_mm_per_prompt={"image": num_imgs},
)
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
hf_processor_mm_kwargs = {} if kwargs_on_init else mm_processor_kwargs
# Build the image str / prompt based on the number of images we pass
placeholders = "<image>" if num_imgs == 1 else "\n".join(
f"Image-{i}: <image>\n" for i in range(1, num_imgs + 1))
prompt = f"<|im_start|>User:{placeholders}\n<end_of_utterance>\nAssistant:" # noqa: E501
# Build mm_data
image_size = ctx.get_hf_config(SmolVLMConfig).vision_config.image_size
dummy_image_size = (image_size * 4, image_size * 4)
dummy_image = image_assets[0].pil_image.resize(dummy_image_size)
mm_data = {"image": [dummy_image] * num_imgs}
processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs)
# Ensure the placeholders format are correct
hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs)
hf_processed_inputs = hf_processor(text=prompt, images=mm_data["image"])
assert processed_inputs["prompt_token_ids"] == hf_processed_inputs[
"input_ids"][0]
# Ensure we have the right number of placeholders per num_crops size
image_token_id = ctx.get_hf_config().image_token_id
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
assert img_tok_count == expected_toks_per_img * num_imgs
...@@ -124,6 +124,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -124,6 +124,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"BloomForCausalLM": _HfExamplesInfo("bigscience/bloomz-1b1"), "BloomForCausalLM": _HfExamplesInfo("bigscience/bloomz-1b1"),
"ChatGLMModel": _HfExamplesInfo("THUDM/chatglm3-6b", "ChatGLMModel": _HfExamplesInfo("THUDM/chatglm3-6b",
trust_remote_code=True), trust_remote_code=True),
"ChatGLMForConditionalGeneration": _HfExamplesInfo("thu-coai/ShieldLM-6B-chatglm3", # noqa: E501
trust_remote_code=True),
"CohereForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r-v01", "CohereForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r-v01",
trust_remote_code=True), trust_remote_code=True),
"Cohere2ForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r7b-12-2024", # noqa: E501 "Cohere2ForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r7b-12-2024", # noqa: E501
...@@ -144,6 +146,11 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -144,6 +146,11 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it", "Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it",
min_transformers_version="4.50"), min_transformers_version="4.50"),
"GlmForCausalLM": _HfExamplesInfo("THUDM/glm-4-9b-chat-hf"), "GlmForCausalLM": _HfExamplesInfo("THUDM/glm-4-9b-chat-hf"),
"Glm4ForCausalLM": _HfExamplesInfo(
"THUDM/GLM-4-32B-Chat-0414",
is_available_online=False,
min_transformers_version="4.52.dev0"
),
"GPT2LMHeadModel": _HfExamplesInfo("gpt2"), "GPT2LMHeadModel": _HfExamplesInfo("gpt2"),
"GPTBigCodeForCausalLM": _HfExamplesInfo("bigcode/starcoder"), "GPTBigCodeForCausalLM": _HfExamplesInfo("bigcode/starcoder"),
"GPTJForCausalLM": _HfExamplesInfo("EleutherAI/gpt-j-6b"), "GPTJForCausalLM": _HfExamplesInfo("EleutherAI/gpt-j-6b"),
...@@ -202,6 +209,16 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -202,6 +209,16 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Qwen2ForCausalLM": _HfExamplesInfo("Qwen/Qwen2-7B-Instruct", "Qwen2ForCausalLM": _HfExamplesInfo("Qwen/Qwen2-7B-Instruct",
extras={"2.5": "Qwen/Qwen2.5-7B-Instruct"}), # noqa: E501 extras={"2.5": "Qwen/Qwen2.5-7B-Instruct"}), # noqa: E501
"Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"), "Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"),
"Qwen3ForCausalLM": _HfExamplesInfo(
"Qwen/Qwen3-8B",
is_available_online=False,
min_transformers_version="4.51"
),
"Qwen3MoeForCausalLM": _HfExamplesInfo(
"Qwen/Qwen3-MoE-15B-A2B",
is_available_online=False,
min_transformers_version="4.51"
),
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b", "RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b",
is_available_online=False), is_available_online=False),
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b", # noqa: E501 "StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b", # noqa: E501
...@@ -277,12 +294,16 @@ _MULTIMODAL_EXAMPLE_MODELS = { ...@@ -277,12 +294,16 @@ _MULTIMODAL_EXAMPLE_MODELS = {
trust_remote_code=True, trust_remote_code=True,
hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501 hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501
"H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m", "H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m",
extras={"2b": "h2oai/h2ovl-mississippi-2b"}), # noqa: E501 extras={"2b": "h2oai/h2ovl-mississippi-2b"}, # noqa: E501
max_transformers_version="4.48", # noqa: E501
transformers_version_reason="HF model is not compatible."), # noqa: E501
"InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B", "InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B",
extras={"2B": "OpenGVLab/InternVL2-2B"}, # noqa: E501 extras={"2B": "OpenGVLab/InternVL2-2B"}, # noqa: E501
trust_remote_code=True), trust_remote_code=True),
"Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3", # noqa: E501 "Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3", # noqa: E501
{"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}), # noqa: E501 {"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}), # noqa: E501
"Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501
min_transformers_version="4.51"),
"LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf", "LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf",
extras={"mistral": "mistral-community/pixtral-12b", # noqa: E501 extras={"mistral": "mistral-community/pixtral-12b", # noqa: E501
"mistral-fp8": "nm-testing/pixtral-12b-FP8-dynamic"}), # noqa: E501 "mistral-fp8": "nm-testing/pixtral-12b-FP8-dynamic"}), # noqa: E501
...@@ -305,7 +326,7 @@ _MULTIMODAL_EXAMPLE_MODELS = { ...@@ -305,7 +326,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
extras={"fp8": "nm-testing/Mistral-Small-3.1-24B-Instruct-2503-FP8-dynamic"}), # noqa: E501 extras={"fp8": "nm-testing/Mistral-Small-3.1-24B-Instruct-2503-FP8-dynamic"}), # noqa: E501
"MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924", "MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924",
max_transformers_version="4.48", max_transformers_version="4.48",
transformers_version_reason="Use of private method which no longer exists.", # noqa: E501 transformers_version_reason="Incorrectly-detected `tensorflow` import.", # noqa: E501
extras={"olmo": "allenai/Molmo-7B-O-0924"}, # noqa: E501 extras={"olmo": "allenai/Molmo-7B-O-0924"}, # noqa: E501
trust_remote_code=True), trust_remote_code=True),
"NVLM_D": _HfExamplesInfo("nvidia/NVLM-D-72B", "NVLM_D": _HfExamplesInfo("nvidia/NVLM-D-72B",
...@@ -314,6 +335,8 @@ _MULTIMODAL_EXAMPLE_MODELS = { ...@@ -314,6 +335,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
extras={"v2": "google/paligemma2-3b-ft-docci-448"}), # noqa: E501 extras={"v2": "google/paligemma2-3b-ft-docci-448"}), # noqa: E501
"Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct", "Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct",
trust_remote_code=True, trust_remote_code=True,
max_transformers_version="4.48",
transformers_version_reason="Use of deprecated imports which have been removed.", # noqa: E501
extras={"phi3.5": "microsoft/Phi-3.5-vision-instruct"}), # noqa: E501 extras={"phi3.5": "microsoft/Phi-3.5-vision-instruct"}), # noqa: E501
"Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct", "Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct",
trust_remote_code=True), trust_remote_code=True),
...@@ -328,6 +351,7 @@ _MULTIMODAL_EXAMPLE_MODELS = { ...@@ -328,6 +351,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501 "Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501
min_transformers_version="4.49"), # noqa: E501 min_transformers_version="4.49"), # noqa: E501
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B"), "SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B"),
"SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501 "UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501
trust_remote_code=True), trust_remote_code=True),
# [Encoder-decoder] # [Encoder-decoder]
...@@ -351,6 +375,10 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { ...@@ -351,6 +375,10 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
"DeepSeekMTPModel": _HfExamplesInfo("luccafong/deepseek_mtp_main_random", "DeepSeekMTPModel": _HfExamplesInfo("luccafong/deepseek_mtp_main_random",
speculative_model="luccafong/deepseek_mtp_draft_random", # noqa: E501 speculative_model="luccafong/deepseek_mtp_draft_random", # noqa: E501
trust_remote_code=True), trust_remote_code=True),
"EagleLlamaForCausalLM": _HfExamplesInfo("yuhuili/EAGLE-LLaMA3-Instruct-8B",
trust_remote_code=True,
speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B",
tokenizer="meta-llama/Meta-Llama-3-8B-Instruct"), # noqa: E501
} }
_TRANSFORMERS_MODELS = { _TRANSFORMERS_MODELS = {
......
...@@ -7,6 +7,8 @@ from transformers import PretrainedConfig ...@@ -7,6 +7,8 @@ from transformers import PretrainedConfig
from vllm import LLM from vllm import LLM
from vllm.engine.llm_engine import LLMEngine as V0LLMEngine from vllm.engine.llm_engine import LLMEngine as V0LLMEngine
from vllm.utils import GiB_bytes
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
from vllm.v1.engine.core import EngineCore as V1EngineCore from vllm.v1.engine.core import EngineCore as V1EngineCore
from .registry import HF_EXAMPLE_MODELS from .registry import HF_EXAMPLE_MODELS
...@@ -42,14 +44,21 @@ def test_can_initialize(model_arch): ...@@ -42,14 +44,21 @@ def test_can_initialize(model_arch):
self.cache_config.num_gpu_blocks = 0 self.cache_config.num_gpu_blocks = 0
self.cache_config.num_cpu_blocks = 0 self.cache_config.num_cpu_blocks = 0
def _initalize_kv_caches_v1(self, vllm_config): def _initialize_kv_caches_v1(self, vllm_config):
# gpu_blocks (> 0), cpu_blocks kv_cache_specs = self.model_executor.get_kv_cache_specs()
return 1, 0 scheduler_kv_cache_config = get_kv_cache_config(
vllm_config,
kv_cache_specs[0],
20 * GiB_bytes,
)
# gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config
return 1, 0, scheduler_kv_cache_config
with (patch.object(V0LLMEngine, "_initialize_kv_caches", with (patch.object(V0LLMEngine, "_initialize_kv_caches",
_initialize_kv_caches_v0), _initialize_kv_caches_v0),
patch.object(V1EngineCore, "_initialize_kv_caches", patch.object(V1EngineCore, "_initialize_kv_caches",
_initalize_kv_caches_v1)): _initialize_kv_caches_v1)):
LLM( LLM(
model_info.default, model_info.default,
tokenizer=model_info.tokenizer, tokenizer=model_info.tokenizer,
......
...@@ -90,6 +90,7 @@ def test_oot_registration_multimodal( ...@@ -90,6 +90,7 @@ def test_oot_registration_multimodal(
max_model_len=4096, max_model_len=4096,
enforce_eager=True, enforce_eager=True,
limit_mm_per_prompt={"image": 1}) limit_mm_per_prompt={"image": 1})
first_token = llm.get_tokenizer().decode(0) first_token = llm.get_tokenizer().decode(0)
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
......
...@@ -255,6 +255,7 @@ def build_model_context( ...@@ -255,6 +255,7 @@ def build_model_context(
model_id: str, model_id: str,
task: TaskOption = "auto", task: TaskOption = "auto",
dtype: Union[str, torch.dtype] = "auto", dtype: Union[str, torch.dtype] = "auto",
model_config_kwargs: Optional[dict[str, Any]] = None,
mm_processor_kwargs: Optional[dict[str, Any]] = None, mm_processor_kwargs: Optional[dict[str, Any]] = None,
limit_mm_per_prompt: Optional[dict[str, int]] = None, limit_mm_per_prompt: Optional[dict[str, int]] = None,
disable_mm_preprocessor_cache: bool = True, disable_mm_preprocessor_cache: bool = True,
...@@ -274,6 +275,7 @@ def build_model_context( ...@@ -274,6 +275,7 @@ def build_model_context(
model_info.check_available_online(on_fail="skip") model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip") model_info.check_transformers_version(on_fail="skip")
model_config_kwargs = model_config_kwargs or {}
model_config = ModelConfig( model_config = ModelConfig(
model_id, model_id,
task=task, task=task,
...@@ -286,5 +288,6 @@ def build_model_context( ...@@ -286,5 +288,6 @@ def build_model_context(
limit_mm_per_prompt=limit_mm_per_prompt, limit_mm_per_prompt=limit_mm_per_prompt,
disable_mm_preprocessor_cache=disable_mm_preprocessor_cache, disable_mm_preprocessor_cache=disable_mm_preprocessor_cache,
hf_overrides=model_info.hf_overrides, hf_overrides=model_info.hf_overrides,
**model_config_kwargs,
) )
return InputContext(model_config) return InputContext(model_config)
...@@ -785,6 +785,7 @@ def test_find_update_tokens( ...@@ -785,6 +785,7 @@ def test_find_update_tokens(
item_idx=0, item_idx=0,
start_idx=6, start_idx=6,
tokens=[32000, 32000], tokens=[32000, 32000],
is_embed=None,
), ),
], ],
"pattern_4": [ "pattern_4": [
...@@ -793,6 +794,7 @@ def test_find_update_tokens( ...@@ -793,6 +794,7 @@ def test_find_update_tokens(
item_idx=0, item_idx=0,
start_idx=3, start_idx=3,
tokens=[32000], tokens=[32000],
is_embed=None,
), ),
], ],
} }
...@@ -807,12 +809,14 @@ def test_find_update_tokens( ...@@ -807,12 +809,14 @@ def test_find_update_tokens(
item_idx=0, item_idx=0,
start_idx=1, start_idx=1,
tokens=[32000, 32000], tokens=[32000, 32000],
is_embed=None,
), ),
PlaceholderFeaturesInfo( PlaceholderFeaturesInfo(
modality="pattern_1", modality="pattern_1",
item_idx=1, item_idx=1,
start_idx=5, start_idx=5,
tokens=[32000, 32000], tokens=[32000, 32000],
is_embed=None,
), ),
], ],
"pattern_3": [ "pattern_3": [
...@@ -821,6 +825,7 @@ def test_find_update_tokens( ...@@ -821,6 +825,7 @@ def test_find_update_tokens(
item_idx=0, item_idx=0,
start_idx=7, start_idx=7,
tokens=[1550, 918, 1550], tokens=[1550, 918, 1550],
is_embed=None,
), ),
], ],
# No match for pattern_4 as it has lower priority than pattern_1 # No match for pattern_4 as it has lower priority than pattern_1
...@@ -835,12 +840,14 @@ def test_find_update_tokens( ...@@ -835,12 +840,14 @@ def test_find_update_tokens(
item_idx=0, item_idx=0,
start_idx=1, start_idx=1,
tokens=[32000, 32000], tokens=[32000, 32000],
is_embed=None,
), ),
PlaceholderFeaturesInfo( PlaceholderFeaturesInfo(
modality="pattern_1", modality="pattern_1",
item_idx=1, item_idx=1,
start_idx=3, start_idx=3,
tokens=[32000, 32000], tokens=[32000, 32000],
is_embed=None,
), ),
], ],
"pattern_4": [ "pattern_4": [
...@@ -849,6 +856,7 @@ def test_find_update_tokens( ...@@ -849,6 +856,7 @@ def test_find_update_tokens(
item_idx=0, item_idx=0,
start_idx=5, start_idx=5,
tokens=[32000], tokens=[32000],
is_embed=None,
), ),
], ],
"pattern_3": [ "pattern_3": [
...@@ -857,6 +865,7 @@ def test_find_update_tokens( ...@@ -857,6 +865,7 @@ def test_find_update_tokens(
item_idx=0, item_idx=0,
start_idx=6, start_idx=6,
tokens=[1550, 918, 1550], tokens=[1550, 918, 1550],
is_embed=None,
), ),
], ],
} }
...@@ -963,10 +972,13 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid): ...@@ -963,10 +972,13 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
if is_valid: if is_valid:
exc_ctx = nullcontext() exc_ctx = nullcontext()
else: else:
exc_ctx = pytest.raises(ValueError, match="this model only supports") exc_ctx = pytest.raises(ValueError, match="The model only supports")
with exc_ctx: with exc_ctx:
profiler.get_decoder_dummy_data(model_config.max_model_len) profiler.get_decoder_dummy_data(
model_config.max_model_len,
mm_counts=limit_mm_per_prompt,
)
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) @pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
......
...@@ -41,7 +41,7 @@ def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, ...@@ -41,7 +41,7 @@ def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
hf_model_kwargs = {"load_in_4bit": True} hf_model_kwargs = {"load_in_4bit": True}
validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1], validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1],
model_name, hf_model_kwargs) model_name, False, hf_model_kwargs)
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), @pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
...@@ -53,7 +53,7 @@ def test_load_pre_quant_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, ...@@ -53,7 +53,7 @@ def test_load_pre_quant_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
model_name, description) -> None: model_name, description) -> None:
validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1], validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1],
model_name) model_name, True)
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), @pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
...@@ -65,7 +65,7 @@ def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts, ...@@ -65,7 +65,7 @@ def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts,
model_name, description) -> None: model_name, description) -> None:
validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1], validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1],
model_name) model_name, True)
@pytest.mark.skipif(torch.cuda.device_count() < 2, @pytest.mark.skipif(torch.cuda.device_count() < 2,
...@@ -82,6 +82,7 @@ def test_load_tp_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, ...@@ -82,6 +82,7 @@ def test_load_tp_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
vllm_runner, vllm_runner,
example_prompts[:1], example_prompts[:1],
model_name, model_name,
False,
hf_model_kwargs, hf_model_kwargs,
vllm_tp_size=2) vllm_tp_size=2)
...@@ -128,13 +129,14 @@ def validate_generated_texts(hf_runner, ...@@ -128,13 +129,14 @@ def validate_generated_texts(hf_runner,
vllm_runner, vllm_runner,
prompts, prompts,
model_name, model_name,
pre_quant=False,
hf_model_kwargs=None, hf_model_kwargs=None,
vllm_tp_size=1): vllm_tp_size=1):
# NOTE: run vLLM first, as it requires a clean process # NOTE: run vLLM first, as it requires a clean process
# when using distributed inference # when using distributed inference
with vllm_runner(model_name, with vllm_runner(model_name,
quantization='bitsandbytes', quantization=None if pre_quant else 'bitsandbytes',
tensor_parallel_size=vllm_tp_size, tensor_parallel_size=vllm_tp_size,
enforce_eager=False) as llm: enforce_eager=False) as llm:
vllm_outputs = llm.generate_greedy(prompts, 8) vllm_outputs = llm.generate_greedy(prompts, 8)
......
...@@ -4,17 +4,28 @@ ...@@ -4,17 +4,28 @@
Run `pytest tests/quantization/test_quark.py`. Run `pytest tests/quantization/test_quark.py`.
""" """
import torch import pytest
from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501 from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
QuarkLinearMethod, QuarkW8A8Fp8) QuarkLinearMethod, QuarkW8A8Fp8, QuarkW8A8Int8)
from vllm.platforms import current_platform
def test_quark_fp8(vllm_runner, monkeypatch): @pytest.fixture(scope="function", autouse=True)
# vllm_runner.apply_model() relies on V0 internals. def use_v0_only(monkeypatch):
monkeypatch.setenv("VLLM_USE_V1", "0") """
This module relies on V0 internals, so set VLLM_USE_V1=0.
"""
monkeypatch.setenv('VLLM_USE_V1', '0')
@pytest.mark.parametrize('kv_cache_dtype', ['auto', 'fp8'])
@pytest.mark.parametrize('tp', [1])
def test_quark_fp8_w_per_tensor_a_per_tensor(vllm_runner, kv_cache_dtype, tp):
model_path = "amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test" model_path = "amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test"
with vllm_runner(model_path) as llm: with vllm_runner(model_path,
kv_cache_dtype=kv_cache_dtype,
tensor_parallel_size=tp) as llm:
def check_model(model): def check_model(model):
layer = model.model.layers[0] layer = model.model.layers[0]
...@@ -26,11 +37,29 @@ def test_quark_fp8(vllm_runner, monkeypatch): ...@@ -26,11 +37,29 @@ def test_quark_fp8(vllm_runner, monkeypatch):
if isinstance(qkv_proj.scheme, QuarkW8A8Fp8): if isinstance(qkv_proj.scheme, QuarkW8A8Fp8):
assert len(qkv_proj.input_scale.shape) == 0 assert len(qkv_proj.input_scale.shape) == 0
assert qkv_proj.weight.dtype is torch.float8_e4m3fn assert qkv_proj.weight.dtype is current_platform.fp8_dtype()
#assert qkv_proj.weight.dtype is torch.float8_e4m3fnuz
assert len(qkv_proj.weight_scale.shape) == 0 assert len(qkv_proj.weight_scale.shape) == 0
llm.apply_model(check_model) llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=20) output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output assert output
@pytest.mark.parametrize('tp', [1])
def test_quark_int8_w_per_tensor_a_per_tensor(vllm_runner, tp):
model_path = "amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test"
with vllm_runner(model_path, tensor_parallel_size=tp) as llm:
def check_model(model):
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
assert isinstance(qkv_proj.scheme, QuarkW8A8Int8)
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output
# SPDX-License-Identifier: Apache-2.0
import importlib.metadata
import importlib.util
import pytest
DTYPE = ["bfloat16"]
TORCHAO_AVAILABLE = importlib.util.find_spec("torchao") is not None
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
def test_pre_quantized_model(vllm_runner):
with vllm_runner("drisspg/float8_dynamic_act_float8_weight-opt-125m",
quantization="torchao",
dtype="bfloat16",
enforce_eager=True) as llm:
output = llm.generate_greedy(["The capital of France is"],
max_tokens=32)
assert output
print(output)
if __name__ == "__main__":
pytest.main([__file__])
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Tests for the SamplingParams class. """Tests for the SamplingParams class.
""" """
import pytest
from vllm import SamplingParams from vllm import SamplingParams
from vllm.config import ModelConfig
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
MODEL_NAME = "Qwen/Qwen1.5-7B"
def test_max_tokens_none(): def test_max_tokens_none():
...@@ -9,6 +16,74 @@ def test_max_tokens_none(): ...@@ -9,6 +16,74 @@ def test_max_tokens_none():
SamplingParams(temperature=0.01, top_p=0.1, max_tokens=None) SamplingParams(temperature=0.01, top_p=0.1, max_tokens=None)
if __name__ == "__main__": @pytest.fixture(scope="module")
import pytest def model_config():
pytest.main([__file__]) return ModelConfig(
MODEL_NAME,
task="auto",
tokenizer=MODEL_NAME,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
revision=None,
)
@pytest.fixture(scope="module")
def default_max_tokens():
return 4096
def test_sampling_params_from_request_with_no_guided_decoding_backend(
model_config, default_max_tokens):
# guided_decoding_backend is not present at request level
request = ChatCompletionRequest.model_validate({
'messages': [{
'role': 'user',
'content': 'Hello'
}],
'model':
MODEL_NAME,
'response_format': {
'type': 'json_object',
},
})
sampling_params = request.to_sampling_params(
default_max_tokens,
model_config.logits_processor_pattern,
)
# we do not expect any backend to be present and the default
# guided_decoding_backend at engine level will be used.
assert sampling_params.guided_decoding.backend is None
@pytest.mark.parametrize("request_level_guided_decoding_backend,expected",
[("xgrammar", "xgrammar"),
("lm-format-enforcer", "lm-format-enforcer"),
("outlines", "outlines")])
def test_sampling_params_from_request_with_guided_decoding_backend(
request_level_guided_decoding_backend: str, expected: str,
model_config, default_max_tokens):
request = ChatCompletionRequest.model_validate({
'messages': [{
'role': 'user',
'content': 'Hello'
}],
'model':
MODEL_NAME,
'response_format': {
'type': 'json_object',
},
'guided_decoding_backend':
request_level_guided_decoding_backend,
})
sampling_params = request.to_sampling_params(
default_max_tokens,
model_config.logits_processor_pattern,
)
# backend correctly identified in resulting sampling_params
assert sampling_params.guided_decoding.backend == expected
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment