Commit 711aa9d5 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.0' into v0.10.0-dev

parents 751c492c 6d8d0a24
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import multiprocessing import multiprocessing
import os import os
import numpy as np
import pytest import pytest
import torch import torch
import torch.distributed import torch.distributed
...@@ -177,6 +178,38 @@ def test_pynccl_all_gather(): ...@@ -177,6 +178,38 @@ def test_pynccl_all_gather():
distributed_run(all_gather_worker_fn, 2) distributed_run(all_gather_worker_fn, 2)
@worker_fn_wrapper
def all_gatherv_worker_fn():
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
device=get_world_group().device)
rank = pynccl_comm.rank
world_size = pynccl_comm.world_size
device = f'cuda:{pynccl_comm.rank}'
assert world_size <= 8
sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size]
num_elems = sizes[rank]
tensor = torch.arange(num_elems, dtype=torch.float32,
device=device) + rank * 100
result = torch.zeros(sum(sizes), dtype=torch.float32, device=device)
expected = torch.cat([
torch.arange(sizes[r], dtype=torch.float32) + r * 100
for r in range(world_size)
]).to(device)
pynccl_comm.all_gatherv(result, tensor, sizes=sizes)
torch.cuda.synchronize()
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
def test_pynccl_all_gatherv():
distributed_run(all_gatherv_worker_fn, 2)
@worker_fn_wrapper @worker_fn_wrapper
def reduce_scatter_worker_fn(): def reduce_scatter_worker_fn():
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
...@@ -214,6 +247,43 @@ def test_pynccl_reduce_scatter(): ...@@ -214,6 +247,43 @@ def test_pynccl_reduce_scatter():
distributed_run(reduce_scatter_worker_fn, 2) distributed_run(reduce_scatter_worker_fn, 2)
@worker_fn_wrapper
def reduce_scatterv_worker_fn():
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
device=get_world_group().device)
rank = pynccl_comm.rank
world_size = pynccl_comm.world_size
device = f'cuda:{pynccl_comm.rank}'
assert world_size <= 8
sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size]
num_elems = sum(sizes)
tensor = torch.arange(num_elems, dtype=torch.float32,
device=device) + rank * 100
result = torch.zeros(sizes[rank], dtype=torch.float32, device=device)
# Calculate expected result for this rank's chunk
all_tensors = [
torch.arange(num_elems, dtype=torch.float32) + r * 100
for r in range(world_size)
]
sizes_cumsum = np.cumsum(sizes)
start = 0 if rank == 0 else sizes_cumsum[rank - 1]
end = sizes_cumsum[rank]
expected = sum(tensor[start:end] for tensor in all_tensors).to(device)
pynccl_comm.reduce_scatterv(result, tensor, sizes=sizes)
torch.cuda.synchronize()
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
def test_pynccl_reduce_scatterv():
distributed_run(reduce_scatterv_worker_fn, 2)
@pytest.mark.skipif(torch.cuda.device_count() < 2, @pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.") reason="Need at least 2 GPUs to run the test.")
def test_pynccl_with_cudagraph(): def test_pynccl_with_cudagraph():
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json import json
from argparse import ArgumentError, ArgumentTypeError from argparse import ArgumentError
from contextlib import nullcontext from contextlib import nullcontext
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Annotated, Literal, Optional from typing import Annotated, Literal, Optional
...@@ -12,8 +12,8 @@ import pytest ...@@ -12,8 +12,8 @@ import pytest
from vllm.config import CompilationConfig, config from vllm.config import CompilationConfig, config
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs, from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
get_type, get_type_hints, is_not_builtin, get_type, get_type_hints, is_not_builtin,
is_type, literal_to_kwargs, nullable_kvs, is_type, literal_to_kwargs, optional_type,
optional_type, parse_type) parse_type)
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
...@@ -25,18 +25,10 @@ from vllm.utils import FlexibleArgumentParser ...@@ -25,18 +25,10 @@ from vllm.utils import FlexibleArgumentParser
"foo": 1, "foo": 1,
"bar": 2 "bar": 2
}), }),
(json.loads, "foo=1,bar=2", {
"foo": 1,
"bar": 2
}),
]) ])
def test_parse_type(type, value, expected): def test_parse_type(type, value, expected):
parse_type_func = parse_type(type) parse_type_func = parse_type(type)
context = nullcontext() assert parse_type_func(value) == expected
if value == "foo=1,bar=2":
context = pytest.warns(DeprecationWarning)
with context:
assert parse_type_func(value) == expected
def test_optional_type(): def test_optional_type():
...@@ -203,34 +195,6 @@ def test_get_kwargs(): ...@@ -203,34 +195,6 @@ def test_get_kwargs():
assert kwargs["from_cli_config2"]["type"]('{"field": 2}').field == 4 assert kwargs["from_cli_config2"]["type"]('{"field": 2}').field == 4
@pytest.mark.parametrize(("arg", "expected"), [
(None, dict()),
("image=16", {
"image": 16
}),
("image=16,video=2", {
"image": 16,
"video": 2
}),
("Image=16, Video=2", {
"image": 16,
"video": 2
}),
])
def test_limit_mm_per_prompt_parser(arg, expected):
"""This functionality is deprecated and will be removed in the future.
This argument should be passed as JSON string instead.
TODO: Remove with nullable_kvs."""
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
if arg is None:
args = parser.parse_args([])
else:
args = parser.parse_args(["--limit-mm-per-prompt", arg])
assert args.limit_mm_per_prompt == expected
@pytest.mark.parametrize( @pytest.mark.parametrize(
("arg", "expected"), ("arg", "expected"),
[ [
...@@ -326,18 +290,6 @@ def test_prefix_cache_default(): ...@@ -326,18 +290,6 @@ def test_prefix_cache_default():
assert not engine_args.enable_prefix_caching assert not engine_args.enable_prefix_caching
@pytest.mark.parametrize(
("arg"),
[
"image", # Missing =
"image=4,image=5", # Conflicting values
"image=video=4" # Too many = in tokenized arg
])
def test_bad_nullable_kvs(arg):
with pytest.raises(ArgumentTypeError):
nullable_kvs(arg)
# yapf: disable # yapf: disable
@pytest.mark.parametrize(("arg", "expected", "option"), [ @pytest.mark.parametrize(("arg", "expected", "option"), [
(None, None, "mm-processor-kwargs"), (None, None, "mm-processor-kwargs"),
......
...@@ -17,16 +17,19 @@ from vllm.platforms import current_platform ...@@ -17,16 +17,19 @@ from vllm.platforms import current_platform
from ...utils import models_path_prefix from ...utils import models_path_prefix
MODEL_NAMES = [ MODEL_NAMES = [
os.path.join(models_path_prefix, "Qwen/Qwen2-1.5B-Instruct"), os.path.join(models_path_prefix, "Qwen/Qwen3-1.7B"),
os.path.join(models_path_prefix, "google/gemma-3-1b-it"), os.path.join(models_path_prefix, "google/gemma-3-1b-it"),
] ]
FP8_KV_MODEL_NAMES = [
os.path.join(models_path_prefix, "Qwen/Qwen3-1.7B"),
]
NUM_CONCURRENT = 500 NUM_CONCURRENT = 500
TASK = "gsm8k" TASK = "gsm8k"
FILTER = "exact_match,strict-match" FILTER = "exact_match,strict-match"
RTOL = 0.03 RTOL = 0.03
EXPECTED_VALUES = { EXPECTED_VALUES = {
"Qwen/Qwen2-1.5B-Instruct": 0.58, os.path.join(models_path_prefix, "Qwen/Qwen3-1.7B"): 0.68,
"google/gemma-3-1b-it": 0.25, os.path.join(models_path_prefix, "google/gemma-3-1b-it"): 0.25,
} }
...@@ -71,6 +74,10 @@ def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch): ...@@ -71,6 +74,10 @@ def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch):
more_args = None more_args = None
if current_platform.is_tpu(): if current_platform.is_tpu():
# Limit compilation time for TPU V1 # Limit compilation time for TPU V1
# xet doesn't work well for both Qwen/Qwen3-1.7B and
# google/gemma-3-1b-it
m.setenv("HF_HUB_DISABLE_XET", "1")
more_args = "max_model_len=2048,max_num_seqs=64" more_args = "max_model_len=2048,max_num_seqs=64"
# Add TP test (if provided) # Add TP test (if provided)
...@@ -80,9 +87,27 @@ def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch): ...@@ -80,9 +87,27 @@ def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch):
run_test(model, more_args) run_test(model, more_args)
def test_lm_eval_accuracy_v0_engine(monkeypatch: pytest.MonkeyPatch): @pytest.mark.skipif(not current_platform.is_cuda()
"""Run with the V0 Engine.""" and not current_platform.is_tpu(),
reason="V1 is currently only supported on CUDA and TPU")
@pytest.mark.parametrize("model", FP8_KV_MODEL_NAMES)
def test_lm_eval_accuracy_v1_engine_fp8_kv_cache(
model, monkeypatch: pytest.MonkeyPatch):
"""Run with the V1 Engine."""
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "0") m.setenv("VLLM_USE_V1", "1")
run_test(os.path.join(models_path_prefix,"Qwen/Qwen2-1.5B-Instruct"))
\ No newline at end of file more_args = None
if current_platform.is_tpu():
# Limit compilation time for TPU V1
# xet doesn't work well for Qwen/Qwen3-1.7B
m.setenv("HF_HUB_DISABLE_XET", "1")
more_args = "max_model_len=2048,max_num_seqs=128,kv_cache_dtype=fp8"
# Add TP test (if provided)
if TPU_TP_TEST_STR:
more_args += ",{}".format(TPU_TP_TEST_STR)
run_test(model, more_args)
...@@ -18,15 +18,19 @@ from vllm.outputs import RequestOutput ...@@ -18,15 +18,19 @@ from vllm.outputs import RequestOutput
from vllm.sampling_params import GuidedDecodingParams, SamplingParams from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from ...utils import models_path_prefix from ...utils import models_path_prefix
MODEL_NAME = os.path.join(models_path_prefix, "Qwen2.5-1.5B-Instruct") MODEL_NAME = os.path.join(models_path_prefix, "Qwen/Qwen2.5-1.5B-Instruct")
GUIDED_DECODING_BACKENDS = [
# Separate backends which support grammars vs ones
# which only support regex based constraints in tests.
GRAMMAR_DECODING_BACKENDS = [
# (backend, disable_any_whitespace), # (backend, disable_any_whitespace),
("outlines", False),
("lm-format-enforcer", False), ("lm-format-enforcer", False),
("xgrammar", True), ("xgrammar", True),
("guidance", True), ("guidance", True),
] ]
ALL_DECODING_BACKENDS = ([("outlines", False)] + GRAMMAR_DECODING_BACKENDS)
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def llm(): def llm():
...@@ -42,7 +46,7 @@ def llm(): ...@@ -42,7 +46,7 @@ def llm():
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS) ALL_DECODING_BACKENDS)
def test_guided_regex(sample_regex, llm, guided_decoding_backend: str, def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
disable_any_whitespace: bool): disable_any_whitespace: bool):
sampling_params = SamplingParams( sampling_params = SamplingParams(
...@@ -52,6 +56,7 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str, ...@@ -52,6 +56,7 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
regex=sample_regex, regex=sample_regex,
backend=guided_decoding_backend, backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace)) disable_any_whitespace=disable_any_whitespace))
outputs = llm.generate(prompts=[ outputs = llm.generate(prompts=[
f"Give an example IPv4 address with this regex: {sample_regex}" f"Give an example IPv4 address with this regex: {sample_regex}"
] * 2, ] * 2,
...@@ -72,7 +77,7 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str, ...@@ -72,7 +77,7 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS) ALL_DECODING_BACKENDS)
def test_guided_json_completion(sample_json_schema, llm, def test_guided_json_completion(sample_json_schema, llm,
guided_decoding_backend: str, guided_decoding_backend: str,
disable_any_whitespace: bool): disable_any_whitespace: bool):
...@@ -106,7 +111,7 @@ def test_guided_json_completion(sample_json_schema, llm, ...@@ -106,7 +111,7 @@ def test_guided_json_completion(sample_json_schema, llm,
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS) ALL_DECODING_BACKENDS)
def test_guided_complex_json_completion(sample_complex_json_schema, llm, def test_guided_complex_json_completion(sample_complex_json_schema, llm,
guided_decoding_backend: str, guided_decoding_backend: str,
disable_any_whitespace: bool): disable_any_whitespace: bool):
...@@ -141,7 +146,7 @@ def test_guided_complex_json_completion(sample_complex_json_schema, llm, ...@@ -141,7 +146,7 @@ def test_guided_complex_json_completion(sample_complex_json_schema, llm,
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS) ALL_DECODING_BACKENDS)
def test_guided_definition_json_completion(sample_definition_json_schema, llm, def test_guided_definition_json_completion(sample_definition_json_schema, llm,
guided_decoding_backend: str, guided_decoding_backend: str,
disable_any_whitespace: bool): disable_any_whitespace: bool):
...@@ -176,7 +181,7 @@ def test_guided_definition_json_completion(sample_definition_json_schema, llm, ...@@ -176,7 +181,7 @@ def test_guided_definition_json_completion(sample_definition_json_schema, llm,
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS) ALL_DECODING_BACKENDS)
def test_guided_enum_json_completion(sample_enum_json_schema, llm, def test_guided_enum_json_completion(sample_enum_json_schema, llm,
guided_decoding_backend: str, guided_decoding_backend: str,
disable_any_whitespace: bool): disable_any_whitespace: bool):
...@@ -221,7 +226,7 @@ def test_guided_enum_json_completion(sample_enum_json_schema, llm, ...@@ -221,7 +226,7 @@ def test_guided_enum_json_completion(sample_enum_json_schema, llm,
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS) ALL_DECODING_BACKENDS)
def test_guided_choice_completion(sample_guided_choice, llm, def test_guided_choice_completion(sample_guided_choice, llm,
guided_decoding_backend: str, guided_decoding_backend: str,
disable_any_whitespace: bool): disable_any_whitespace: bool):
...@@ -251,7 +256,7 @@ def test_guided_choice_completion(sample_guided_choice, llm, ...@@ -251,7 +256,7 @@ def test_guided_choice_completion(sample_guided_choice, llm,
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS) GRAMMAR_DECODING_BACKENDS)
def test_guided_grammar(sample_sql_statements, llm, def test_guided_grammar(sample_sql_statements, llm,
guided_decoding_backend: str, guided_decoding_backend: str,
disable_any_whitespace: bool): disable_any_whitespace: bool):
...@@ -347,7 +352,7 @@ def test_disable_guided_decoding_fallback(sample_regex, llm): ...@@ -347,7 +352,7 @@ def test_disable_guided_decoding_fallback(sample_regex, llm):
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS) GRAMMAR_DECODING_BACKENDS)
def test_guided_json_object(llm, guided_decoding_backend: str, def test_guided_json_object(llm, guided_decoding_backend: str,
disable_any_whitespace: bool): disable_any_whitespace: bool):
sampling_params = SamplingParams( sampling_params = SamplingParams(
...@@ -380,7 +385,9 @@ def test_guided_json_object(llm, guided_decoding_backend: str, ...@@ -380,7 +385,9 @@ def test_guided_json_object(llm, guided_decoding_backend: str,
# Parse to verify it is valid JSON # Parse to verify it is valid JSON
parsed_json = json.loads(generated_text) parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict) # A list is not what was intended, but is still valid
# json.
assert isinstance(parsed_json, (dict, list))
class CarType(str, Enum): class CarType(str, Enum):
...@@ -398,7 +405,7 @@ class CarDescription(BaseModel): ...@@ -398,7 +405,7 @@ class CarDescription(BaseModel):
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS) ALL_DECODING_BACKENDS)
def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str, def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str,
disable_any_whitespace: bool): disable_any_whitespace: bool):
json_schema = CarDescription.model_json_schema() json_schema = CarDescription.model_json_schema()
...@@ -430,7 +437,7 @@ def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str, ...@@ -430,7 +437,7 @@ def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str,
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS) ALL_DECODING_BACKENDS)
def test_guided_number_range_json_completion(llm, guided_decoding_backend: str, def test_guided_number_range_json_completion(llm, guided_decoding_backend: str,
disable_any_whitespace: bool): disable_any_whitespace: bool):
sample_output_schema = { sample_output_schema = {
......
...@@ -70,8 +70,9 @@ def run_test(more_args): ...@@ -70,8 +70,9 @@ def run_test(more_args):
@pytest.mark.skipif(not current_platform.is_cuda() @pytest.mark.skipif(not current_platform.is_cuda()
and not current_platform.is_tpu(), and not current_platform.is_tpu()
reason="V1 currently only supported on CUDA and TPU") and not current_platform.is_xpu(),
reason="V1 currently only supported on CUDA, XPU and TPU")
def test_lm_eval_accuracy_v1_engine(monkeypatch: pytest.MonkeyPatch): def test_lm_eval_accuracy_v1_engine(monkeypatch: pytest.MonkeyPatch):
"""Run with the V1 Engine.""" """Run with the V1 Engine."""
......
...@@ -1118,10 +1118,7 @@ async def test_http_chat_no_model_name_with_curl(server: RemoteOpenAIServer): ...@@ -1118,10 +1118,7 @@ async def test_http_chat_no_model_name_with_curl(server: RemoteOpenAIServer):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME, ""]) async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer):
async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer,
model_name: str):
openai_api_key = "EMPTY" openai_api_key = "EMPTY"
openai_api_base = f"http://localhost:{server.port}/v1" openai_api_base = f"http://localhost:{server.port}/v1"
...@@ -1140,3 +1137,35 @@ async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer, ...@@ -1140,3 +1137,35 @@ async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer,
messages=messages, messages=messages,
) )
assert response.model == MODEL_NAME assert response.model == MODEL_NAME
@pytest.mark.asyncio
async def test_invocations(server: RemoteOpenAIServer,
client: openai.AsyncOpenAI):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role": "user",
"content": "what is 1+1?"
}]
request_args = {
"model": MODEL_NAME,
"messages": messages,
"max_completion_tokens": 5,
"temperature": 0.0,
"logprobs": False,
}
chat_completion = await client.chat.completions.create(**request_args)
invocation_response = requests.post(server.url_for("invocations"),
json=request_args)
invocation_response.raise_for_status()
chat_output = chat_completion.model_dump()
invocation_output = invocation_response.json()
assert chat_output.keys() == invocation_output.keys()
assert chat_output["choices"] == invocation_output["choices"]
...@@ -155,3 +155,29 @@ def test_batch_classification_empty_list(server: RemoteOpenAIServer, ...@@ -155,3 +155,29 @@ def test_batch_classification_empty_list(server: RemoteOpenAIServer,
assert output.object == "list" assert output.object == "list"
assert isinstance(output.data, list) assert isinstance(output.data, list)
assert len(output.data) == 0 assert len(output.data) == 0
@pytest.mark.asyncio
async def test_invocations(server: RemoteOpenAIServer):
request_args = {
"model": MODEL_NAME,
"input": "This product was excellent and exceeded my expectations"
}
classification_response = requests.post(server.url_for("classify"),
json=request_args)
classification_response.raise_for_status()
invocation_response = requests.post(server.url_for("invocations"),
json=request_args)
invocation_response.raise_for_status()
classification_output = classification_response.json()
invocation_output = invocation_response.json()
assert classification_output.keys() == invocation_output.keys()
for classification_data, invocation_data in zip(
classification_output["data"], invocation_output["data"]):
assert classification_data.keys() == invocation_data.keys()
assert classification_data["probs"] == pytest.approx(
invocation_data["probs"], rel=0.01)
...@@ -153,3 +153,13 @@ def test_chat_template_validation_for_sad_paths(serve_parser): ...@@ -153,3 +153,13 @@ def test_chat_template_validation_for_sad_paths(serve_parser):
args = serve_parser.parse_args(args=["--chat-template", "does/not/exist"]) args = serve_parser.parse_args(args=["--chat-template", "does/not/exist"])
with pytest.raises(ValueError): with pytest.raises(ValueError):
validate_parsed_serve_args(args) validate_parsed_serve_args(args)
@pytest.mark.parametrize(
"cli_args, expected_middleware",
[(["--middleware", "middleware1", "--middleware", "middleware2"
], ["middleware1", "middleware2"]), ([], [])])
def test_middleware(serve_parser, cli_args, expected_middleware):
"""Ensure multiple middleware args are parsed properly"""
args = serve_parser.parse_args(args=cli_args)
assert args.middleware == expected_middleware
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# imports for guided decoding tests # imports for guided decoding tests
import json import json
import os
import shutil import shutil
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Optional from typing import Optional
...@@ -12,6 +13,7 @@ import pytest ...@@ -12,6 +13,7 @@ import pytest
import os import os
import pytest_asyncio import pytest_asyncio
import regex as re import regex as re
import requests
# downloading lora to test lora requests # downloading lora to test lora requests
# from huggingface_hub import snapshot_download # from huggingface_hub import snapshot_download
from openai import BadRequestError from openai import BadRequestError
...@@ -26,10 +28,6 @@ MODEL_NAME = os.path.join(models_path_prefix, "HuggingFaceH4/zephyr-7b-beta") ...@@ -26,10 +28,6 @@ MODEL_NAME = os.path.join(models_path_prefix, "HuggingFaceH4/zephyr-7b-beta")
# technically these adapters use a different base model, # technically these adapters use a different base model,
# but we're not testing generation quality here # but we're not testing generation quality here
LORA_NAME = os.path.join(models_path_prefix, "typeof/zephyr-7b-beta-lora") LORA_NAME = os.path.join(models_path_prefix, "typeof/zephyr-7b-beta-lora")
PA_NAME = os.path.join(models_path_prefix, "swapnilbp/llama_tweet_ptune")
# if PA_NAME changes, PA_NUM_VIRTUAL_TOKENS might also
# need to change to match the prompt adapter
PA_NUM_VIRTUAL_TOKENS = 8
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"] GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
...@@ -57,14 +55,7 @@ def zephyr_lora_added_tokens_files(zephyr_lora_files): ...@@ -57,14 +55,7 @@ def zephyr_lora_added_tokens_files(zephyr_lora_files):
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def zephyr_pa_files(): def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files):
# return snapshot_download(repo_id=PA_NAME)
return PA_NAME
@pytest.fixture(scope="module")
def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files,
zephyr_pa_files):
return [ return [
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
"--dtype", "--dtype",
...@@ -83,15 +74,6 @@ def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files, ...@@ -83,15 +74,6 @@ def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files,
"64", "64",
"--max-cpu-loras", "--max-cpu-loras",
"2", "2",
# pa config
"--enable-prompt-adapter",
"--prompt-adapters",
f"zephyr-pa={zephyr_pa_files}",
f"zephyr-pa2={zephyr_pa_files}",
"--max-prompt-adapters",
"2",
"--max-prompt-adapter-token",
"128",
] ]
...@@ -100,8 +82,19 @@ def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files, ...@@ -100,8 +82,19 @@ def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files,
def server(default_server_args, request): def server(default_server_args, request):
if request.param: if request.param:
default_server_args.append(request.param) default_server_args.append(request.param)
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
yield remote_server original_value = os.environ.get('VLLM_USE_V1')
os.environ['VLLM_USE_V1'] = '0'
try:
with RemoteOpenAIServer(MODEL_NAME,
default_server_args) as remote_server:
yield remote_server
finally:
# Restore original env value
if original_value is None:
os.environ.pop('VLLM_USE_V1', None)
else:
os.environ['VLLM_USE_V1'] = original_value
@pytest_asyncio.fixture @pytest_asyncio.fixture
...@@ -112,14 +105,11 @@ async def client(server): ...@@ -112,14 +105,11 @@ async def client(server):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
# first test base model, then test loras, then test prompt adapters # first test base model, then test loras
"model_name,num_virtual_tokens", "model_name",
[(MODEL_NAME, 0), ("zephyr-lora", 0), ("zephyr-lora2", 0), [MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
("zephyr-pa", PA_NUM_VIRTUAL_TOKENS),
("zephyr-pa2", PA_NUM_VIRTUAL_TOKENS)],
) )
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str, async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
num_virtual_tokens: int):
completion = await client.completions.create(model=model_name, completion = await client.completions.create(model=model_name,
prompt="Hello, my name is", prompt="Hello, my name is",
max_tokens=5, max_tokens=5,
...@@ -132,9 +122,7 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str, ...@@ -132,9 +122,7 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str,
assert len(choice.text) >= 5 assert len(choice.text) >= 5
assert choice.finish_reason == "length" assert choice.finish_reason == "length"
assert completion.usage == openai.types.CompletionUsage( assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5, completion_tokens=5, prompt_tokens=6, total_tokens=11)
prompt_tokens=6 + num_virtual_tokens,
total_tokens=11 + num_virtual_tokens)
# test using token IDs # test using token IDs
completion = await client.completions.create( completion = await client.completions.create(
...@@ -177,9 +165,9 @@ async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI): ...@@ -177,9 +165,9 @@ async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
# first test base model, then test loras, then test prompt adapters # first test base model, then test loras
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2", "zephyr-pa", "zephyr-pa2"], [MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
) )
async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str): async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str):
# test using token IDs # test using token IDs
...@@ -196,9 +184,9 @@ async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str): ...@@ -196,9 +184,9 @@ async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
# just test 1 lora and 1 pa hereafter # just test 1 lora
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"], [MODEL_NAME, "zephyr-lora"],
) )
async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str): async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str):
# test using token IDs # test using token IDs
...@@ -219,7 +207,7 @@ async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str): ...@@ -219,7 +207,7 @@ async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"], [MODEL_NAME, "zephyr-lora"],
) )
async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str): async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str):
# test using token IDs # test using token IDs
...@@ -240,7 +228,7 @@ async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str): ...@@ -240,7 +228,7 @@ async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"], [MODEL_NAME, "zephyr-lora"],
) )
async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI, async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
model_name: str): model_name: str):
...@@ -316,7 +304,7 @@ async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI, ...@@ -316,7 +304,7 @@ async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"], [MODEL_NAME, "zephyr-lora"],
) )
async def test_completion_streaming(client: openai.AsyncOpenAI, async def test_completion_streaming(client: openai.AsyncOpenAI,
model_name: str): model_name: str):
...@@ -350,7 +338,7 @@ async def test_completion_streaming(client: openai.AsyncOpenAI, ...@@ -350,7 +338,7 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"], [MODEL_NAME, "zephyr-lora"],
) )
async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str): async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
"""Streaming for parallel sampling. """Streaming for parallel sampling.
...@@ -384,7 +372,7 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str): ...@@ -384,7 +372,7 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"], [MODEL_NAME, "zephyr-lora"],
) )
async def test_completion_stream_options(client: openai.AsyncOpenAI, async def test_completion_stream_options(client: openai.AsyncOpenAI,
model_name: str): model_name: str):
...@@ -521,7 +509,7 @@ async def test_completion_stream_options(client: openai.AsyncOpenAI, ...@@ -521,7 +509,7 @@ async def test_completion_stream_options(client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"], [MODEL_NAME, "zephyr-lora"],
) )
async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str): async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str):
# test both text and token IDs # test both text and token IDs
...@@ -836,3 +824,27 @@ async def test_echo_stream_completion(client: openai.AsyncOpenAI, ...@@ -836,3 +824,27 @@ async def test_echo_stream_completion(client: openai.AsyncOpenAI,
assert content is not None and saying in content assert content is not None and saying in content
else: else:
assert content is not None and saying not in content assert content is not None and saying not in content
@pytest.mark.asyncio
async def test_invocations(server: RemoteOpenAIServer,
client: openai.AsyncOpenAI):
request_args = {
"model": MODEL_NAME,
"prompt": "Hello, my name is",
"max_tokens": 5,
"temperature": 0.0,
"logprobs": None,
}
completion = await client.completions.create(**request_args)
invocation_response = requests.post(server.url_for("invocations"),
json=request_args)
invocation_response.raise_for_status()
completion_output = completion.model_dump()
invocation_output = invocation_response.json()
assert completion_output.keys() == invocation_output.keys()
assert completion_output["choices"] == invocation_output["choices"]
...@@ -72,8 +72,43 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str, ...@@ -72,8 +72,43 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str,
"The unit to fetch the temperature in", "The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"], "enum": ["celsius", "fahrenheit"],
}, },
"options": {
"$ref": "#/$defs/WeatherOptions",
"description":
"Optional parameters for weather query",
},
}, },
"required": ["country", "unit"], "required": ["country", "unit"],
"$defs": {
"WeatherOptions": {
"title": "WeatherOptions",
"type": "object",
"additionalProperties": False,
"properties": {
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"default": "celsius",
"description": "Temperature unit",
"title": "Temperature Unit",
},
"include_forecast": {
"type": "boolean",
"default": False,
"description":
"Whether to include a 24-hour forecast",
"title": "Include Forecast",
},
"language": {
"type": "string",
"default": "zh-CN",
"description": "Language of the response",
"title": "Language",
"enum": ["zh-CN", "en-US", "ja-JP"],
},
},
},
},
}, },
}, },
}, },
...@@ -145,7 +180,11 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str, ...@@ -145,7 +180,11 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str,
"enable_thinking": enable_thinking "enable_thinking": enable_thinking
} }
}) })
if enable_thinking:
assert chat_completion.choices[0].message.\
reasoning_content is not None
assert chat_completion.choices[0].message.\
reasoning_content != ""
assert chat_completion.choices[0].message.tool_calls is not None assert chat_completion.choices[0].message.tool_calls is not None
assert len(chat_completion.choices[0].message.tool_calls) > 0 assert len(chat_completion.choices[0].message.tool_calls) > 0
else: else:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import openai # use the official client for correctness check
import pytest
import pytest_asyncio
from huggingface_hub import snapshot_download
from ...conftest import AudioTestAssets
from ...utils import RemoteOpenAIServer
# NOTE - the tests in this module are currently analogous to test_chat, but are
# separated to avoid OOM killing due to module-scoped servers, since we
# need a multimodal model for these tests.
# Contains a modality specific lora alongside the base model
MULTIMODAL_MODEL_NAME = snapshot_download(
"microsoft/Phi-4-multimodal-instruct")
AUDIO_LORA_PATH = os.path.join(MULTIMODAL_MODEL_NAME, "speech-lora")
ACTIVE_MM_LORA_RESPONSE = "Spoken text: The first words I spoke in the original chronograph, a little piece of practical poetry. Mary had a little lamb, it slept with quite a snow, and everywhere that Mary went, the lamb was sure to go." # noqa: E501
@pytest.fixture(scope="module")
def monkeypatch_module():
from _pytest.monkeypatch import MonkeyPatch
mpatch = MonkeyPatch()
yield mpatch
mpatch.undo()
@pytest.fixture(scope="module", params=[False, True])
def multimodal_server(request, monkeypatch_module): # noqa: F811
use_v1 = request.param
monkeypatch_module.setenv('VLLM_USE_V1', '1' if use_v1 else '0')
args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"half",
"--max-model-len",
"12800",
"--enforce-eager",
# lora config below
"--enable-lora",
"--lora-modules",
f"speech={AUDIO_LORA_PATH}",
"--max-lora-rank",
"320",
"--max-num-seqs",
"2",
"--trust-remote-code",
"--gpu-memory-utilization",
"0.8",
"--default-mm-loras",
f"{{\"audio\": \"{AUDIO_LORA_PATH}\"}}",
]
with RemoteOpenAIServer(MULTIMODAL_MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def multi_modal_client(multimodal_server):
async with multimodal_server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
@pytest.mark.parametrize(
# base model with default lora should give the same response as lora model
"model_name",
[MULTIMODAL_MODEL_NAME, "speech"],
)
async def test_default_mm_lora_chat_completions(
model_name: str,
multi_modal_client: openai.AsyncOpenAI,
audio_assets: AudioTestAssets,
):
messages = [{
"role":
"user",
"content": [{
"type": "text",
"text": "Can you transcribe this audio?",
}, {
"type": "audio_url",
"audio_url": {
"url": audio_assets[0].url
},
}]
}]
chat_completion = await multi_modal_client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=128,
temperature=0.0)
assert len(chat_completion.choices) > 0
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 0
assert message.content == ACTIVE_MM_LORA_RESPONSE
...@@ -15,6 +15,7 @@ from vllm.transformers_utils.tokenizer import get_tokenizer ...@@ -15,6 +15,7 @@ from vllm.transformers_utils.tokenizer import get_tokenizer
from ...models.language.pooling.embed_utils import ( from ...models.language.pooling.embed_utils import (
run_embedding_correctness_test) run_embedding_correctness_test)
from ...models.utils import check_embeddings_close
from ...utils import RemoteOpenAIServer, models_path_prefix from ...utils import RemoteOpenAIServer, models_path_prefix
...@@ -298,3 +299,75 @@ async def test_single_embedding_truncation_invalid(client: openai.AsyncOpenAI, ...@@ -298,3 +299,75 @@ async def test_single_embedding_truncation_invalid(client: openai.AsyncOpenAI,
assert "error" in response.object assert "error" in response.object
assert "truncate_prompt_tokens value is greater than max_model_len. "\ assert "truncate_prompt_tokens value is greater than max_model_len. "\
"Please, select a smaller truncation size." in response.message "Please, select a smaller truncation size." in response.message
@pytest.mark.asyncio
async def test_invocations(server: RemoteOpenAIServer,
client: openai.AsyncOpenAI):
input_texts = [
"The chef prepared a delicious meal.",
]
request_args = {
"model": MODEL_NAME,
"input": input_texts,
"encoding_format": "float",
}
completion_response = await client.embeddings.create(**request_args)
invocation_response = requests.post(server.url_for("invocations"),
json=request_args)
invocation_response.raise_for_status()
completion_output = completion_response.model_dump()
invocation_output = invocation_response.json()
assert completion_output.keys() == invocation_output.keys()
for completion_data, invocation_data in zip(completion_output["data"],
invocation_output["data"]):
assert completion_data.keys() == invocation_data.keys()
check_embeddings_close(embeddings_0_lst=[completion_data["embedding"]],
embeddings_1_lst=[invocation_data["embedding"]],
name_0="completion",
name_1="invocation")
@pytest.mark.asyncio
async def test_invocations_conversation(server: RemoteOpenAIServer):
messages = [{
"role": "user",
"content": "The cat sat on the mat.",
}, {
"role": "assistant",
"content": "A feline was resting on a rug.",
}, {
"role": "user",
"content": "Stars twinkle brightly in the night sky.",
}]
request_args = {
"model": MODEL_NAME,
"messages": messages,
"encoding_format": "float",
}
chat_response = requests.post(server.url_for("v1/embeddings"),
json=request_args)
chat_response.raise_for_status()
invocation_response = requests.post(server.url_for("invocations"),
json=request_args)
invocation_response.raise_for_status()
chat_output = chat_response.json()
invocation_output = invocation_response.json()
assert chat_output.keys() == invocation_output.keys()
for chat_data, invocation_data in zip(chat_output["data"],
invocation_output["data"]):
assert chat_data.keys() == invocation_data.keys()
check_embeddings_close(embeddings_0_lst=[chat_data["embedding"]],
embeddings_1_lst=[invocation_data["embedding"]],
name_0="chat",
name_1="invocation")
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from typing import Final from typing import Final
import pytest import pytest
...@@ -29,7 +30,7 @@ def server(): ...@@ -29,7 +30,7 @@ def server():
"--enforce-eager", "--enforce-eager",
"--trust-remote-code", "--trust-remote-code",
"--limit-mm-per-prompt", "--limit-mm-per-prompt",
f"image={MAXIMUM_IMAGES}", json.dumps({"image": MAXIMUM_IMAGES}),
] ]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
...@@ -95,6 +96,10 @@ def test_openapi_stateless(case: schemathesis.Case): ...@@ -95,6 +96,10 @@ def test_openapi_stateless(case: schemathesis.Case):
case.operation.method.upper(), case.operation.method.upper(),
case.operation.path, case.operation.path,
) )
if case.operation.path.startswith("/v1/responses"):
# Skip responses API as it is meant to be stateful.
return
timeout = { timeout = {
# requires a longer timeout # requires a longer timeout
("POST", "/v1/chat/completions"): ("POST", "/v1/chat/completions"):
......
...@@ -13,7 +13,7 @@ from vllm.transformers_utils.tokenizer import get_tokenizer ...@@ -13,7 +13,7 @@ from vllm.transformers_utils.tokenizer import get_tokenizer
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach" MODEL_NAME = "internlm/internlm2-1_8b-reward"
DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501 DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501
...@@ -21,15 +21,16 @@ DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + ...@@ -21,15 +21,16 @@ DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' +
def server(): def server():
args = [ args = [
"--task", "--task",
"classify", "reward",
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
"--dtype", "--dtype",
"bfloat16", "bfloat16",
"--enforce-eager", "--enforce-eager",
"--max-model-len", "--max-model-len",
"8192", "512",
"--chat-template", "--chat-template",
DUMMY_CHAT_TEMPLATE, DUMMY_CHAT_TEMPLATE,
"--trust-remote-code",
] ]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
...@@ -57,10 +58,10 @@ async def test_single_pooling(server: RemoteOpenAIServer, model_name: str): ...@@ -57,10 +58,10 @@ async def test_single_pooling(server: RemoteOpenAIServer, model_name: str):
assert poolings.id is not None assert poolings.id is not None
assert len(poolings.data) == 1 assert len(poolings.data) == 1
assert len(poolings.data[0].data) == 2 assert len(poolings.data[0].data) == 8
assert poolings.usage.completion_tokens == 0 assert poolings.usage.completion_tokens == 0
assert poolings.usage.prompt_tokens == 7 assert poolings.usage.prompt_tokens == 8
assert poolings.usage.total_tokens == 7 assert poolings.usage.total_tokens == 8
# test using token IDs # test using token IDs
input_tokens = [1, 1, 1, 1, 1] input_tokens = [1, 1, 1, 1, 1]
...@@ -77,7 +78,7 @@ async def test_single_pooling(server: RemoteOpenAIServer, model_name: str): ...@@ -77,7 +78,7 @@ async def test_single_pooling(server: RemoteOpenAIServer, model_name: str):
assert poolings.id is not None assert poolings.id is not None
assert len(poolings.data) == 1 assert len(poolings.data) == 1
assert len(poolings.data[0].data) == 2 assert len(poolings.data[0].data) == 5
assert poolings.usage.completion_tokens == 0 assert poolings.usage.completion_tokens == 0
assert poolings.usage.prompt_tokens == 5 assert poolings.usage.prompt_tokens == 5
assert poolings.usage.total_tokens == 5 assert poolings.usage.total_tokens == 5
...@@ -104,10 +105,10 @@ async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str): ...@@ -104,10 +105,10 @@ async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str):
assert poolings.id is not None assert poolings.id is not None
assert len(poolings.data) == 3 assert len(poolings.data) == 3
assert len(poolings.data[0].data) == 2 assert len(poolings.data[0].data) == 8
assert poolings.usage.completion_tokens == 0 assert poolings.usage.completion_tokens == 0
assert poolings.usage.prompt_tokens == 25 assert poolings.usage.prompt_tokens == 29
assert poolings.usage.total_tokens == 25 assert poolings.usage.total_tokens == 29
# test list[list[int]] # test list[list[int]]
input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24], input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
...@@ -125,7 +126,7 @@ async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str): ...@@ -125,7 +126,7 @@ async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str):
assert poolings.id is not None assert poolings.id is not None
assert len(poolings.data) == 4 assert len(poolings.data) == 4
assert len(poolings.data[0].data) == 2 assert len(poolings.data[0].data) == 5
assert poolings.usage.completion_tokens == 0 assert poolings.usage.completion_tokens == 0
assert poolings.usage.prompt_tokens == 17 assert poolings.usage.prompt_tokens == 17
assert poolings.usage.total_tokens == 17 assert poolings.usage.total_tokens == 17
...@@ -157,7 +158,11 @@ async def test_conversation_pooling(server: RemoteOpenAIServer, ...@@ -157,7 +158,11 @@ async def test_conversation_pooling(server: RemoteOpenAIServer,
chat_response.raise_for_status() chat_response.raise_for_status()
chat_poolings = PoolingResponse.model_validate(chat_response.json()) chat_poolings = PoolingResponse.model_validate(chat_response.json())
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast") tokenizer = get_tokenizer(
tokenizer_name=model_name,
tokenizer_mode="fast",
trust_remote_code=True,
)
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
messages, messages,
chat_template=DUMMY_CHAT_TEMPLATE, chat_template=DUMMY_CHAT_TEMPLATE,
...@@ -206,6 +211,9 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer, ...@@ -206,6 +211,9 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer,
) )
float_response.raise_for_status() float_response.raise_for_status()
responses_float = PoolingResponse.model_validate(float_response.json()) responses_float = PoolingResponse.model_validate(float_response.json())
float_data = [
np.array(d.data).squeeze(-1).tolist() for d in responses_float.data
]
base64_response = requests.post( base64_response = requests.post(
server.url_for("pooling"), server.url_for("pooling"),
...@@ -224,11 +232,10 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer, ...@@ -224,11 +232,10 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer,
np.frombuffer(base64.b64decode(data.data), np.frombuffer(base64.b64decode(data.data),
dtype="float32").tolist()) dtype="float32").tolist())
check_embeddings_close( check_embeddings_close(embeddings_0_lst=float_data,
embeddings_0_lst=[d.data for d in responses_float.data], embeddings_1_lst=decoded_responses_base64_data,
embeddings_1_lst=decoded_responses_base64_data, name_0="float32",
name_0="float32", name_1="base64")
name_1="base64")
# Default response is float32 decoded from base64 by OpenAI Client # Default response is float32 decoded from base64 by OpenAI Client
default_response = requests.post( default_response = requests.post(
...@@ -240,9 +247,83 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer, ...@@ -240,9 +247,83 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer,
) )
default_response.raise_for_status() default_response.raise_for_status()
responses_default = PoolingResponse.model_validate(default_response.json()) responses_default = PoolingResponse.model_validate(default_response.json())
default_data = [
np.array(d.data).squeeze(-1).tolist() for d in responses_default.data
]
check_embeddings_close(embeddings_0_lst=float_data,
embeddings_1_lst=default_data,
name_0="float32",
name_1="default")
@pytest.mark.asyncio
async def test_invocations(server: RemoteOpenAIServer):
input_texts = [
"The chef prepared a delicious meal.",
]
request_args = {
"model": MODEL_NAME,
"input": input_texts,
"encoding_format": "float",
}
completion_response = requests.post(server.url_for("pooling"),
json=request_args)
completion_response.raise_for_status()
invocation_response = requests.post(server.url_for("invocations"),
json=request_args)
invocation_response.raise_for_status()
completion_output = completion_response.json()
invocation_output = invocation_response.json()
assert completion_output.keys() == invocation_output.keys()
for completion_data, invocation_data in zip(completion_output["data"],
invocation_output["data"]):
assert completion_data.keys() == invocation_data.keys()
check_embeddings_close(embeddings_0_lst=completion_data["data"],
embeddings_1_lst=invocation_data["data"],
name_0="completion",
name_1="invocation")
@pytest.mark.asyncio
async def test_invocations_conversation(server: RemoteOpenAIServer):
messages = [{
"role": "user",
"content": "The cat sat on the mat.",
}, {
"role": "assistant",
"content": "A feline was resting on a rug.",
}, {
"role": "user",
"content": "Stars twinkle brightly in the night sky.",
}]
request_args = {
"model": MODEL_NAME,
"messages": messages,
"encoding_format": "float",
}
chat_response = requests.post(server.url_for("pooling"), json=request_args)
chat_response.raise_for_status()
check_embeddings_close( invocation_response = requests.post(server.url_for("invocations"),
embeddings_0_lst=[d.data for d in responses_default.data], json=request_args)
embeddings_1_lst=[d.data for d in responses_default.data], invocation_response.raise_for_status()
name_0="float32",
name_1="base64") chat_output = chat_response.json()
invocation_output = invocation_response.json()
assert chat_output.keys() == invocation_output.keys()
for chat_data, invocation_data in zip(chat_output["data"],
invocation_output["data"]):
assert chat_data.keys() == invocation_data.keys()
check_embeddings_close(embeddings_0_lst=chat_data["data"],
embeddings_1_lst=invocation_data["data"],
name_0="chat",
name_1="invocation")
...@@ -94,3 +94,34 @@ def test_rerank_max_model_len(server: RemoteOpenAIServer, model_name: str): ...@@ -94,3 +94,34 @@ def test_rerank_max_model_len(server: RemoteOpenAIServer, model_name: str):
# Assert just a small fragments of the response # Assert just a small fragments of the response
assert "Please reduce the length of the input." in \ assert "Please reduce the length of the input." in \
rerank_response.text rerank_response.text
def test_invocations(server: RemoteOpenAIServer):
query = "What is the capital of France?"
documents = [
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
]
request_args = {
"model": MODEL_NAME,
"query": query,
"documents": documents,
}
rerank_response = requests.post(server.url_for("rerank"),
json=request_args)
rerank_response.raise_for_status()
invocation_response = requests.post(server.url_for("invocations"),
json=request_args)
invocation_response.raise_for_status()
rerank_output = rerank_response.json()
invocation_output = invocation_response.json()
assert rerank_output.keys() == invocation_output.keys()
for rerank_result, invocations_result in zip(rerank_output["results"],
invocation_output["results"]):
assert rerank_result.keys() == invocations_result.keys()
assert rerank_result["relevance_score"] == pytest.approx(
invocations_result["relevance_score"], rel=0.01)
...@@ -13,7 +13,6 @@ from ...utils import RemoteOpenAIServer ...@@ -13,7 +13,6 @@ from ...utils import RemoteOpenAIServer
from .test_completion import default_server_args # noqa: F401 from .test_completion import default_server_args # noqa: F401
from .test_completion import zephyr_lora_added_tokens_files # noqa: F401 from .test_completion import zephyr_lora_added_tokens_files # noqa: F401
from .test_completion import zephyr_lora_files # noqa: F401 from .test_completion import zephyr_lora_files # noqa: F401
from .test_completion import zephyr_pa_files # noqa: F401
from .test_completion import MODEL_NAME from .test_completion import MODEL_NAME
......
...@@ -193,3 +193,32 @@ class TestModel: ...@@ -193,3 +193,32 @@ class TestModel:
assert score_response.status_code == 400 assert score_response.status_code == 400
assert "Please, select a smaller truncation size." in \ assert "Please, select a smaller truncation size." in \
score_response.text score_response.text
def test_invocations(self, server: RemoteOpenAIServer, model: dict[str,
Any]):
text_1 = "What is the capital of France?"
text_2 = "The capital of France is Paris."
request_args = {
"model": model["name"],
"text_1": text_1,
"text_2": text_2,
}
score_response = requests.post(server.url_for("score"),
json=request_args)
score_response.raise_for_status()
invocation_response = requests.post(server.url_for("invocations"),
json=request_args)
invocation_response.raise_for_status()
score_output = score_response.json()
invocation_output = invocation_response.json()
assert score_output.keys() == invocation_output.keys()
for score_data, invocation_data in zip(score_output["data"],
invocation_output["data"]):
assert score_data.keys() == invocation_data.keys()
assert score_data["score"] == pytest.approx(
invocation_data["score"], rel=0.01)
...@@ -8,6 +8,8 @@ from dataclasses import dataclass, field ...@@ -8,6 +8,8 @@ from dataclasses import dataclass, field
from typing import Any, Optional from typing import Any, Optional
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest
from vllm.config import MultiModalConfig from vllm.config import MultiModalConfig
from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.entrypoints.openai.protocol import ChatCompletionRequest
...@@ -75,7 +77,8 @@ def test_async_serving_chat_init(): ...@@ -75,7 +77,8 @@ def test_async_serving_chat_init():
assert serving_completion.chat_template == CHAT_TEMPLATE assert serving_completion.chat_template == CHAT_TEMPLATE
def test_serving_chat_should_set_correct_max_tokens(): @pytest.mark.asyncio
async def test_serving_chat_should_set_correct_max_tokens():
mock_engine = MagicMock(spec=MQLLMEngineClient) mock_engine = MagicMock(spec=MQLLMEngineClient)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False mock_engine.errored = False
...@@ -90,6 +93,7 @@ def test_serving_chat_should_set_correct_max_tokens(): ...@@ -90,6 +93,7 @@ def test_serving_chat_should_set_correct_max_tokens():
chat_template=CHAT_TEMPLATE, chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto", chat_template_content_format="auto",
request_logger=None) request_logger=None)
req = ChatCompletionRequest( req = ChatCompletionRequest(
model=MODEL_NAME, model=MODEL_NAME,
messages=[{ messages=[{
...@@ -100,13 +104,13 @@ def test_serving_chat_should_set_correct_max_tokens(): ...@@ -100,13 +104,13 @@ def test_serving_chat_should_set_correct_max_tokens():
) )
with suppress(Exception): with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req)) await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 93 assert mock_engine.generate.call_args.args[1].max_tokens == 93
req.max_tokens = 10 req.max_tokens = 10
with suppress(Exception): with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req)) await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 10 assert mock_engine.generate.call_args.args[1].max_tokens == 10
...@@ -145,7 +149,7 @@ def test_serving_chat_should_set_correct_max_tokens(): ...@@ -145,7 +149,7 @@ def test_serving_chat_should_set_correct_max_tokens():
) )
with suppress(Exception): with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req)) await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 10 assert mock_engine.generate.call_args.args[1].max_tokens == 10
...@@ -153,7 +157,7 @@ def test_serving_chat_should_set_correct_max_tokens(): ...@@ -153,7 +157,7 @@ def test_serving_chat_should_set_correct_max_tokens():
req.max_tokens = 15 req.max_tokens = 15
with suppress(Exception): with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req)) await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 10 assert mock_engine.generate.call_args.args[1].max_tokens == 10
...@@ -161,7 +165,7 @@ def test_serving_chat_should_set_correct_max_tokens(): ...@@ -161,7 +165,7 @@ def test_serving_chat_should_set_correct_max_tokens():
req.max_tokens = 5 req.max_tokens = 5
with suppress(Exception): with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req)) await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 5 assert mock_engine.generate.call_args.args[1].max_tokens == 5
...@@ -200,7 +204,7 @@ def test_serving_chat_should_set_correct_max_tokens(): ...@@ -200,7 +204,7 @@ def test_serving_chat_should_set_correct_max_tokens():
) )
with suppress(Exception): with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req)) await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 93 assert mock_engine.generate.call_args.args[1].max_tokens == 93
...@@ -208,7 +212,7 @@ def test_serving_chat_should_set_correct_max_tokens(): ...@@ -208,7 +212,7 @@ def test_serving_chat_should_set_correct_max_tokens():
req.max_tokens = 100 req.max_tokens = 100
with suppress(Exception): with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req)) await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 93 assert mock_engine.generate.call_args.args[1].max_tokens == 93
...@@ -216,12 +220,13 @@ def test_serving_chat_should_set_correct_max_tokens(): ...@@ -216,12 +220,13 @@ def test_serving_chat_should_set_correct_max_tokens():
req.max_tokens = 5 req.max_tokens = 5
with suppress(Exception): with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req)) await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 5 assert mock_engine.generate.call_args.args[1].max_tokens == 5
def test_serving_chat_could_load_correct_generation_config(): @pytest.mark.asyncio
async def test_serving_chat_could_load_correct_generation_config():
mock_model_config = MockModelConfig() mock_model_config = MockModelConfig()
mock_model_config.diff_sampling_param = { mock_model_config.diff_sampling_param = {
...@@ -244,6 +249,7 @@ def test_serving_chat_could_load_correct_generation_config(): ...@@ -244,6 +249,7 @@ def test_serving_chat_could_load_correct_generation_config():
chat_template=CHAT_TEMPLATE, chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto", chat_template_content_format="auto",
request_logger=None) request_logger=None)
req = ChatCompletionRequest( req = ChatCompletionRequest(
model=MODEL_NAME, model=MODEL_NAME,
messages=[{ messages=[{
...@@ -254,7 +260,7 @@ def test_serving_chat_could_load_correct_generation_config(): ...@@ -254,7 +260,7 @@ def test_serving_chat_could_load_correct_generation_config():
) )
with suppress(Exception): with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req)) await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].temperature == 0.5 assert mock_engine.generate.call_args.args[1].temperature == 0.5
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05 assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
...@@ -263,7 +269,7 @@ def test_serving_chat_could_load_correct_generation_config(): ...@@ -263,7 +269,7 @@ def test_serving_chat_could_load_correct_generation_config():
req.temperature = 0.1 req.temperature = 0.1
with suppress(Exception): with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req)) await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].temperature == 0.1 assert mock_engine.generate.call_args.args[1].temperature == 0.1
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05 assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
...@@ -272,13 +278,14 @@ def test_serving_chat_could_load_correct_generation_config(): ...@@ -272,13 +278,14 @@ def test_serving_chat_could_load_correct_generation_config():
req.temperature = 0.0 req.temperature = 0.0
with suppress(Exception): with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req)) await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].temperature == 0.0 assert mock_engine.generate.call_args.args[1].temperature == 0.0
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05 assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
def test_serving_chat_did_set_correct_cache_salt(): @pytest.mark.asyncio
async def test_serving_chat_did_set_correct_cache_salt():
mock_model_config = MockModelConfig() mock_model_config = MockModelConfig()
mock_engine = MagicMock(spec=MQLLMEngineClient) mock_engine = MagicMock(spec=MQLLMEngineClient)
...@@ -308,11 +315,11 @@ def test_serving_chat_did_set_correct_cache_salt(): ...@@ -308,11 +315,11 @@ def test_serving_chat_did_set_correct_cache_salt():
# By default cache_salt in the engine prompt is not set # By default cache_salt in the engine prompt is not set
with suppress(Exception): with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req)) await serving_chat.create_chat_completion(req)
assert "cache_salt" not in mock_engine.generate.call_args.args[0] assert "cache_salt" not in mock_engine.generate.call_args.args[0]
# Test with certain cache_salt # Test with certain cache_salt
req.cache_salt = "test_salt" req.cache_salt = "test_salt"
with suppress(Exception): with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req)) await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[0]["cache_salt"] == "test_salt" assert mock_engine.generate.call_args.args[0]["cache_salt"] == "test_salt"
...@@ -34,8 +34,7 @@ async def _async_serving_models_init() -> OpenAIServingModels: ...@@ -34,8 +34,7 @@ async def _async_serving_models_init() -> OpenAIServingModels:
serving_models = OpenAIServingModels(engine_client=mock_engine_client, serving_models = OpenAIServingModels(engine_client=mock_engine_client,
base_model_paths=BASE_MODEL_PATHS, base_model_paths=BASE_MODEL_PATHS,
model_config=mock_model_config, model_config=mock_model_config,
lora_modules=None, lora_modules=None)
prompt_adapters=None)
await serving_models.init_static_loras() await serving_models.init_static_loras()
return serving_models return serving_models
...@@ -59,7 +58,8 @@ async def test_load_lora_adapter_success(): ...@@ -59,7 +58,8 @@ async def test_load_lora_adapter_success():
response = await serving_models.load_lora_adapter(request) response = await serving_models.load_lora_adapter(request)
assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name='adapter') assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name='adapter')
assert len(serving_models.lora_requests) == 1 assert len(serving_models.lora_requests) == 1
assert serving_models.lora_requests[0].lora_name == "adapter" assert "adapter" in serving_models.lora_requests
assert serving_models.lora_requests["adapter"].lora_name == "adapter"
@pytest.mark.asyncio @pytest.mark.asyncio
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc import gc
import json import os
import tempfile import tempfile
import openai import openai
...@@ -58,18 +58,20 @@ def tensorize_model_and_lora(tmp_dir, model_uri): ...@@ -58,18 +58,20 @@ def tensorize_model_and_lora(tmp_dir, model_uri):
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(model_uri, tensorize_model_and_lora): def server(model_uri, tensorize_model_and_lora):
model_loader_extra_config = { # In this case, model_uri is a directory with a model.tensors
"tensorizer_uri": model_uri, # file and all necessary model artifacts, particularly a
} # HF `config.json` file. In this case, Tensorizer can infer the
# `TensorizerConfig` so --model-loader-extra-config can be completely
# omitted.
## Start OpenAI API server ## Start OpenAI API server
args = [ args = [
"--load-format", "tensorizer", "--device", "cuda", "--load-format", "tensorizer", "--served-model-name", MODEL_NAME,
"--model-loader-extra-config", "--enable-lora"
json.dumps(model_loader_extra_config), "--enable-lora"
] ]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: model_dir = os.path.dirname(model_uri)
with RemoteOpenAIServer(model_dir, args) as remote_server:
yield remote_server yield remote_server
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment