Commit 31330101 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.4' into v0.8.4-dev

parents e8933c34 dc1b4a6f
...@@ -15,3 +15,6 @@ torchaudio==2.6.0; platform_machine == "ppc64le" ...@@ -15,3 +15,6 @@ torchaudio==2.6.0; platform_machine == "ppc64le"
torchvision; platform_machine != "ppc64le" and platform_machine != "s390x" torchvision; platform_machine != "ppc64le" and platform_machine != "s390x"
torchvision==0.21.0; platform_machine == "ppc64le" torchvision==0.21.0; platform_machine == "ppc64le"
datasets # for benchmark scripts datasets # for benchmark scripts
# cpu cannot use triton 3.3.0
triton==3.2.0; platform_machine != "ppc64le"
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
-r common.txt -r common.txt
numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding
numba == 0.61; python_version > '3.9' numba == 0.61.2; python_version > '3.9'
# Dependencies for NVIDIA GPUs # Dependencies for NVIDIA GPUs
ray[cgraph]>=2.43.0, !=2.44.* # Ray Compiled Graph, required for pipeline parallelism in V1. ray[cgraph]>=2.43.0, !=2.44.* # Ray Compiled Graph, required for pipeline parallelism in V1.
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
ray ray
triton==3.1.0 triton==3.1.0
pandas pandas
numpy==1.26.4
tabulate tabulate
setuptools>=61 setuptools>=61
setuptools-scm>=8 setuptools-scm>=8
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
-r common.txt -r common.txt
numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding
numba == 0.61; python_version > '3.9' numba == 0.61.2; python_version > '3.9'
# Dependencies for hcus # Dependencies for hcus
awscli awscli
......
...@@ -5,6 +5,7 @@ pytest-forked ...@@ -5,6 +5,7 @@ pytest-forked
pytest-asyncio pytest-asyncio
pytest-rerunfailures pytest-rerunfailures
pytest-shard pytest-shard
pytest-timeout
# testing utils # testing utils
awscli awscli
...@@ -27,10 +28,11 @@ torchvision==0.21.0 ...@@ -27,10 +28,11 @@ torchvision==0.21.0
transformers_stream_generator # required for qwen-vl test transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test matplotlib # required for qwen-vl test
mistral_common[opencv] >= 1.5.4 # required for pixtral test mistral_common[opencv] >= 1.5.4 # required for pixtral test
num2words # required for smolvlm test
opencv-python-headless >= 4.11.0 # required for video test opencv-python-headless >= 4.11.0 # required for video test
datamodel_code_generator # required for minicpm3 test datamodel_code_generator # required for minicpm3 test
lm-eval[api]==0.4.8 # required for model evaluation test lm-eval[api]==0.4.8 # required for model evaluation test
transformers==4.51.0 transformers==4.51.1
huggingface-hub[hf_xet]>=0.30.0 # Required for Xet downloads. huggingface-hub[hf_xet]>=0.30.0 # Required for Xet downloads.
# quantization # quantization
bitsandbytes>=0.45.3 bitsandbytes>=0.45.3
...@@ -40,7 +42,7 @@ genai_perf==0.0.8 ...@@ -40,7 +42,7 @@ genai_perf==0.0.8
tritonclient==2.51.0 tritonclient==2.51.0
numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding
numba == 0.61; python_version > '3.9' numba == 0.61.2; python_version > '3.9'
numpy numpy
runai-model-streamer==0.11.0 runai-model-streamer==0.11.0
runai-model-streamer-s3==0.11.0 runai-model-streamer-s3==0.11.0
......
...@@ -101,6 +101,8 @@ dill==0.3.8 ...@@ -101,6 +101,8 @@ dill==0.3.8
# multiprocess # multiprocess
dnspython==2.7.0 dnspython==2.7.0
# via email-validator # via email-validator
docopt==0.6.2
# via num2words
docutils==0.16 docutils==0.16
# via awscli # via awscli
einops==0.8.0 einops==0.8.0
...@@ -263,7 +265,9 @@ networkx==3.2.1 ...@@ -263,7 +265,9 @@ networkx==3.2.1
# via torch # via torch
nltk==3.9.1 nltk==3.9.1
# via rouge-score # via rouge-score
numba==0.61.0 num2words==0.5.14
# via -r requirements/test.in
numba==0.61.2
# via # via
# -r requirements/test.in # -r requirements/test.in
# librosa # librosa
...@@ -444,6 +448,7 @@ pytest==8.3.3 ...@@ -444,6 +448,7 @@ pytest==8.3.3
# pytest-mock # pytest-mock
# pytest-rerunfailures # pytest-rerunfailures
# pytest-shard # pytest-shard
# pytest-timeout
pytest-asyncio==0.24.0 pytest-asyncio==0.24.0
# via -r requirements/test.in # via -r requirements/test.in
pytest-forked==1.6.0 pytest-forked==1.6.0
...@@ -454,6 +459,8 @@ pytest-rerunfailures==14.0 ...@@ -454,6 +459,8 @@ pytest-rerunfailures==14.0
# via -r requirements/test.in # via -r requirements/test.in
pytest-shard==0.1.2 pytest-shard==0.1.2
# via -r requirements/test.in # via -r requirements/test.in
pytest-timeout==2.3.1
# via -r requirements/test.in
python-dateutil==2.9.0.post0 python-dateutil==2.9.0.post0
# via # via
# botocore # botocore
...@@ -645,7 +652,7 @@ tqdm==4.66.6 ...@@ -645,7 +652,7 @@ tqdm==4.66.6
# transformers # transformers
tqdm-multiprocess==0.0.11 tqdm-multiprocess==0.0.11
# via lm-eval # via lm-eval
transformers==4.51.0 transformers==4.51.1
# via # via
# -r requirements/test.in # -r requirements/test.in
# genai-perf # genai-perf
......
...@@ -17,10 +17,10 @@ ray[data] ...@@ -17,10 +17,10 @@ ray[data]
--find-links https://storage.googleapis.com/libtpu-releases/index.html --find-links https://storage.googleapis.com/libtpu-releases/index.html
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250403-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250408-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250403-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250408-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250403-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250408-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250403-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250403-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250403-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
...@@ -563,9 +563,9 @@ def get_version_add(sha: Optional[str] = None) -> str: ...@@ -563,9 +563,9 @@ def get_version_add(sha: Optional[str] = None) -> str:
new_version_content = f""" new_version_content = f"""
try: try:
__version__ = "0.8.3" __version__ = "0.8.4"
__version_tuple__ = (0, 8, 3) __version_tuple__ = (0, 8, 4)
__hcu_version__ = f'0.8.3+{version}' __hcu_version__ = f'0.8.4+{version}'
from vllm.version import __version__, __version_tuple__, __hcu_version__ from vllm.version import __version__, __version_tuple__, __hcu_version__
except Exception as e: except Exception as e:
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Union from typing import Any, Optional, Union
import pytest import pytest
import torch import torch
...@@ -15,7 +15,7 @@ from vllm.platforms import current_platform ...@@ -15,7 +15,7 @@ from vllm.platforms import current_platform
from ..utils import create_new_process_for_each_test from ..utils import create_new_process_for_each_test
def models_list(all: bool): def models_list(*, all: bool = True, keywords: Optional[list[str]] = None):
TEST_MODELS: list[tuple[str, dict[str, Any]]] = [ TEST_MODELS: list[tuple[str, dict[str, Any]]] = [
("facebook/opt-125m", {}), ("facebook/opt-125m", {}),
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", { ("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
...@@ -32,47 +32,50 @@ def models_list(all: bool): ...@@ -32,47 +32,50 @@ def models_list(all: bool):
("meta-llama/Llama-3.2-1B-Instruct", {}), ("meta-llama/Llama-3.2-1B-Instruct", {}),
] ]
if not all: if all:
return TEST_MODELS if is_quant_method_supported("aqlm"):
TEST_MODELS.append(("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", {
"quantization": "aqlm"
}))
# TODO: figure out why this fails.
if False and is_quant_method_supported("gguf"): # noqa: SIM223
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", {
"quantization": "gguf"
}))
if is_quant_method_supported("gptq"):
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {
"quantization": "gptq"
}))
if is_quant_method_supported("gptq_marlin"):
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", {
"quantization": "gptq_marlin"
}))
if is_quant_method_supported("aqlm"): if is_quant_method_supported("gptq_marlin_24"):
TEST_MODELS.append(("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", { TEST_MODELS.append(("alexm-nm/tinyllama-24-marlin24-4bit-g128", {
"quantization": "aqlm" "quantization": "gptq_marlin_24"
}))
# TODO: figure out why this fails.
if False and is_quant_method_supported("gguf"): # noqa: SIM223
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", {
"quantization": "gguf"
}))
if is_quant_method_supported("gptq"):
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {
"quantization": "gptq"
}))
if is_quant_method_supported("gptq_marlin"):
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", {
"quantization": "gptq_marlin"
}))
if is_quant_method_supported("gptq_marlin_24"):
TEST_MODELS.append(("alexm-nm/tinyllama-24-marlin24-4bit-g128", {
"quantization": "gptq_marlin_24"
}))
if is_quant_method_supported("marlin"):
TEST_MODELS.append(
("robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin", {
"quantization": "marlin"
})) }))
if not current_platform.is_rocm() and is_quant_method_supported("awq"): if is_quant_method_supported("marlin"):
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", { TEST_MODELS.append(
"quantization": "AWQ" ("robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin", {
})) "quantization": "marlin"
}))
return TEST_MODELS if not current_platform.is_rocm() and is_quant_method_supported("awq"):
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {
"quantization": "AWQ"
}))
if keywords is None:
return TEST_MODELS
# filter by keywords
pred = lambda model: any(keyword in model[0] for keyword in keywords)
return list(filter(pred, TEST_MODELS))
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -96,20 +99,30 @@ def test_full_graph( ...@@ -96,20 +99,30 @@ def test_full_graph(
run_model(optimization_level, model, model_kwargs) run_model(optimization_level, model, model_kwargs)
PassConfig = CompilationConfig.PassConfig
# TODO(luka) add other supported compilation config scenarios here # TODO(luka) add other supported compilation config scenarios here
@pytest.mark.parametrize( @pytest.mark.parametrize(
"compilation_config", "compilation_config, model_info",
# additional compile sizes
[ [
CompilationConfig(level=CompilationLevel.PIECEWISE, # additional compile sizes, only some of the models
compile_sizes=[1, 2]) (CompilationConfig(level=CompilationLevel.PIECEWISE,
compile_sizes=[1, 2]), model)
for model in models_list(all=False)
] + [
# RMSNorm + quant fusion, only 8-bit quant models
(CompilationConfig(level=CompilationLevel.PIECEWISE,
custom_ops=["+rms_norm"],
pass_config=PassConfig(enable_fusion=True,
enable_noop=True)), model)
for model in models_list(keywords=["FP8-dynamic", "quantized.w8a8"])
]) ])
# only test some of the models # only test some of the models
@pytest.mark.parametrize("model_info", models_list(all=False))
@create_new_process_for_each_test() @create_new_process_for_each_test()
def test_custom_compile_config( def test_custom_compile_config(
model_info: tuple[str, dict[str, Any]],
compilation_config: CompilationConfig, compilation_config: CompilationConfig,
model_info: tuple[str, dict[str, Any]],
): ):
model, model_kwargs = model_info model, model_kwargs = model_info
print(f"MODEL={model}") print(f"MODEL={model}")
......
...@@ -44,12 +44,17 @@ class TestModel(torch.nn.Module): ...@@ -44,12 +44,17 @@ class TestModel(torch.nn.Module):
resid = torch.sqrt(x) resid = torch.sqrt(x)
y = self.norm[0](x) y = self.norm[0](x)
x2 = self.fp8_linear.apply(y, self.w[0], self.wscale[0], self.scale[0]) x2 = self.fp8_linear.apply(y,
self.w[0],
self.wscale[0],
input_scale=self.scale[0])
# make sure resid is used for replacement to work # make sure resid is used for replacement to work
y2, resid = self.norm[1](x2, resid) y2, resid = self.norm[1](x2, resid)
x3 = self.fp8_linear.apply(y2, self.w[1], self.wscale[1], x3 = self.fp8_linear.apply(y2,
self.scale[1]) self.w[1],
self.wscale[1],
input_scale=self.scale[1])
y3, resid = self.norm[2](x3, resid) # use resid here y3, resid = self.norm[2](x3, resid) # use resid here
return y3 return y3
......
...@@ -676,8 +676,9 @@ class HfRunner: ...@@ -676,8 +676,9 @@ class HfRunner:
return [(output_ids, output_str, output_logprobs) return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs] for output_ids, output_str, output_logprobs in outputs]
def encode(self, prompts: list[str]) -> list[list[torch.Tensor]]: def encode(self, prompts: list[str], *args,
return self.model.encode(prompts) **kwargs) -> list[list[torch.Tensor]]:
return self.model.encode(prompts, *args, **kwargs)
def predict(self, prompts: list[list[str]]) -> torch.Tensor: def predict(self, prompts: list[list[str]]) -> torch.Tensor:
return self.model.predict(prompts, convert_to_tensor=True) return self.model.predict(prompts, convert_to_tensor=True)
...@@ -964,19 +965,19 @@ class VllmRunner: ...@@ -964,19 +965,19 @@ class VllmRunner:
req_outputs = self.model.classify(prompts) req_outputs = self.model.classify(prompts)
return [req_output.outputs.probs for req_output in req_outputs] return [req_output.outputs.probs for req_output in req_outputs]
def encode( def encode(self,
self, prompts: list[str],
prompts: list[str], images: Optional[PromptImageInput] = None,
images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None,
videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None,
audios: Optional[PromptAudioInput] = None, *args,
) -> list[list[float]]: **kwargs) -> list[list[float]]:
inputs = self.get_inputs(prompts, inputs = self.get_inputs(prompts,
images=images, images=images,
videos=videos, videos=videos,
audios=audios) audios=audios)
req_outputs = self.model.embed(inputs) req_outputs = self.model.embed(inputs, *args, **kwargs)
return [req_output.outputs.embedding for req_output in req_outputs] return [req_output.outputs.embedding for req_output in req_outputs]
def score( def score(
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from argparse import ArgumentTypeError from argparse import ArgumentError, ArgumentTypeError
import pytest import pytest
...@@ -142,3 +142,39 @@ def test_composite_arg_parser(arg, expected, option): ...@@ -142,3 +142,39 @@ def test_composite_arg_parser(arg, expected, option):
else: else:
args = parser.parse_args([f"--{option}", arg]) args = parser.parse_args([f"--{option}", arg])
assert getattr(args, option.replace("-", "_")) == expected assert getattr(args, option.replace("-", "_")) == expected
def test_human_readable_model_len():
# `exit_on_error` disabled to test invalid values below
parser = EngineArgs.add_cli_args(
FlexibleArgumentParser(exit_on_error=False))
args = parser.parse_args([])
assert args.max_model_len is None
args = parser.parse_args(["--max-model-len", "1024"])
assert args.max_model_len == 1024
# Lower
args = parser.parse_args(["--max-model-len", "1m"])
assert args.max_model_len == 1_000_000
args = parser.parse_args(["--max-model-len", "10k"])
assert args.max_model_len == 10_000
# Capital
args = parser.parse_args(["--max-model-len", "3K"])
assert args.max_model_len == 1024 * 3
args = parser.parse_args(["--max-model-len", "10M"])
assert args.max_model_len == 2**20 * 10
# Decimal values
args = parser.parse_args(["--max-model-len", "10.2k"])
assert args.max_model_len == 10200
# ..truncated to the nearest int
args = parser.parse_args(["--max-model-len", "10.212345k"])
assert args.max_model_len == 10212
# Invalid (do not allow decimals with binary multipliers)
for invalid in ["1a", "pwd", "10.24", "1.23M"]:
with pytest.raises(ArgumentError):
args = parser.parse_args(["--max-model-len", invalid])
...@@ -19,7 +19,8 @@ models = [os.path.join(models_path_prefix, "llava-hf/llava-1.5-7b-hf")] ...@@ -19,7 +19,8 @@ models = [os.path.join(models_path_prefix, "llava-hf/llava-1.5-7b-hf")]
def test_context_length_too_short(vllm_runner, image_assets, model): def test_context_length_too_short(vllm_runner, image_assets, model):
images = [asset.pil_image for asset in image_assets] images = [asset.pil_image for asset in image_assets]
with pytest.raises(ValueError, match="too long to fit into the model"): with pytest.raises(ValueError,
match="longer than the maximum model length"):
vllm_model = vllm_runner( vllm_model = vllm_runner(
model, model,
max_model_len=128, # LLaVA has a feature size of 576 max_model_len=128, # LLaVA has a feature size of 576
......
...@@ -3,10 +3,12 @@ ...@@ -3,10 +3,12 @@
import json import json
import re import re
import weakref import weakref
from enum import Enum
import jsonschema import jsonschema
import pytest import pytest
import os import os
from pydantic import BaseModel
from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed import cleanup_dist_env_and_memory
from vllm.entrypoints.llm import LLM from vllm.entrypoints.llm import LLM
...@@ -287,15 +289,26 @@ def test_validation_against_both_guided_decoding_options(sample_regex, llm): ...@@ -287,15 +289,26 @@ def test_validation_against_both_guided_decoding_options(sample_regex, llm):
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
def test_disable_guided_decoding_fallback(sample_regex, llm): def test_disable_guided_decoding_fallback(sample_regex, llm):
# see has_xgrammar_unsupported_json_features()
unsupported_json = {
"type": "object",
"properties": {
"example": {
"type": "string",
"minLength": 5 # unsupported by xgrammar
}
}
}
sampling_params = SamplingParams(temperature=0.8, sampling_params = SamplingParams(temperature=0.8,
top_p=0.95, top_p=0.95,
guided_decoding=GuidedDecodingParams( guided_decoding=GuidedDecodingParams(
regex=sample_regex, json=unsupported_json,
backend="xgrammar:no-fallback")) backend="xgrammar:no-fallback"))
with pytest.raises( with pytest.raises(
ValueError, ValueError,
match="xgrammar does not support regex guided decoding"): match="xgrammar does not support advanced JSON schema features "
"like enums, patterns or numeric ranges."):
llm.generate(prompts="This should fail", llm.generate(prompts="This should fail",
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=True) use_tqdm=True)
...@@ -333,3 +346,44 @@ def test_guided_json_object(llm, guided_decoding_backend: str): ...@@ -333,3 +346,44 @@ def test_guided_json_object(llm, guided_decoding_backend: str):
# Parse to verify it is valid JSON # Parse to verify it is valid JSON
parsed_json = json.loads(generated_text) parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict) assert isinstance(parsed_json, dict)
class CarType(str, Enum):
sedan = "sedan"
suv = "SUV"
truck = "Truck"
coupe = "Coupe"
class CarDescription(BaseModel):
brand: str
model: str
car_type: CarType
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str):
json_schema = CarDescription.model_json_schema()
sampling_params = SamplingParams(temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=json_schema,
backend=guided_decoding_backend))
outputs = llm.generate(
prompts="Generate a JSON with the brand, model and car_type of"
"the most iconic car from the 90's",
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=json_schema)
\ No newline at end of file
...@@ -16,8 +16,8 @@ def v1(run_with_both_engines): ...@@ -16,8 +16,8 @@ def v1(run_with_both_engines):
def test_empty_prompt(): def test_empty_prompt():
llm = LLM(model=os.path.join(models_path_prefix, "openai-community/gpt2"), enforce_eager=True) llm = LLM(model=os.path.join(models_path_prefix, "openai-community/gpt2"),, enforce_eager=True)
with pytest.raises(ValueError, match='Prompt cannot be empty'): with pytest.raises(ValueError, match='decoder prompt cannot be empty'):
llm.generate([""]) llm.generate([""])
......
...@@ -12,8 +12,10 @@ from ...utils import RemoteOpenAIServer, models_path_prefix ...@@ -12,8 +12,10 @@ from ...utils import RemoteOpenAIServer, models_path_prefix
MODEL_NAME = os.path.join(models_path_prefix, "fixie-ai/ultravox-v0_5-llama-3_2-1b") MODEL_NAME = os.path.join(models_path_prefix, "fixie-ai/ultravox-v0_5-llama-3_2-1b")
TEST_AUDIO_URLS = [ TEST_AUDIO_URLS = [
"http://localhost:8000/winning_call.ogg" AudioAsset("winning_call").url,
AudioAsset("mary_had_lamb").url,
] ]
MAXIMUM_AUDIOS = 2
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(): def server():
...@@ -24,6 +26,8 @@ def server(): ...@@ -24,6 +26,8 @@ def server():
"5", "5",
"--enforce-eager", "--enforce-eager",
"--trust-remote-code", "--trust-remote-code",
"--limit-mm-per-prompt",
f"audio={MAXIMUM_AUDIOS}",
] ]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
...@@ -46,7 +50,7 @@ def base64_encoded_audio() -> dict[str, str]: ...@@ -46,7 +50,7 @@ def base64_encoded_audio() -> dict[str, str]:
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS) @pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]])
async def test_single_chat_session_audio(client: openai.AsyncOpenAI, async def test_single_chat_session_audio(client: openai.AsyncOpenAI,
model_name: str, audio_url: str): model_name: str, audio_url: str):
messages = [{ messages = [{
...@@ -100,7 +104,7 @@ async def test_single_chat_session_audio(client: openai.AsyncOpenAI, ...@@ -100,7 +104,7 @@ async def test_single_chat_session_audio(client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS) @pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]])
async def test_single_chat_session_audio_base64encoded( async def test_single_chat_session_audio_base64encoded(
client: openai.AsyncOpenAI, model_name: str, audio_url: str, client: openai.AsyncOpenAI, model_name: str, audio_url: str,
base64_encoded_audio: dict[str, str]): base64_encoded_audio: dict[str, str]):
...@@ -158,7 +162,7 @@ async def test_single_chat_session_audio_base64encoded( ...@@ -158,7 +162,7 @@ async def test_single_chat_session_audio_base64encoded(
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS) @pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]])
async def test_single_chat_session_input_audio( async def test_single_chat_session_input_audio(
client: openai.AsyncOpenAI, model_name: str, audio_url: str, client: openai.AsyncOpenAI, model_name: str, audio_url: str,
base64_encoded_audio: dict[str, str]): base64_encoded_audio: dict[str, str]):
...@@ -330,28 +334,21 @@ async def test_chat_streaming_input_audio(client: openai.AsyncOpenAI, ...@@ -330,28 +334,21 @@ async def test_chat_streaming_input_audio(client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS) @pytest.mark.parametrize(
"audio_urls", [TEST_AUDIO_URLS, TEST_AUDIO_URLS + [TEST_AUDIO_URLS[0]]])
async def test_multi_audio_input(client: openai.AsyncOpenAI, model_name: str, async def test_multi_audio_input(client: openai.AsyncOpenAI, model_name: str,
audio_url: str, audio_urls: list[str]):
base64_encoded_audio: dict[str, str]):
messages = [{ messages = [{
"role": "role":
"user", "user",
"content": [ "content": [
{ *({
"type": "audio_url", "type": "audio_url",
"audio_url": { "audio_url": {
"url": audio_url "url": audio_url
} }
}, } for audio_url in audio_urls),
{
"type": "input_audio",
"input_audio": {
"data": base64_encoded_audio[audio_url],
"format": "wav"
}
},
{ {
"type": "text", "type": "text",
"text": "What's happening in this audio?" "text": "What's happening in this audio?"
...@@ -359,20 +356,30 @@ async def test_multi_audio_input(client: openai.AsyncOpenAI, model_name: str, ...@@ -359,20 +356,30 @@ async def test_multi_audio_input(client: openai.AsyncOpenAI, model_name: str,
], ],
}] }]
with pytest.raises(openai.BadRequestError): # test multi-audio input if len(audio_urls) > MAXIMUM_AUDIOS:
await client.chat.completions.create( with pytest.raises(openai.BadRequestError): # test multi-audio input
await client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=10,
temperature=0.0,
)
# the server should still work afterwards
completion = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
completion = completion.choices[0].text
assert completion is not None and len(completion) >= 0
else:
chat_completion = await client.chat.completions.create(
model=model_name, model=model_name,
messages=messages, messages=messages,
max_completion_tokens=10, max_completion_tokens=10,
temperature=0.0, temperature=0.0,
) )
message = chat_completion.choices[0].message
# the server should still work afterwards assert message.content is not None and len(message.content) >= 0
completion = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
completion = completion.choices[0].text
assert completion is not None and len(completion) >= 0
...@@ -21,8 +21,6 @@ from .test_completion import zephyr_lora_files # noqa: F401 ...@@ -21,8 +21,6 @@ from .test_completion import zephyr_lora_files # noqa: F401
# any model with a chat template should work here # any model with a chat template should work here
MODEL_NAME = os.path.join(models_path_prefix, "HuggingFaceH4/zephyr-7b-beta") MODEL_NAME = os.path.join(models_path_prefix, "HuggingFaceH4/zephyr-7b-beta")
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def monkeypatch_module(): def monkeypatch_module():
...@@ -492,20 +490,9 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI, ...@@ -492,20 +490,9 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI,
assert last_completion_tokens == 10 assert last_completion_tokens == 10
# NOTE: Not sure why, but when I place this after `test_guided_regex_chat`
# (i.e. using the same ordering as in the Completions API tests), the test
# will fail on the second `guided_decoding_backend` even when I swap their order
# (ref: https://github.com/vllm-project/vllm/pull/5526#issuecomment-2173772256)
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
async def test_guided_choice_chat(client: openai.AsyncOpenAI, async def test_guided_choice_chat(client: openai.AsyncOpenAI,
is_v1_server: bool,
guided_decoding_backend: str,
sample_guided_choice): sample_guided_choice):
if is_v1_server and guided_decoding_backend != 'xgrammar':
pytest.skip("Only xgrammar backend is supported with V1")
messages = [{ messages = [{
"role": "system", "role": "system",
"content": "you are a helpful assistant" "content": "you are a helpful assistant"
...@@ -520,8 +507,7 @@ async def test_guided_choice_chat(client: openai.AsyncOpenAI, ...@@ -520,8 +507,7 @@ async def test_guided_choice_chat(client: openai.AsyncOpenAI,
messages=messages, messages=messages,
max_completion_tokens=10, max_completion_tokens=10,
temperature=0.7, temperature=0.7,
extra_body=dict(guided_choice=sample_guided_choice, extra_body=dict(guided_choice=sample_guided_choice))
guided_decoding_backend=guided_decoding_backend))
choice1 = chat_completion.choices[0].message.content choice1 = chat_completion.choices[0].message.content
assert choice1 in sample_guided_choice assert choice1 in sample_guided_choice
...@@ -535,22 +521,16 @@ async def test_guided_choice_chat(client: openai.AsyncOpenAI, ...@@ -535,22 +521,16 @@ async def test_guided_choice_chat(client: openai.AsyncOpenAI,
messages=messages, messages=messages,
max_completion_tokens=10, max_completion_tokens=10,
temperature=0.7, temperature=0.7,
extra_body=dict(guided_choice=sample_guided_choice, extra_body=dict(guided_choice=sample_guided_choice))
guided_decoding_backend=guided_decoding_backend))
choice2 = chat_completion.choices[0].message.content choice2 = chat_completion.choices[0].message.content
assert choice2 in sample_guided_choice assert choice2 in sample_guided_choice
assert choice1 != choice2 assert choice1 != choice2
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) async def test_guided_json_chat(client: openai.AsyncOpenAI,
async def test_guided_json_chat(client: openai.AsyncOpenAI, is_v1_server: bool,
guided_decoding_backend: str,
sample_json_schema): sample_json_schema):
if is_v1_server:
pytest.skip("sample_json_schema has features unsupported in V1")
messages = [{ messages = [{
"role": "system", "role": "system",
"content": "you are a helpful assistant" "content": "you are a helpful assistant"
...@@ -565,8 +545,7 @@ async def test_guided_json_chat(client: openai.AsyncOpenAI, is_v1_server: bool, ...@@ -565,8 +545,7 @@ async def test_guided_json_chat(client: openai.AsyncOpenAI, is_v1_server: bool,
model=MODEL_NAME, model=MODEL_NAME,
messages=messages, messages=messages,
max_completion_tokens=1000, max_completion_tokens=1000,
extra_body=dict(guided_json=sample_json_schema, extra_body=dict(guided_json=sample_json_schema))
guided_decoding_backend=guided_decoding_backend))
message = chat_completion.choices[0].message message = chat_completion.choices[0].message
assert message.content is not None assert message.content is not None
json1 = json.loads(message.content) json1 = json.loads(message.content)
...@@ -583,8 +562,7 @@ async def test_guided_json_chat(client: openai.AsyncOpenAI, is_v1_server: bool, ...@@ -583,8 +562,7 @@ async def test_guided_json_chat(client: openai.AsyncOpenAI, is_v1_server: bool,
model=MODEL_NAME, model=MODEL_NAME,
messages=messages, messages=messages,
max_completion_tokens=1000, max_completion_tokens=1000,
extra_body=dict(guided_json=sample_json_schema, extra_body=dict(guided_json=sample_json_schema))
guided_decoding_backend=guided_decoding_backend))
message = chat_completion.choices[0].message message = chat_completion.choices[0].message
assert message.content is not None assert message.content is not None
json2 = json.loads(message.content) json2 = json.loads(message.content)
...@@ -594,13 +572,7 @@ async def test_guided_json_chat(client: openai.AsyncOpenAI, is_v1_server: bool, ...@@ -594,13 +572,7 @@ async def test_guided_json_chat(client: openai.AsyncOpenAI, is_v1_server: bool,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) async def test_guided_regex_chat(client: openai.AsyncOpenAI, sample_regex):
async def test_guided_regex_chat(client: openai.AsyncOpenAI,
is_v1_server: bool,
guided_decoding_backend: str, sample_regex):
if is_v1_server and guided_decoding_backend != 'xgrammar':
pytest.skip("Only xgrammar backend is supported with V1")
messages = [{ messages = [{
"role": "system", "role": "system",
...@@ -615,8 +587,7 @@ async def test_guided_regex_chat(client: openai.AsyncOpenAI, ...@@ -615,8 +587,7 @@ async def test_guided_regex_chat(client: openai.AsyncOpenAI,
model=MODEL_NAME, model=MODEL_NAME,
messages=messages, messages=messages,
max_completion_tokens=20, max_completion_tokens=20,
extra_body=dict(guided_regex=sample_regex, extra_body=dict(guided_regex=sample_regex))
guided_decoding_backend=guided_decoding_backend))
ip1 = chat_completion.choices[0].message.content ip1 = chat_completion.choices[0].message.content
assert ip1 is not None assert ip1 is not None
assert re.fullmatch(sample_regex, ip1) is not None assert re.fullmatch(sample_regex, ip1) is not None
...@@ -627,8 +598,7 @@ async def test_guided_regex_chat(client: openai.AsyncOpenAI, ...@@ -627,8 +598,7 @@ async def test_guided_regex_chat(client: openai.AsyncOpenAI,
model=MODEL_NAME, model=MODEL_NAME,
messages=messages, messages=messages,
max_completion_tokens=20, max_completion_tokens=20,
extra_body=dict(guided_regex=sample_regex, extra_body=dict(guided_regex=sample_regex))
guided_decoding_backend=guided_decoding_backend))
ip2 = chat_completion.choices[0].message.content ip2 = chat_completion.choices[0].message.content
assert ip2 is not None assert ip2 is not None
assert re.fullmatch(sample_regex, ip2) is not None assert re.fullmatch(sample_regex, ip2) is not None
...@@ -657,15 +627,9 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI): ...@@ -657,15 +627,9 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI, async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI,
is_v1_server: bool,
guided_decoding_backend: str,
sample_guided_choice): sample_guided_choice):
if is_v1_server and guided_decoding_backend != 'xgrammar':
pytest.skip("Only xgrammar backend is supported with V1")
messages = [{ messages = [{
"role": "system", "role": "system",
"content": "you are a helpful assistant" "content": "you are a helpful assistant"
...@@ -681,8 +645,7 @@ async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI, ...@@ -681,8 +645,7 @@ async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI,
max_completion_tokens=10, max_completion_tokens=10,
logprobs=True, logprobs=True,
top_logprobs=5, top_logprobs=5,
extra_body=dict(guided_choice=sample_guided_choice, extra_body=dict(guided_choice=sample_guided_choice))
guided_decoding_backend=guided_decoding_backend))
assert chat_completion.choices[0].logprobs is not None assert chat_completion.choices[0].logprobs is not None
assert chat_completion.choices[0].logprobs.content is not None assert chat_completion.choices[0].logprobs.content is not None
...@@ -694,14 +657,7 @@ async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI, ...@@ -694,14 +657,7 @@ async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) async def test_named_tool_use(client: openai.AsyncOpenAI, sample_json_schema):
async def test_named_tool_use(client: openai.AsyncOpenAI, is_v1_server: bool,
guided_decoding_backend: str,
sample_json_schema):
if is_v1_server:
pytest.skip("sample_json_schema has features unsupported on V1")
messages = [{ messages = [{
"role": "system", "role": "system",
"content": "you are a helpful assistant" "content": "you are a helpful assistant"
...@@ -733,7 +689,7 @@ async def test_named_tool_use(client: openai.AsyncOpenAI, is_v1_server: bool, ...@@ -733,7 +689,7 @@ async def test_named_tool_use(client: openai.AsyncOpenAI, is_v1_server: bool,
"name": "dummy_function_name" "name": "dummy_function_name"
} }
}, },
extra_body=dict(guided_decoding_backend=guided_decoding_backend)) )
message = chat_completion.choices[0].message message = chat_completion.choices[0].message
assert len(message.content) == 0 assert len(message.content) == 0
json_string = message.tool_calls[0].function.arguments json_string = message.tool_calls[0].function.arguments
...@@ -768,7 +724,6 @@ async def test_named_tool_use(client: openai.AsyncOpenAI, is_v1_server: bool, ...@@ -768,7 +724,6 @@ async def test_named_tool_use(client: openai.AsyncOpenAI, is_v1_server: bool,
"name": "dummy_function_name" "name": "dummy_function_name"
} }
}, },
extra_body=dict(guided_decoding_backend=guided_decoding_backend),
stream=True) stream=True)
output = [] output = []
...@@ -893,7 +848,6 @@ async def test_required_tool_use(client: openai.AsyncOpenAI, ...@@ -893,7 +848,6 @@ async def test_required_tool_use(client: openai.AsyncOpenAI,
model=model_name, model=model_name,
tools=tools, tools=tools,
tool_choice="required", tool_choice="required",
extra_body=dict(guided_decoding_backend="outlines"),
) )
assert chat_completion.choices[0].message.tool_calls is not None assert chat_completion.choices[0].message.tool_calls is not None
...@@ -905,7 +859,6 @@ async def test_required_tool_use(client: openai.AsyncOpenAI, ...@@ -905,7 +859,6 @@ async def test_required_tool_use(client: openai.AsyncOpenAI,
model=model_name, model=model_name,
tools=tools, tools=tools,
tool_choice="required", tool_choice="required",
extra_body=dict(guided_decoding_backend="outlines"),
stream=True, stream=True,
) )
...@@ -919,12 +872,7 @@ async def test_required_tool_use(client: openai.AsyncOpenAI, ...@@ -919,12 +872,7 @@ async def test_required_tool_use(client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_inconsistent_tool_choice_and_tools(client: openai.AsyncOpenAI, async def test_inconsistent_tool_choice_and_tools(client: openai.AsyncOpenAI,
is_v1_server: bool,
sample_json_schema): sample_json_schema):
if is_v1_server:
pytest.skip("sample_json_schema has features unsupported on V1")
messages = [{ messages = [{
"role": "system", "role": "system",
"content": "you are a helpful assistant" "content": "you are a helpful assistant"
......
# SPDX-License-Identifier: Apache-2.0
import openai
import pytest
import pytest_asyncio
from vllm.config import ModelConfig
from ...utils import RemoteOpenAIServer
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
def get_vocab_size(model_name):
config = ModelConfig(
model=model_name,
task="auto",
tokenizer=model_name,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="bfloat16",
)
return config.get_vocab_size()
@pytest.fixture(scope="module")
def server():
args = [
"--dtype",
"bfloat16",
"--max-model-len",
"1024",
"--enforce-eager",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
async def test_chat_logit_bias_valid(client):
"""Test that valid logit_bias values are accepted in chat completions."""
vocab_size = get_vocab_size(MODEL_NAME)
valid_token_id = vocab_size - 1
completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=[{
"role": "user",
"content": "Testing valid logit bias"
}],
max_tokens=5,
logit_bias={str(valid_token_id): 1.0},
)
assert completion.choices[0].message.content is not None
@pytest.mark.asyncio
async def test_chat_logit_bias_invalid(client):
"""Test that invalid logit_bias values are rejected in chat completions."""
vocab_size = get_vocab_size(MODEL_NAME)
invalid_token_id = vocab_size + 1
with pytest.raises(openai.BadRequestError) as excinfo:
await client.chat.completions.create(
model=MODEL_NAME,
messages=[{
"role": "user",
"content": "Testing invalid logit bias"
}],
max_tokens=5,
logit_bias={str(invalid_token_id): 1.0},
)
error = excinfo.value
error_message = str(error)
assert error.status_code == 400
assert str(invalid_token_id) in error_message
assert str(vocab_size) in error_message
...@@ -12,6 +12,7 @@ import requests ...@@ -12,6 +12,7 @@ import requests
from vllm.entrypoints.openai.protocol import EmbeddingResponse from vllm.entrypoints.openai.protocol import EmbeddingResponse
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from ...models.embedding.utils import check_embeddings_close
from ...utils import RemoteOpenAIServer, models_path_prefix from ...utils import RemoteOpenAIServer, models_path_prefix
MODEL_NAME = os.path.join(models_path_prefix, "intfloat/multilingual-e5-small") MODEL_NAME = os.path.join(models_path_prefix, "intfloat/multilingual-e5-small")
...@@ -191,30 +192,35 @@ async def test_batch_base64_embedding(client: openai.AsyncOpenAI, ...@@ -191,30 +192,35 @@ async def test_batch_base64_embedding(client: openai.AsyncOpenAI,
responses_float = await client.embeddings.create(input=input_texts, responses_float = await client.embeddings.create(input=input_texts,
model=model_name, model=model_name,
encoding_format="float") encoding_format="float")
float_data = [d.embedding for d in responses_float.data]
responses_base64 = await client.embeddings.create(input=input_texts, responses_base64 = await client.embeddings.create(input=input_texts,
model=model_name, model=model_name,
encoding_format="base64") encoding_format="base64")
base64_data = []
decoded_responses_base64_data = []
for data in responses_base64.data: for data in responses_base64.data:
decoded_responses_base64_data.append( base64_data.append(
np.frombuffer(base64.b64decode(data.embedding), np.frombuffer(base64.b64decode(data.embedding),
dtype="float32").tolist()) dtype="float32").tolist())
assert responses_float.data[0].embedding == decoded_responses_base64_data[ check_embeddings_close(
0] embeddings_0_lst=float_data,
assert responses_float.data[1].embedding == decoded_responses_base64_data[ embeddings_1_lst=base64_data,
1] name_0="float",
name_1="base64",
)
# Default response is float32 decoded from base64 by OpenAI Client # Default response is float32 decoded from base64 by OpenAI Client
responses_default = await client.embeddings.create(input=input_texts, responses_default = await client.embeddings.create(input=input_texts,
model=model_name) model=model_name)
default_data = [d.embedding for d in responses_default.data]
assert responses_float.data[0].embedding == responses_default.data[ check_embeddings_close(
0].embedding embeddings_0_lst=float_data,
assert responses_float.data[1].embedding == responses_default.data[ embeddings_1_lst=default_data,
1].embedding name_0="float",
name_1="default",
)
@pytest.mark.asyncio @pytest.mark.asyncio
......
# SPDX-License-Identifier: Apache-2.0
"""
Run `pytest tests/entrypoints/openai/test_embedding_dimensions.py`.
"""
from typing import NamedTuple
import openai
import pytest
from vllm.entrypoints.openai.protocol import EmbeddingResponse
from ...utils import RemoteOpenAIServer
class ModelInfo(NamedTuple):
name: str
is_matryoshka: bool
MODELS = [
ModelInfo(name="BAAI/bge-m3", is_matryoshka=False),
ModelInfo(name="jinaai/jina-embeddings-v3", is_matryoshka=True),
]
input_texts = [
"The chef prepared a delicious meal.",
] * 3
@pytest.mark.asyncio
@pytest.mark.parametrize("model", MODELS)
async def test_validating_dimensions(model: ModelInfo):
args = [
"--task",
"embed",
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--enforce-eager",
"--max-model-len",
"512",
"--trust_remote_code"
]
with RemoteOpenAIServer(model.name, args) as remote_server:
client = remote_server.get_async_client()
async def make_request(dimensions):
embedding_response = await client.embeddings.create(
model=model.name,
input=input_texts,
dimensions=dimensions,
encoding_format="float",
)
embeddings = EmbeddingResponse.model_validate(
embedding_response.model_dump(mode="json"))
assert embeddings.id is not None
assert len(embeddings.data) == 3
assert len(embeddings.data[0].embedding) > 0
assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens > 0
assert embeddings.usage.total_tokens > 0
if dimensions is not None:
assert len(embeddings.data[0].embedding) == dimensions
if model.is_matryoshka:
for dimensions in [None, 16]:
await make_request(dimensions)
with pytest.raises(openai.BadRequestError):
for dimensions in [-1]:
await make_request(dimensions)
else:
for dimensions in [None]:
await make_request(dimensions)
with pytest.raises(openai.BadRequestError):
for dimensions in [-1, 16]:
await make_request(dimensions)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment