Commit 711aa9d5 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.0' into v0.10.0-dev

parents 751c492c 6d8d0a24
......@@ -4,6 +4,7 @@
import multiprocessing
import os
import numpy as np
import pytest
import torch
import torch.distributed
......@@ -177,6 +178,38 @@ def test_pynccl_all_gather():
distributed_run(all_gather_worker_fn, 2)
@worker_fn_wrapper
def all_gatherv_worker_fn():
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
device=get_world_group().device)
rank = pynccl_comm.rank
world_size = pynccl_comm.world_size
device = f'cuda:{pynccl_comm.rank}'
assert world_size <= 8
sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size]
num_elems = sizes[rank]
tensor = torch.arange(num_elems, dtype=torch.float32,
device=device) + rank * 100
result = torch.zeros(sum(sizes), dtype=torch.float32, device=device)
expected = torch.cat([
torch.arange(sizes[r], dtype=torch.float32) + r * 100
for r in range(world_size)
]).to(device)
pynccl_comm.all_gatherv(result, tensor, sizes=sizes)
torch.cuda.synchronize()
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
def test_pynccl_all_gatherv():
distributed_run(all_gatherv_worker_fn, 2)
@worker_fn_wrapper
def reduce_scatter_worker_fn():
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
......@@ -214,6 +247,43 @@ def test_pynccl_reduce_scatter():
distributed_run(reduce_scatter_worker_fn, 2)
@worker_fn_wrapper
def reduce_scatterv_worker_fn():
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
device=get_world_group().device)
rank = pynccl_comm.rank
world_size = pynccl_comm.world_size
device = f'cuda:{pynccl_comm.rank}'
assert world_size <= 8
sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size]
num_elems = sum(sizes)
tensor = torch.arange(num_elems, dtype=torch.float32,
device=device) + rank * 100
result = torch.zeros(sizes[rank], dtype=torch.float32, device=device)
# Calculate expected result for this rank's chunk
all_tensors = [
torch.arange(num_elems, dtype=torch.float32) + r * 100
for r in range(world_size)
]
sizes_cumsum = np.cumsum(sizes)
start = 0 if rank == 0 else sizes_cumsum[rank - 1]
end = sizes_cumsum[rank]
expected = sum(tensor[start:end] for tensor in all_tensors).to(device)
pynccl_comm.reduce_scatterv(result, tensor, sizes=sizes)
torch.cuda.synchronize()
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
def test_pynccl_reduce_scatterv():
distributed_run(reduce_scatterv_worker_fn, 2)
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
def test_pynccl_with_cudagraph():
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from argparse import ArgumentError, ArgumentTypeError
from argparse import ArgumentError
from contextlib import nullcontext
from dataclasses import dataclass, field
from typing import Annotated, Literal, Optional
......@@ -12,8 +12,8 @@ import pytest
from vllm.config import CompilationConfig, config
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
get_type, get_type_hints, is_not_builtin,
is_type, literal_to_kwargs, nullable_kvs,
optional_type, parse_type)
is_type, literal_to_kwargs, optional_type,
parse_type)
from vllm.utils import FlexibleArgumentParser
......@@ -25,18 +25,10 @@ from vllm.utils import FlexibleArgumentParser
"foo": 1,
"bar": 2
}),
(json.loads, "foo=1,bar=2", {
"foo": 1,
"bar": 2
}),
])
def test_parse_type(type, value, expected):
parse_type_func = parse_type(type)
context = nullcontext()
if value == "foo=1,bar=2":
context = pytest.warns(DeprecationWarning)
with context:
assert parse_type_func(value) == expected
assert parse_type_func(value) == expected
def test_optional_type():
......@@ -203,34 +195,6 @@ def test_get_kwargs():
assert kwargs["from_cli_config2"]["type"]('{"field": 2}').field == 4
@pytest.mark.parametrize(("arg", "expected"), [
(None, dict()),
("image=16", {
"image": 16
}),
("image=16,video=2", {
"image": 16,
"video": 2
}),
("Image=16, Video=2", {
"image": 16,
"video": 2
}),
])
def test_limit_mm_per_prompt_parser(arg, expected):
"""This functionality is deprecated and will be removed in the future.
This argument should be passed as JSON string instead.
TODO: Remove with nullable_kvs."""
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
if arg is None:
args = parser.parse_args([])
else:
args = parser.parse_args(["--limit-mm-per-prompt", arg])
assert args.limit_mm_per_prompt == expected
@pytest.mark.parametrize(
("arg", "expected"),
[
......@@ -326,18 +290,6 @@ def test_prefix_cache_default():
assert not engine_args.enable_prefix_caching
@pytest.mark.parametrize(
("arg"),
[
"image", # Missing =
"image=4,image=5", # Conflicting values
"image=video=4" # Too many = in tokenized arg
])
def test_bad_nullable_kvs(arg):
with pytest.raises(ArgumentTypeError):
nullable_kvs(arg)
# yapf: disable
@pytest.mark.parametrize(("arg", "expected", "option"), [
(None, None, "mm-processor-kwargs"),
......
......@@ -17,16 +17,19 @@ from vllm.platforms import current_platform
from ...utils import models_path_prefix
MODEL_NAMES = [
os.path.join(models_path_prefix, "Qwen/Qwen2-1.5B-Instruct"),
os.path.join(models_path_prefix, "Qwen/Qwen3-1.7B"),
os.path.join(models_path_prefix, "google/gemma-3-1b-it"),
]
FP8_KV_MODEL_NAMES = [
os.path.join(models_path_prefix, "Qwen/Qwen3-1.7B"),
]
NUM_CONCURRENT = 500
TASK = "gsm8k"
FILTER = "exact_match,strict-match"
RTOL = 0.03
EXPECTED_VALUES = {
"Qwen/Qwen2-1.5B-Instruct": 0.58,
"google/gemma-3-1b-it": 0.25,
os.path.join(models_path_prefix, "Qwen/Qwen3-1.7B"): 0.68,
os.path.join(models_path_prefix, "google/gemma-3-1b-it"): 0.25,
}
......@@ -71,6 +74,10 @@ def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch):
more_args = None
if current_platform.is_tpu():
# Limit compilation time for TPU V1
# xet doesn't work well for both Qwen/Qwen3-1.7B and
# google/gemma-3-1b-it
m.setenv("HF_HUB_DISABLE_XET", "1")
more_args = "max_model_len=2048,max_num_seqs=64"
# Add TP test (if provided)
......@@ -80,9 +87,27 @@ def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch):
run_test(model, more_args)
def test_lm_eval_accuracy_v0_engine(monkeypatch: pytest.MonkeyPatch):
"""Run with the V0 Engine."""
@pytest.mark.skipif(not current_platform.is_cuda()
and not current_platform.is_tpu(),
reason="V1 is currently only supported on CUDA and TPU")
@pytest.mark.parametrize("model", FP8_KV_MODEL_NAMES)
def test_lm_eval_accuracy_v1_engine_fp8_kv_cache(
model, monkeypatch: pytest.MonkeyPatch):
"""Run with the V1 Engine."""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "0")
run_test(os.path.join(models_path_prefix,"Qwen/Qwen2-1.5B-Instruct"))
\ No newline at end of file
m.setenv("VLLM_USE_V1", "1")
more_args = None
if current_platform.is_tpu():
# Limit compilation time for TPU V1
# xet doesn't work well for Qwen/Qwen3-1.7B
m.setenv("HF_HUB_DISABLE_XET", "1")
more_args = "max_model_len=2048,max_num_seqs=128,kv_cache_dtype=fp8"
# Add TP test (if provided)
if TPU_TP_TEST_STR:
more_args += ",{}".format(TPU_TP_TEST_STR)
run_test(model, more_args)
......@@ -18,15 +18,19 @@ from vllm.outputs import RequestOutput
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from ...utils import models_path_prefix
MODEL_NAME = os.path.join(models_path_prefix, "Qwen2.5-1.5B-Instruct")
GUIDED_DECODING_BACKENDS = [
MODEL_NAME = os.path.join(models_path_prefix, "Qwen/Qwen2.5-1.5B-Instruct")
# Separate backends which support grammars vs ones
# which only support regex based constraints in tests.
GRAMMAR_DECODING_BACKENDS = [
# (backend, disable_any_whitespace),
("outlines", False),
("lm-format-enforcer", False),
("xgrammar", True),
("guidance", True),
]
ALL_DECODING_BACKENDS = ([("outlines", False)] + GRAMMAR_DECODING_BACKENDS)
@pytest.fixture(scope="module")
def llm():
......@@ -42,7 +46,7 @@ def llm():
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
ALL_DECODING_BACKENDS)
def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
disable_any_whitespace: bool):
sampling_params = SamplingParams(
......@@ -52,6 +56,7 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
regex=sample_regex,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace))
outputs = llm.generate(prompts=[
f"Give an example IPv4 address with this regex: {sample_regex}"
] * 2,
......@@ -72,7 +77,7 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
ALL_DECODING_BACKENDS)
def test_guided_json_completion(sample_json_schema, llm,
guided_decoding_backend: str,
disable_any_whitespace: bool):
......@@ -106,7 +111,7 @@ def test_guided_json_completion(sample_json_schema, llm,
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
ALL_DECODING_BACKENDS)
def test_guided_complex_json_completion(sample_complex_json_schema, llm,
guided_decoding_backend: str,
disable_any_whitespace: bool):
......@@ -141,7 +146,7 @@ def test_guided_complex_json_completion(sample_complex_json_schema, llm,
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
ALL_DECODING_BACKENDS)
def test_guided_definition_json_completion(sample_definition_json_schema, llm,
guided_decoding_backend: str,
disable_any_whitespace: bool):
......@@ -176,7 +181,7 @@ def test_guided_definition_json_completion(sample_definition_json_schema, llm,
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
ALL_DECODING_BACKENDS)
def test_guided_enum_json_completion(sample_enum_json_schema, llm,
guided_decoding_backend: str,
disable_any_whitespace: bool):
......@@ -221,7 +226,7 @@ def test_guided_enum_json_completion(sample_enum_json_schema, llm,
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
ALL_DECODING_BACKENDS)
def test_guided_choice_completion(sample_guided_choice, llm,
guided_decoding_backend: str,
disable_any_whitespace: bool):
......@@ -251,7 +256,7 @@ def test_guided_choice_completion(sample_guided_choice, llm,
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
GRAMMAR_DECODING_BACKENDS)
def test_guided_grammar(sample_sql_statements, llm,
guided_decoding_backend: str,
disable_any_whitespace: bool):
......@@ -347,7 +352,7 @@ def test_disable_guided_decoding_fallback(sample_regex, llm):
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
GRAMMAR_DECODING_BACKENDS)
def test_guided_json_object(llm, guided_decoding_backend: str,
disable_any_whitespace: bool):
sampling_params = SamplingParams(
......@@ -380,7 +385,9 @@ def test_guided_json_object(llm, guided_decoding_backend: str,
# Parse to verify it is valid JSON
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)
# A list is not what was intended, but is still valid
# json.
assert isinstance(parsed_json, (dict, list))
class CarType(str, Enum):
......@@ -398,7 +405,7 @@ class CarDescription(BaseModel):
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
ALL_DECODING_BACKENDS)
def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str,
disable_any_whitespace: bool):
json_schema = CarDescription.model_json_schema()
......@@ -430,7 +437,7 @@ def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str,
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
ALL_DECODING_BACKENDS)
def test_guided_number_range_json_completion(llm, guided_decoding_backend: str,
disable_any_whitespace: bool):
sample_output_schema = {
......
......@@ -70,8 +70,9 @@ def run_test(more_args):
@pytest.mark.skipif(not current_platform.is_cuda()
and not current_platform.is_tpu(),
reason="V1 currently only supported on CUDA and TPU")
and not current_platform.is_tpu()
and not current_platform.is_xpu(),
reason="V1 currently only supported on CUDA, XPU and TPU")
def test_lm_eval_accuracy_v1_engine(monkeypatch: pytest.MonkeyPatch):
"""Run with the V1 Engine."""
......
......@@ -1118,10 +1118,7 @@ async def test_http_chat_no_model_name_with_curl(server: RemoteOpenAIServer):
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME, ""])
async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer,
model_name: str):
async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer):
openai_api_key = "EMPTY"
openai_api_base = f"http://localhost:{server.port}/v1"
......@@ -1140,3 +1137,35 @@ async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer,
messages=messages,
)
assert response.model == MODEL_NAME
@pytest.mark.asyncio
async def test_invocations(server: RemoteOpenAIServer,
client: openai.AsyncOpenAI):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role": "user",
"content": "what is 1+1?"
}]
request_args = {
"model": MODEL_NAME,
"messages": messages,
"max_completion_tokens": 5,
"temperature": 0.0,
"logprobs": False,
}
chat_completion = await client.chat.completions.create(**request_args)
invocation_response = requests.post(server.url_for("invocations"),
json=request_args)
invocation_response.raise_for_status()
chat_output = chat_completion.model_dump()
invocation_output = invocation_response.json()
assert chat_output.keys() == invocation_output.keys()
assert chat_output["choices"] == invocation_output["choices"]
......@@ -155,3 +155,29 @@ def test_batch_classification_empty_list(server: RemoteOpenAIServer,
assert output.object == "list"
assert isinstance(output.data, list)
assert len(output.data) == 0
@pytest.mark.asyncio
async def test_invocations(server: RemoteOpenAIServer):
request_args = {
"model": MODEL_NAME,
"input": "This product was excellent and exceeded my expectations"
}
classification_response = requests.post(server.url_for("classify"),
json=request_args)
classification_response.raise_for_status()
invocation_response = requests.post(server.url_for("invocations"),
json=request_args)
invocation_response.raise_for_status()
classification_output = classification_response.json()
invocation_output = invocation_response.json()
assert classification_output.keys() == invocation_output.keys()
for classification_data, invocation_data in zip(
classification_output["data"], invocation_output["data"]):
assert classification_data.keys() == invocation_data.keys()
assert classification_data["probs"] == pytest.approx(
invocation_data["probs"], rel=0.01)
......@@ -153,3 +153,13 @@ def test_chat_template_validation_for_sad_paths(serve_parser):
args = serve_parser.parse_args(args=["--chat-template", "does/not/exist"])
with pytest.raises(ValueError):
validate_parsed_serve_args(args)
@pytest.mark.parametrize(
"cli_args, expected_middleware",
[(["--middleware", "middleware1", "--middleware", "middleware2"
], ["middleware1", "middleware2"]), ([], [])])
def test_middleware(serve_parser, cli_args, expected_middleware):
"""Ensure multiple middleware args are parsed properly"""
args = serve_parser.parse_args(args=cli_args)
assert args.middleware == expected_middleware
......@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# imports for guided decoding tests
import json
import os
import shutil
from tempfile import TemporaryDirectory
from typing import Optional
......@@ -12,6 +13,7 @@ import pytest
import os
import pytest_asyncio
import regex as re
import requests
# downloading lora to test lora requests
# from huggingface_hub import snapshot_download
from openai import BadRequestError
......@@ -26,10 +28,6 @@ MODEL_NAME = os.path.join(models_path_prefix, "HuggingFaceH4/zephyr-7b-beta")
# technically these adapters use a different base model,
# but we're not testing generation quality here
LORA_NAME = os.path.join(models_path_prefix, "typeof/zephyr-7b-beta-lora")
PA_NAME = os.path.join(models_path_prefix, "swapnilbp/llama_tweet_ptune")
# if PA_NAME changes, PA_NUM_VIRTUAL_TOKENS might also
# need to change to match the prompt adapter
PA_NUM_VIRTUAL_TOKENS = 8
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
......@@ -57,14 +55,7 @@ def zephyr_lora_added_tokens_files(zephyr_lora_files):
@pytest.fixture(scope="module")
def zephyr_pa_files():
# return snapshot_download(repo_id=PA_NAME)
return PA_NAME
@pytest.fixture(scope="module")
def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files,
zephyr_pa_files):
def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files):
return [
# use half precision for speed and memory savings in CI environment
"--dtype",
......@@ -83,15 +74,6 @@ def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files,
"64",
"--max-cpu-loras",
"2",
# pa config
"--enable-prompt-adapter",
"--prompt-adapters",
f"zephyr-pa={zephyr_pa_files}",
f"zephyr-pa2={zephyr_pa_files}",
"--max-prompt-adapters",
"2",
"--max-prompt-adapter-token",
"128",
]
......@@ -100,8 +82,19 @@ def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files,
def server(default_server_args, request):
if request.param:
default_server_args.append(request.param)
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
yield remote_server
original_value = os.environ.get('VLLM_USE_V1')
os.environ['VLLM_USE_V1'] = '0'
try:
with RemoteOpenAIServer(MODEL_NAME,
default_server_args) as remote_server:
yield remote_server
finally:
# Restore original env value
if original_value is None:
os.environ.pop('VLLM_USE_V1', None)
else:
os.environ['VLLM_USE_V1'] = original_value
@pytest_asyncio.fixture
......@@ -112,14 +105,11 @@ async def client(server):
@pytest.mark.asyncio
@pytest.mark.parametrize(
# first test base model, then test loras, then test prompt adapters
"model_name,num_virtual_tokens",
[(MODEL_NAME, 0), ("zephyr-lora", 0), ("zephyr-lora2", 0),
("zephyr-pa", PA_NUM_VIRTUAL_TOKENS),
("zephyr-pa2", PA_NUM_VIRTUAL_TOKENS)],
# first test base model, then test loras
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
)
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str,
num_virtual_tokens: int):
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
completion = await client.completions.create(model=model_name,
prompt="Hello, my name is",
max_tokens=5,
......@@ -132,9 +122,7 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str,
assert len(choice.text) >= 5
assert choice.finish_reason == "length"
assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5,
prompt_tokens=6 + num_virtual_tokens,
total_tokens=11 + num_virtual_tokens)
completion_tokens=5, prompt_tokens=6, total_tokens=11)
# test using token IDs
completion = await client.completions.create(
......@@ -177,9 +165,9 @@ async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI):
@pytest.mark.asyncio
@pytest.mark.parametrize(
# first test base model, then test loras, then test prompt adapters
# first test base model, then test loras
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2", "zephyr-pa", "zephyr-pa2"],
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
)
async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str):
# test using token IDs
......@@ -196,9 +184,9 @@ async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str):
@pytest.mark.asyncio
@pytest.mark.parametrize(
# just test 1 lora and 1 pa hereafter
# just test 1 lora
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
[MODEL_NAME, "zephyr-lora"],
)
async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str):
# test using token IDs
......@@ -219,7 +207,7 @@ async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str):
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
[MODEL_NAME, "zephyr-lora"],
)
async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str):
# test using token IDs
......@@ -240,7 +228,7 @@ async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str):
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
[MODEL_NAME, "zephyr-lora"],
)
async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
model_name: str):
......@@ -316,7 +304,7 @@ async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI,
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
[MODEL_NAME, "zephyr-lora"],
)
async def test_completion_streaming(client: openai.AsyncOpenAI,
model_name: str):
......@@ -350,7 +338,7 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
[MODEL_NAME, "zephyr-lora"],
)
async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
"""Streaming for parallel sampling.
......@@ -384,7 +372,7 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
[MODEL_NAME, "zephyr-lora"],
)
async def test_completion_stream_options(client: openai.AsyncOpenAI,
model_name: str):
......@@ -521,7 +509,7 @@ async def test_completion_stream_options(client: openai.AsyncOpenAI,
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
[MODEL_NAME, "zephyr-lora"],
)
async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str):
# test both text and token IDs
......@@ -836,3 +824,27 @@ async def test_echo_stream_completion(client: openai.AsyncOpenAI,
assert content is not None and saying in content
else:
assert content is not None and saying not in content
@pytest.mark.asyncio
async def test_invocations(server: RemoteOpenAIServer,
client: openai.AsyncOpenAI):
request_args = {
"model": MODEL_NAME,
"prompt": "Hello, my name is",
"max_tokens": 5,
"temperature": 0.0,
"logprobs": None,
}
completion = await client.completions.create(**request_args)
invocation_response = requests.post(server.url_for("invocations"),
json=request_args)
invocation_response.raise_for_status()
completion_output = completion.model_dump()
invocation_output = invocation_response.json()
assert completion_output.keys() == invocation_output.keys()
assert completion_output["choices"] == invocation_output["choices"]
......@@ -72,8 +72,43 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str,
"The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"],
},
"options": {
"$ref": "#/$defs/WeatherOptions",
"description":
"Optional parameters for weather query",
},
},
"required": ["country", "unit"],
"$defs": {
"WeatherOptions": {
"title": "WeatherOptions",
"type": "object",
"additionalProperties": False,
"properties": {
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"default": "celsius",
"description": "Temperature unit",
"title": "Temperature Unit",
},
"include_forecast": {
"type": "boolean",
"default": False,
"description":
"Whether to include a 24-hour forecast",
"title": "Include Forecast",
},
"language": {
"type": "string",
"default": "zh-CN",
"description": "Language of the response",
"title": "Language",
"enum": ["zh-CN", "en-US", "ja-JP"],
},
},
},
},
},
},
},
......@@ -145,7 +180,11 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str,
"enable_thinking": enable_thinking
}
})
if enable_thinking:
assert chat_completion.choices[0].message.\
reasoning_content is not None
assert chat_completion.choices[0].message.\
reasoning_content != ""
assert chat_completion.choices[0].message.tool_calls is not None
assert len(chat_completion.choices[0].message.tool_calls) > 0
else:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import openai # use the official client for correctness check
import pytest
import pytest_asyncio
from huggingface_hub import snapshot_download
from ...conftest import AudioTestAssets
from ...utils import RemoteOpenAIServer
# NOTE - the tests in this module are currently analogous to test_chat, but are
# separated to avoid OOM killing due to module-scoped servers, since we
# need a multimodal model for these tests.
# Contains a modality specific lora alongside the base model
MULTIMODAL_MODEL_NAME = snapshot_download(
"microsoft/Phi-4-multimodal-instruct")
AUDIO_LORA_PATH = os.path.join(MULTIMODAL_MODEL_NAME, "speech-lora")
ACTIVE_MM_LORA_RESPONSE = "Spoken text: The first words I spoke in the original chronograph, a little piece of practical poetry. Mary had a little lamb, it slept with quite a snow, and everywhere that Mary went, the lamb was sure to go." # noqa: E501
@pytest.fixture(scope="module")
def monkeypatch_module():
from _pytest.monkeypatch import MonkeyPatch
mpatch = MonkeyPatch()
yield mpatch
mpatch.undo()
@pytest.fixture(scope="module", params=[False, True])
def multimodal_server(request, monkeypatch_module): # noqa: F811
use_v1 = request.param
monkeypatch_module.setenv('VLLM_USE_V1', '1' if use_v1 else '0')
args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"half",
"--max-model-len",
"12800",
"--enforce-eager",
# lora config below
"--enable-lora",
"--lora-modules",
f"speech={AUDIO_LORA_PATH}",
"--max-lora-rank",
"320",
"--max-num-seqs",
"2",
"--trust-remote-code",
"--gpu-memory-utilization",
"0.8",
"--default-mm-loras",
f"{{\"audio\": \"{AUDIO_LORA_PATH}\"}}",
]
with RemoteOpenAIServer(MULTIMODAL_MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def multi_modal_client(multimodal_server):
async with multimodal_server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
@pytest.mark.parametrize(
# base model with default lora should give the same response as lora model
"model_name",
[MULTIMODAL_MODEL_NAME, "speech"],
)
async def test_default_mm_lora_chat_completions(
model_name: str,
multi_modal_client: openai.AsyncOpenAI,
audio_assets: AudioTestAssets,
):
messages = [{
"role":
"user",
"content": [{
"type": "text",
"text": "Can you transcribe this audio?",
}, {
"type": "audio_url",
"audio_url": {
"url": audio_assets[0].url
},
}]
}]
chat_completion = await multi_modal_client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=128,
temperature=0.0)
assert len(chat_completion.choices) > 0
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 0
assert message.content == ACTIVE_MM_LORA_RESPONSE
......@@ -15,6 +15,7 @@ from vllm.transformers_utils.tokenizer import get_tokenizer
from ...models.language.pooling.embed_utils import (
run_embedding_correctness_test)
from ...models.utils import check_embeddings_close
from ...utils import RemoteOpenAIServer, models_path_prefix
......@@ -298,3 +299,75 @@ async def test_single_embedding_truncation_invalid(client: openai.AsyncOpenAI,
assert "error" in response.object
assert "truncate_prompt_tokens value is greater than max_model_len. "\
"Please, select a smaller truncation size." in response.message
@pytest.mark.asyncio
async def test_invocations(server: RemoteOpenAIServer,
client: openai.AsyncOpenAI):
input_texts = [
"The chef prepared a delicious meal.",
]
request_args = {
"model": MODEL_NAME,
"input": input_texts,
"encoding_format": "float",
}
completion_response = await client.embeddings.create(**request_args)
invocation_response = requests.post(server.url_for("invocations"),
json=request_args)
invocation_response.raise_for_status()
completion_output = completion_response.model_dump()
invocation_output = invocation_response.json()
assert completion_output.keys() == invocation_output.keys()
for completion_data, invocation_data in zip(completion_output["data"],
invocation_output["data"]):
assert completion_data.keys() == invocation_data.keys()
check_embeddings_close(embeddings_0_lst=[completion_data["embedding"]],
embeddings_1_lst=[invocation_data["embedding"]],
name_0="completion",
name_1="invocation")
@pytest.mark.asyncio
async def test_invocations_conversation(server: RemoteOpenAIServer):
messages = [{
"role": "user",
"content": "The cat sat on the mat.",
}, {
"role": "assistant",
"content": "A feline was resting on a rug.",
}, {
"role": "user",
"content": "Stars twinkle brightly in the night sky.",
}]
request_args = {
"model": MODEL_NAME,
"messages": messages,
"encoding_format": "float",
}
chat_response = requests.post(server.url_for("v1/embeddings"),
json=request_args)
chat_response.raise_for_status()
invocation_response = requests.post(server.url_for("invocations"),
json=request_args)
invocation_response.raise_for_status()
chat_output = chat_response.json()
invocation_output = invocation_response.json()
assert chat_output.keys() == invocation_output.keys()
for chat_data, invocation_data in zip(chat_output["data"],
invocation_output["data"]):
assert chat_data.keys() == invocation_data.keys()
check_embeddings_close(embeddings_0_lst=[chat_data["embedding"]],
embeddings_1_lst=[invocation_data["embedding"]],
name_0="chat",
name_1="invocation")
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from typing import Final
import pytest
......@@ -29,7 +30,7 @@ def server():
"--enforce-eager",
"--trust-remote-code",
"--limit-mm-per-prompt",
f"image={MAXIMUM_IMAGES}",
json.dumps({"image": MAXIMUM_IMAGES}),
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
......@@ -95,6 +96,10 @@ def test_openapi_stateless(case: schemathesis.Case):
case.operation.method.upper(),
case.operation.path,
)
if case.operation.path.startswith("/v1/responses"):
# Skip responses API as it is meant to be stateful.
return
timeout = {
# requires a longer timeout
("POST", "/v1/chat/completions"):
......
......@@ -13,7 +13,7 @@ from vllm.transformers_utils.tokenizer import get_tokenizer
from ...utils import RemoteOpenAIServer
MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach"
MODEL_NAME = "internlm/internlm2-1_8b-reward"
DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501
......@@ -21,15 +21,16 @@ DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' +
def server():
args = [
"--task",
"classify",
"reward",
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--enforce-eager",
"--max-model-len",
"8192",
"512",
"--chat-template",
DUMMY_CHAT_TEMPLATE,
"--trust-remote-code",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
......@@ -57,10 +58,10 @@ async def test_single_pooling(server: RemoteOpenAIServer, model_name: str):
assert poolings.id is not None
assert len(poolings.data) == 1
assert len(poolings.data[0].data) == 2
assert len(poolings.data[0].data) == 8
assert poolings.usage.completion_tokens == 0
assert poolings.usage.prompt_tokens == 7
assert poolings.usage.total_tokens == 7
assert poolings.usage.prompt_tokens == 8
assert poolings.usage.total_tokens == 8
# test using token IDs
input_tokens = [1, 1, 1, 1, 1]
......@@ -77,7 +78,7 @@ async def test_single_pooling(server: RemoteOpenAIServer, model_name: str):
assert poolings.id is not None
assert len(poolings.data) == 1
assert len(poolings.data[0].data) == 2
assert len(poolings.data[0].data) == 5
assert poolings.usage.completion_tokens == 0
assert poolings.usage.prompt_tokens == 5
assert poolings.usage.total_tokens == 5
......@@ -104,10 +105,10 @@ async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str):
assert poolings.id is not None
assert len(poolings.data) == 3
assert len(poolings.data[0].data) == 2
assert len(poolings.data[0].data) == 8
assert poolings.usage.completion_tokens == 0
assert poolings.usage.prompt_tokens == 25
assert poolings.usage.total_tokens == 25
assert poolings.usage.prompt_tokens == 29
assert poolings.usage.total_tokens == 29
# test list[list[int]]
input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
......@@ -125,7 +126,7 @@ async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str):
assert poolings.id is not None
assert len(poolings.data) == 4
assert len(poolings.data[0].data) == 2
assert len(poolings.data[0].data) == 5
assert poolings.usage.completion_tokens == 0
assert poolings.usage.prompt_tokens == 17
assert poolings.usage.total_tokens == 17
......@@ -157,7 +158,11 @@ async def test_conversation_pooling(server: RemoteOpenAIServer,
chat_response.raise_for_status()
chat_poolings = PoolingResponse.model_validate(chat_response.json())
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
tokenizer = get_tokenizer(
tokenizer_name=model_name,
tokenizer_mode="fast",
trust_remote_code=True,
)
prompt = tokenizer.apply_chat_template(
messages,
chat_template=DUMMY_CHAT_TEMPLATE,
......@@ -206,6 +211,9 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer,
)
float_response.raise_for_status()
responses_float = PoolingResponse.model_validate(float_response.json())
float_data = [
np.array(d.data).squeeze(-1).tolist() for d in responses_float.data
]
base64_response = requests.post(
server.url_for("pooling"),
......@@ -224,11 +232,10 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer,
np.frombuffer(base64.b64decode(data.data),
dtype="float32").tolist())
check_embeddings_close(
embeddings_0_lst=[d.data for d in responses_float.data],
embeddings_1_lst=decoded_responses_base64_data,
name_0="float32",
name_1="base64")
check_embeddings_close(embeddings_0_lst=float_data,
embeddings_1_lst=decoded_responses_base64_data,
name_0="float32",
name_1="base64")
# Default response is float32 decoded from base64 by OpenAI Client
default_response = requests.post(
......@@ -240,9 +247,83 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer,
)
default_response.raise_for_status()
responses_default = PoolingResponse.model_validate(default_response.json())
default_data = [
np.array(d.data).squeeze(-1).tolist() for d in responses_default.data
]
check_embeddings_close(embeddings_0_lst=float_data,
embeddings_1_lst=default_data,
name_0="float32",
name_1="default")
@pytest.mark.asyncio
async def test_invocations(server: RemoteOpenAIServer):
input_texts = [
"The chef prepared a delicious meal.",
]
request_args = {
"model": MODEL_NAME,
"input": input_texts,
"encoding_format": "float",
}
completion_response = requests.post(server.url_for("pooling"),
json=request_args)
completion_response.raise_for_status()
invocation_response = requests.post(server.url_for("invocations"),
json=request_args)
invocation_response.raise_for_status()
completion_output = completion_response.json()
invocation_output = invocation_response.json()
assert completion_output.keys() == invocation_output.keys()
for completion_data, invocation_data in zip(completion_output["data"],
invocation_output["data"]):
assert completion_data.keys() == invocation_data.keys()
check_embeddings_close(embeddings_0_lst=completion_data["data"],
embeddings_1_lst=invocation_data["data"],
name_0="completion",
name_1="invocation")
@pytest.mark.asyncio
async def test_invocations_conversation(server: RemoteOpenAIServer):
messages = [{
"role": "user",
"content": "The cat sat on the mat.",
}, {
"role": "assistant",
"content": "A feline was resting on a rug.",
}, {
"role": "user",
"content": "Stars twinkle brightly in the night sky.",
}]
request_args = {
"model": MODEL_NAME,
"messages": messages,
"encoding_format": "float",
}
chat_response = requests.post(server.url_for("pooling"), json=request_args)
chat_response.raise_for_status()
check_embeddings_close(
embeddings_0_lst=[d.data for d in responses_default.data],
embeddings_1_lst=[d.data for d in responses_default.data],
name_0="float32",
name_1="base64")
invocation_response = requests.post(server.url_for("invocations"),
json=request_args)
invocation_response.raise_for_status()
chat_output = chat_response.json()
invocation_output = invocation_response.json()
assert chat_output.keys() == invocation_output.keys()
for chat_data, invocation_data in zip(chat_output["data"],
invocation_output["data"]):
assert chat_data.keys() == invocation_data.keys()
check_embeddings_close(embeddings_0_lst=chat_data["data"],
embeddings_1_lst=invocation_data["data"],
name_0="chat",
name_1="invocation")
......@@ -94,3 +94,34 @@ def test_rerank_max_model_len(server: RemoteOpenAIServer, model_name: str):
# Assert just a small fragments of the response
assert "Please reduce the length of the input." in \
rerank_response.text
def test_invocations(server: RemoteOpenAIServer):
query = "What is the capital of France?"
documents = [
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
]
request_args = {
"model": MODEL_NAME,
"query": query,
"documents": documents,
}
rerank_response = requests.post(server.url_for("rerank"),
json=request_args)
rerank_response.raise_for_status()
invocation_response = requests.post(server.url_for("invocations"),
json=request_args)
invocation_response.raise_for_status()
rerank_output = rerank_response.json()
invocation_output = invocation_response.json()
assert rerank_output.keys() == invocation_output.keys()
for rerank_result, invocations_result in zip(rerank_output["results"],
invocation_output["results"]):
assert rerank_result.keys() == invocations_result.keys()
assert rerank_result["relevance_score"] == pytest.approx(
invocations_result["relevance_score"], rel=0.01)
......@@ -13,7 +13,6 @@ from ...utils import RemoteOpenAIServer
from .test_completion import default_server_args # noqa: F401
from .test_completion import zephyr_lora_added_tokens_files # noqa: F401
from .test_completion import zephyr_lora_files # noqa: F401
from .test_completion import zephyr_pa_files # noqa: F401
from .test_completion import MODEL_NAME
......
......@@ -193,3 +193,32 @@ class TestModel:
assert score_response.status_code == 400
assert "Please, select a smaller truncation size." in \
score_response.text
def test_invocations(self, server: RemoteOpenAIServer, model: dict[str,
Any]):
text_1 = "What is the capital of France?"
text_2 = "The capital of France is Paris."
request_args = {
"model": model["name"],
"text_1": text_1,
"text_2": text_2,
}
score_response = requests.post(server.url_for("score"),
json=request_args)
score_response.raise_for_status()
invocation_response = requests.post(server.url_for("invocations"),
json=request_args)
invocation_response.raise_for_status()
score_output = score_response.json()
invocation_output = invocation_response.json()
assert score_output.keys() == invocation_output.keys()
for score_data, invocation_data in zip(score_output["data"],
invocation_output["data"]):
assert score_data.keys() == invocation_data.keys()
assert score_data["score"] == pytest.approx(
invocation_data["score"], rel=0.01)
......@@ -8,6 +8,8 @@ from dataclasses import dataclass, field
from typing import Any, Optional
from unittest.mock import MagicMock
import pytest
from vllm.config import MultiModalConfig
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
......@@ -75,7 +77,8 @@ def test_async_serving_chat_init():
assert serving_completion.chat_template == CHAT_TEMPLATE
def test_serving_chat_should_set_correct_max_tokens():
@pytest.mark.asyncio
async def test_serving_chat_should_set_correct_max_tokens():
mock_engine = MagicMock(spec=MQLLMEngineClient)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
......@@ -90,6 +93,7 @@ def test_serving_chat_should_set_correct_max_tokens():
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
request_logger=None)
req = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{
......@@ -100,13 +104,13 @@ def test_serving_chat_should_set_correct_max_tokens():
)
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 93
req.max_tokens = 10
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 10
......@@ -145,7 +149,7 @@ def test_serving_chat_should_set_correct_max_tokens():
)
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 10
......@@ -153,7 +157,7 @@ def test_serving_chat_should_set_correct_max_tokens():
req.max_tokens = 15
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 10
......@@ -161,7 +165,7 @@ def test_serving_chat_should_set_correct_max_tokens():
req.max_tokens = 5
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 5
......@@ -200,7 +204,7 @@ def test_serving_chat_should_set_correct_max_tokens():
)
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 93
......@@ -208,7 +212,7 @@ def test_serving_chat_should_set_correct_max_tokens():
req.max_tokens = 100
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 93
......@@ -216,12 +220,13 @@ def test_serving_chat_should_set_correct_max_tokens():
req.max_tokens = 5
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 5
def test_serving_chat_could_load_correct_generation_config():
@pytest.mark.asyncio
async def test_serving_chat_could_load_correct_generation_config():
mock_model_config = MockModelConfig()
mock_model_config.diff_sampling_param = {
......@@ -244,6 +249,7 @@ def test_serving_chat_could_load_correct_generation_config():
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
request_logger=None)
req = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{
......@@ -254,7 +260,7 @@ def test_serving_chat_could_load_correct_generation_config():
)
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].temperature == 0.5
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
......@@ -263,7 +269,7 @@ def test_serving_chat_could_load_correct_generation_config():
req.temperature = 0.1
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].temperature == 0.1
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
......@@ -272,13 +278,14 @@ def test_serving_chat_could_load_correct_generation_config():
req.temperature = 0.0
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].temperature == 0.0
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
def test_serving_chat_did_set_correct_cache_salt():
@pytest.mark.asyncio
async def test_serving_chat_did_set_correct_cache_salt():
mock_model_config = MockModelConfig()
mock_engine = MagicMock(spec=MQLLMEngineClient)
......@@ -308,11 +315,11 @@ def test_serving_chat_did_set_correct_cache_salt():
# By default cache_salt in the engine prompt is not set
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
await serving_chat.create_chat_completion(req)
assert "cache_salt" not in mock_engine.generate.call_args.args[0]
# Test with certain cache_salt
req.cache_salt = "test_salt"
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[0]["cache_salt"] == "test_salt"
......@@ -34,8 +34,7 @@ async def _async_serving_models_init() -> OpenAIServingModels:
serving_models = OpenAIServingModels(engine_client=mock_engine_client,
base_model_paths=BASE_MODEL_PATHS,
model_config=mock_model_config,
lora_modules=None,
prompt_adapters=None)
lora_modules=None)
await serving_models.init_static_loras()
return serving_models
......@@ -59,7 +58,8 @@ async def test_load_lora_adapter_success():
response = await serving_models.load_lora_adapter(request)
assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name='adapter')
assert len(serving_models.lora_requests) == 1
assert serving_models.lora_requests[0].lora_name == "adapter"
assert "adapter" in serving_models.lora_requests
assert serving_models.lora_requests["adapter"].lora_name == "adapter"
@pytest.mark.asyncio
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
import json
import os
import tempfile
import openai
......@@ -58,18 +58,20 @@ def tensorize_model_and_lora(tmp_dir, model_uri):
@pytest.fixture(scope="module")
def server(model_uri, tensorize_model_and_lora):
model_loader_extra_config = {
"tensorizer_uri": model_uri,
}
# In this case, model_uri is a directory with a model.tensors
# file and all necessary model artifacts, particularly a
# HF `config.json` file. In this case, Tensorizer can infer the
# `TensorizerConfig` so --model-loader-extra-config can be completely
# omitted.
## Start OpenAI API server
args = [
"--load-format", "tensorizer", "--device", "cuda",
"--model-loader-extra-config",
json.dumps(model_loader_extra_config), "--enable-lora"
"--load-format", "tensorizer", "--served-model-name", MODEL_NAME,
"--enable-lora"
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
model_dir = os.path.dirname(model_uri)
with RemoteOpenAIServer(model_dir, args) as remote_server:
yield remote_server
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment