Commit 4eabe123 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'mirror/releases/v0.9.0' into v0.9.0-ori

parents 45840cd2 58738772
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import gc import gc
import json
import os import os
import pathlib import pathlib
import subprocess import subprocess
from functools import partial
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import openai
import pytest import pytest
import torch import torch
from huggingface_hub import snapshot_download
from vllm import SamplingParams from vllm import SamplingParams
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
...@@ -22,12 +18,11 @@ from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig, ...@@ -22,12 +18,11 @@ from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
is_vllm_tensorized, is_vllm_tensorized,
load_with_tensorizer, load_with_tensorizer,
open_stream, open_stream,
serialize_vllm_model,
tensorize_vllm_model) tensorize_vllm_model)
# yapf: enable # yapf: enable
from vllm.utils import PlaceholderModule, import_from_path from vllm.utils import PlaceholderModule
from ..utils import VLLM_PATH, RemoteOpenAIServer from ..utils import VLLM_PATH
try: try:
from tensorizer import EncryptionParams from tensorizer import EncryptionParams
...@@ -103,6 +98,7 @@ def test_can_deserialize_s3(vllm_runner): ...@@ -103,6 +98,7 @@ def test_can_deserialize_s3(vllm_runner):
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") @pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_deserialized_encrypted_vllm_model_has_same_outputs( def test_deserialized_encrypted_vllm_model_has_same_outputs(
vllm_runner, tmp_path): vllm_runner, tmp_path):
args = EngineArgs(model=model_ref)
with vllm_runner(model_ref) as vllm_model: with vllm_runner(model_ref) as vllm_model:
model_path = tmp_path / (model_ref + ".tensors") model_path = tmp_path / (model_ref + ".tensors")
key_path = tmp_path / (model_ref + ".key") key_path = tmp_path / (model_ref + ".key")
...@@ -110,15 +106,13 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs( ...@@ -110,15 +106,13 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs(
outputs = vllm_model.generate(prompts, sampling_params) outputs = vllm_model.generate(prompts, sampling_params)
config_for_serializing = TensorizerConfig(tensorizer_uri=model_path, config_for_serializing = TensorizerConfig(tensorizer_uri=str(model_path),
encryption_keyfile=key_path) encryption_keyfile=str(key_path))
vllm_model.apply_model( tensorize_vllm_model(args, config_for_serializing)
partial(serialize_vllm_model,
tensorizer_config=config_for_serializing))
config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path, config_for_deserializing = TensorizerConfig(
encryption_keyfile=key_path) tensorizer_uri=str(model_path), encryption_keyfile=str(key_path))
with vllm_runner(model_ref, with vllm_runner(model_ref,
load_format="tensorizer", load_format="tensorizer",
...@@ -154,113 +148,46 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner, ...@@ -154,113 +148,46 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
assert outputs == deserialized_outputs assert outputs == deserialized_outputs
def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path): def test_load_without_tensorizer_load_format(vllm_runner, capfd):
multilora_inference = import_from_path(
"examples.offline_inference.multilora_inference",
EXAMPLES_PATH / "offline_inference/multilora_inference.py",
)
model_ref = "meta-llama/Llama-2-7b-hf"
lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
test_prompts = multilora_inference.create_test_prompts(lora_path)
# Serialize model before deserializing and binding LoRA adapters
with vllm_runner(model_ref) as vllm_model:
model_path = tmp_path / (model_ref + ".tensors")
vllm_model.apply_model(
partial(
serialize_vllm_model,
tensorizer_config=TensorizerConfig(tensorizer_uri=model_path)))
with vllm_runner(
model_ref,
load_format="tensorizer",
model_loader_extra_config=TensorizerConfig(
tensorizer_uri=model_path,
num_readers=1,
),
enable_lora=True,
max_loras=1,
max_lora_rank=8,
max_cpu_loras=2,
max_num_seqs=50,
max_model_len=1000,
) as loaded_vllm_model:
multilora_inference.process_requests(
loaded_vllm_model.model.llm_engine, test_prompts)
assert loaded_vllm_model
def test_load_without_tensorizer_load_format(vllm_runner):
model = None model = None
with pytest.raises(ValueError): try:
model = vllm_runner( model = vllm_runner(
model_ref, model_ref,
model_loader_extra_config=TensorizerConfig(tensorizer_uri="test")) model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
del model except RuntimeError:
gc.collect() out, err = capfd.readouterr()
torch.cuda.empty_cache() combined_output = out + err
assert ("ValueError: Model loader extra config "
"is not supported for load "
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") "format LoadFormat.AUTO") in combined_output
def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path): finally:
## Serialize model del model
with vllm_runner(model_ref) as vllm_model: gc.collect()
model_path = tmp_path / (model_ref + ".tensors") torch.cuda.empty_cache()
vllm_model.apply_model(
partial( def test_raise_value_error_on_invalid_load_format(vllm_runner, capfd):
serialize_vllm_model,
tensorizer_config=TensorizerConfig(tensorizer_uri=model_path)))
model_loader_extra_config = {
"tensorizer_uri": str(model_path),
}
## Start OpenAI API server
openai_args = [
"--dtype",
"float16",
"--load-format",
"tensorizer",
"--model-loader-extra-config",
json.dumps(model_loader_extra_config),
]
with RemoteOpenAIServer(model_ref, openai_args) as server:
print("Server ready.")
client = server.get_client()
completion = client.completions.create(model=model_ref,
prompt="Hello, my name is",
max_tokens=5,
temperature=0.0)
assert completion.id is not None
assert len(completion.choices) == 1
assert len(completion.choices[0].text) >= 5
assert completion.choices[0].finish_reason == "length"
assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5, prompt_tokens=6, total_tokens=11)
def test_raise_value_error_on_invalid_load_format(vllm_runner):
model = None model = None
with pytest.raises(ValueError): try:
model = vllm_runner( model = vllm_runner(
model_ref, model_ref,
load_format="safetensors", load_format="safetensors",
model_loader_extra_config=TensorizerConfig(tensorizer_uri="test")) model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
del model except RuntimeError:
gc.collect() out, err = capfd.readouterr()
torch.cuda.empty_cache()
combined_output = out + err
assert ("ValueError: Model loader extra config is not supported "
"for load format LoadFormat.SAFETENSORS") in combined_output
finally:
del model
gc.collect()
torch.cuda.empty_cache()
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 GPUs") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 GPUs")
def test_tensorizer_with_tp_path_without_template(vllm_runner): def test_tensorizer_with_tp_path_without_template(vllm_runner, capfd):
with pytest.raises(ValueError): try:
model_ref = "EleutherAI/pythia-1.4b" model_ref = "EleutherAI/pythia-1.4b"
tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors" tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"
...@@ -275,6 +202,13 @@ def test_tensorizer_with_tp_path_without_template(vllm_runner): ...@@ -275,6 +202,13 @@ def test_tensorizer_with_tp_path_without_template(vllm_runner):
tensor_parallel_size=2, tensor_parallel_size=2,
disable_custom_all_reduce=True, disable_custom_all_reduce=True,
) )
except RuntimeError:
out, err = capfd.readouterr()
combined_output = out + err
assert ("ValueError: For a sharded model, tensorizer_uri "
"should include a string format template like '%04d' "
"to be formatted with the rank "
"of the shard") in combined_output
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 GPUs") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 GPUs")
...@@ -288,7 +222,6 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs( ...@@ -288,7 +222,6 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(
enforce_eager=True, enforce_eager=True,
) as base_model: ) as base_model:
outputs = base_model.generate(prompts, sampling_params) outputs = base_model.generate(prompts, sampling_params)
base_model.model.llm_engine.model_executor.shutdown()
# load model with two shards and serialize with encryption # load model with two shards and serialize with encryption
model_path = str(tmp_path / (model_ref + "-%02d.tensors")) model_path = str(tmp_path / (model_ref + "-%02d.tensors"))
...@@ -296,7 +229,7 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs( ...@@ -296,7 +229,7 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(
tensorizer_config = TensorizerConfig( tensorizer_config = TensorizerConfig(
tensorizer_uri=model_path, tensorizer_uri=model_path,
encryption_keyfile=key_path, encryption_keyfile=str(key_path),
) )
tensorize_vllm_model( tensorize_vllm_model(
...@@ -331,14 +264,13 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path): ...@@ -331,14 +264,13 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
model_ref = "facebook/opt-125m" model_ref = "facebook/opt-125m"
model_path = tmp_path / (model_ref + ".tensors") model_path = tmp_path / (model_ref + ".tensors")
config = TensorizerConfig(tensorizer_uri=str(model_path)) config = TensorizerConfig(tensorizer_uri=str(model_path))
args = EngineArgs(model=model_ref, device="cuda")
with vllm_runner(model_ref) as vllm_model: with vllm_runner(model_ref) as vllm_model:
outputs = vllm_model.generate(prompts, sampling_params) outputs = vllm_model.generate(prompts, sampling_params)
vllm_model.apply_model( tensorize_vllm_model(args, config)
partial(serialize_vllm_model, tensorizer_config=config)) assert is_vllm_tensorized(config)
assert is_vllm_tensorized(config)
with vllm_runner(model_ref, with vllm_runner(model_ref,
load_format="tensorizer", load_format="tensorizer",
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import enum
import json import json
import logging import logging
import os import os
import sys import sys
import tempfile import tempfile
from dataclasses import dataclass
from json.decoder import JSONDecodeError from json.decoder import JSONDecodeError
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from typing import Any from typing import Any
...@@ -16,6 +17,7 @@ import pytest ...@@ -16,6 +17,7 @@ import pytest
from vllm.logger import (_DATE_FORMAT, _FORMAT, _configure_vllm_root_logger, from vllm.logger import (_DATE_FORMAT, _FORMAT, _configure_vllm_root_logger,
enable_trace_function_call, init_logger) enable_trace_function_call, init_logger)
from vllm.logging_utils import NewLineFormatter from vllm.logging_utils import NewLineFormatter
from vllm.logging_utils.dump_input import prepare_object_to_dump
def f1(x): def f1(x):
...@@ -216,3 +218,37 @@ def test_custom_logging_config_causes_an_error_if_configure_logging_is_off(): ...@@ -216,3 +218,37 @@ def test_custom_logging_config_causes_an_error_if_configure_logging_is_off():
assert other_logger.handlers != root_logger.handlers assert other_logger.handlers != root_logger.handlers
assert other_logger.level != root_logger.level assert other_logger.level != root_logger.level
assert other_logger.propagate assert other_logger.propagate
def test_prepare_object_to_dump():
str_obj = 'str'
assert prepare_object_to_dump(str_obj) == "'str'"
list_obj = [1, 2, 3]
assert prepare_object_to_dump(list_obj) == '[1, 2, 3]'
dict_obj = {'a': 1, 'b': 'b'}
assert prepare_object_to_dump(dict_obj) in [
"{a: 1, b: 'b'}", "{b: 'b', a: 1}"
]
set_obj = {1, 2, 3}
assert prepare_object_to_dump(set_obj) == '[1, 2, 3]'
tuple_obj = ('a', 'b', 'c')
assert prepare_object_to_dump(tuple_obj) == "['a', 'b', 'c']"
class CustomEnum(enum.Enum):
A = enum.auto()
B = enum.auto()
C = enum.auto()
assert prepare_object_to_dump(CustomEnum.A) == repr(CustomEnum.A)
@dataclass
class CustomClass:
a: int
b: str
assert (prepare_object_to_dump(CustomClass(
1, 'b')) == "CustomClass(a=1, b='b')")
# SPDX-License-Identifier: Apache-2.0
from vllm.outputs import RequestOutput
def test_request_output_forward_compatible():
output = RequestOutput(request_id="test_request_id",
prompt="test prompt",
prompt_token_ids=[1, 2, 3],
prompt_logprobs=None,
outputs=[],
finished=False,
example_arg_added_in_new_version="some_value")
assert output is not None
...@@ -60,6 +60,9 @@ def test_model_from_modelscope(monkeypatch: pytest.MonkeyPatch): ...@@ -60,6 +60,9 @@ def test_model_from_modelscope(monkeypatch: pytest.MonkeyPatch):
# model: https://modelscope.cn/models/qwen/Qwen1.5-0.5B-Chat/summary # model: https://modelscope.cn/models/qwen/Qwen1.5-0.5B-Chat/summary
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_USE_MODELSCOPE", "True") m.setenv("VLLM_USE_MODELSCOPE", "True")
# Don't use HF_TOKEN for ModelScope repos, otherwise it will fail
# with 400 Client Error: Bad Request.
m.setenv("HF_TOKEN", "")
llm = LLM(model="qwen/Qwen1.5-0.5B-Chat") llm = LLM(model="qwen/Qwen1.5-0.5B-Chat")
prompts = [ prompts = [
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import pytest import pytest
from mistral_common.protocol.instruct.messages import UserMessage from mistral_common.protocol.instruct.messages import (AssistantMessage,
ToolMessage,
UserMessage)
from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.protocol.instruct.tool_calls import Function, Tool from mistral_common.protocol.instruct.tool_calls import (Function,
FunctionCall, Tool,
ToolCall)
from vllm.transformers_utils.tokenizers.mistral import ( from vllm.transformers_utils.tokenizers.mistral import (
make_mistral_chat_completion_request) make_mistral_chat_completion_request)
# yapf: enable
@pytest.mark.parametrize( @pytest.mark.parametrize(
"openai_request,expected_mistral_request", "openai_request,expected_mistral_request",
[( [(
...@@ -78,6 +81,107 @@ from vllm.transformers_utils.tokenizers.mistral import ( ...@@ -78,6 +81,107 @@ from vllm.transformers_utils.tokenizers.mistral import (
) )
def test_make_mistral_chat_completion_request(openai_request, def test_make_mistral_chat_completion_request(openai_request,
expected_mistral_request): expected_mistral_request):
assert (make_mistral_chat_completion_request( actual_request = make_mistral_chat_completion_request(
openai_request["messages"], openai_request["messages"], openai_request["tools"])
openai_request["tools"]) == expected_mistral_request) assert actual_request == expected_mistral_request
# Tool use with list content and reasoning_content
@pytest.mark.parametrize("openai_request,expected_mistral_request", [(
{
"messages": [
{
"role": "user",
"content": "What's the weather in Paris?",
},
{
"role":
"assistant",
"reasoning_content":
None,
"content":
None,
"tool_calls": [{
"id": "call123",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"city": "Paris"}',
},
}],
},
{
"role": "tool",
"content": [{
"type": "text",
"text": "Rainy"
}],
"name": "get_weather",
"tool_call_id": "call123",
},
],
"tools": [{
"type": "function",
"function": {
"name": "get_weather",
"description": "Gets the current weather in a city.",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city name"
}
},
"required": ["city"],
},
},
}],
},
ChatCompletionRequest(
messages=[
UserMessage(content="What's the weather in Paris?"),
AssistantMessage(
content=None,
tool_calls=[
ToolCall(
id="call123",
function=FunctionCall(
name="get_weather",
arguments='{"city": "Paris"}',
),
)
],
),
ToolMessage(
content="Rainy",
tool_call_id="call123",
name="get_weather",
),
],
tools=[
Tool(
type="function",
function=Function(
name="get_weather",
description="Gets the current weather in a city.",
parameters={
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city name"
}
},
"required": ["city"],
},
),
)
],
),
)])
def test_make_mistral_chat_completion_request_list_content(
openai_request, expected_mistral_request):
actual_request = make_mistral_chat_completion_request(
openai_request["messages"], openai_request["tools"])
assert actual_request == expected_mistral_request
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import json import json
import re
from copy import deepcopy from copy import deepcopy
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
import regex as re
from pydantic import TypeAdapter from pydantic import TypeAdapter
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
...@@ -333,4 +333,4 @@ def test_streaming_output_valid(output, empty_params, delta_len): ...@@ -333,4 +333,4 @@ def test_streaming_output_valid(output, empty_params, delta_len):
combined_messages += message.tool_calls[0].function.arguments combined_messages += message.tool_calls[0].function.arguments
combined_messages += "}]" combined_messages += "}]"
assert json.loads(combined_messages) == output assert json.loads(combined_messages) == output
assert json.dumps(json.loads(combined_messages)) == output_json assert json.dumps(json.loads(combined_messages)) == output_json
\ No newline at end of file
...@@ -88,7 +88,7 @@ CONFIGS: dict[str, ServerConfig] = { ...@@ -88,7 +88,7 @@ CONFIGS: dict[str, ServerConfig] = {
"meta-llama/Llama-4-Scout-17B-16E-Instruct", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"arguments": [ "arguments": [
"--enforce-eager", "--no-enable-prefix-caching", "--enforce-eager", "--no-enable-prefix-caching",
"--tool-call-parser", "pythonic", "--chat-template", "--tool-call-parser", "llama4_pythonic", "--chat-template",
str(VLLM_PATH / str(VLLM_PATH /
"examples/tool_chat_template_llama4_pythonic.jinja"), "-tp", "examples/tool_chat_template_llama4_pythonic.jinja"), "-tp",
"4" "4"
......
...@@ -19,7 +19,8 @@ def model() -> LLM: ...@@ -19,7 +19,8 @@ def model() -> LLM:
enable_prefix_caching=True, enable_prefix_caching=True,
long_prefill_token_threshold=2, long_prefill_token_threshold=2,
max_num_batched_tokens=6, max_num_batched_tokens=6,
max_num_seqs=3) max_num_seqs=3,
block_size=16)
def test_concurrent_partial_prefill(model): def test_concurrent_partial_prefill(model):
...@@ -27,3 +28,11 @@ def test_concurrent_partial_prefill(model): ...@@ -27,3 +28,11 @@ def test_concurrent_partial_prefill(model):
assert len(outputs) == 3 assert len(outputs) == 3
for output in outputs: for output in outputs:
assert len(output.outputs) == 1 assert len(output.outputs) == 1
def test_prefix_cache_stats_is_recorded(model):
# 17 tokens will make sure first 16 tokens are cached in a block
input_tokens = {"prompt_token_ids": [101] * 17}
_ = model.generate([input_tokens])
outputs = model.generate([input_tokens])
assert outputs[0].num_cached_tokens == 16
...@@ -6,6 +6,7 @@ from typing import Optional ...@@ -6,6 +6,7 @@ from typing import Optional
import pytest import pytest
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.v1.metrics.reader import Counter, Gauge, Histogram, Metric, Vector
MODEL = "facebook/opt-125m" MODEL = "facebook/opt-125m"
DTYPE = "half" DTYPE = "half"
...@@ -97,3 +98,67 @@ def test_parallel_sampling(vllm_model, example_prompts) -> None: ...@@ -97,3 +98,67 @@ def test_parallel_sampling(vllm_model, example_prompts) -> None:
raise AssertionError( raise AssertionError(
f"{len(completion_counts)} unique completions; expected" f"{len(completion_counts)} unique completions; expected"
f" {n}. Repeats: {repeats}") f" {n}. Repeats: {repeats}")
def test_engine_metrics(vllm_runner, monkeypatch, example_prompts):
max_tokens = 100
# Use spec decoding to test num_accepted_tokens_per_pos
speculative_config = {
"method": "ngram",
"prompt_lookup_max": 5,
"prompt_lookup_min": 3,
"num_speculative_tokens": 5,
}
monkeypatch.setenv("VLLM_USE_V1", "1")
with vllm_runner(
MODEL,
speculative_config=speculative_config,
disable_log_stats=False,
) as vllm_model:
model: LLM = vllm_model.model
sampling_params = SamplingParams(temperature=0.0,
max_tokens=max_tokens)
outputs = model.generate(example_prompts, sampling_params)
n_prompts = len(example_prompts)
assert len(outputs) == n_prompts
total_tokens = 0
for out in outputs:
assert len(out.outputs) == 1
total_tokens += len(out.outputs[0].token_ids)
assert total_tokens == max_tokens * n_prompts
metrics = model.get_metrics()
def find_metric(name) -> list[Metric]:
found = []
for metric in metrics:
if metric.name == name:
found.append(metric)
return found
num_requests_running = find_metric("vllm:num_requests_running")
assert len(num_requests_running) == 1
assert isinstance(num_requests_running[0], Gauge)
assert num_requests_running[0].value == .0
generation_tokens = find_metric("vllm:generation_tokens")
assert len(generation_tokens) == 1
assert isinstance(generation_tokens[0], Counter)
assert generation_tokens[0].value == total_tokens
request_generation_tokens = find_metric(
"vllm:request_generation_tokens")
assert len(request_generation_tokens) == 1
assert isinstance(request_generation_tokens[0], Histogram)
assert "+Inf" in request_generation_tokens[0].buckets
assert request_generation_tokens[0].buckets["+Inf"] == n_prompts
assert request_generation_tokens[0].count == n_prompts
assert request_generation_tokens[0].sum == total_tokens
num_accepted_tokens_per_pos = find_metric(
"vllm:spec_decode_num_accepted_tokens_per_pos")
assert len(num_accepted_tokens_per_pos) == 1
assert isinstance(num_accepted_tokens_per_pos[0], Vector)
assert len(num_accepted_tokens_per_pos[0].values) == 5
...@@ -4,12 +4,12 @@ ...@@ -4,12 +4,12 @@
from __future__ import annotations from __future__ import annotations
import json import json
import re
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
import jsonschema import jsonschema
import pytest import pytest
import regex as re
from pydantic import BaseModel from pydantic import BaseModel
from tests.reasoning.utils import run_reasoning_extraction from tests.reasoning.utils import run_reasoning_extraction
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import re
from typing import Optional from typing import Optional
import openai # use the official client for correctness check import openai # use the official client for correctness check
import pytest import pytest
import pytest_asyncio import pytest_asyncio
import regex as re
from openai import BadRequestError from openai import BadRequestError
from tests.utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
......
...@@ -13,6 +13,8 @@ NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-2} # Default to 2 ...@@ -13,6 +13,8 @@ NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-2} # Default to 2
# Find the git repository root directory # Find the git repository root directory
GIT_ROOT=$(git rev-parse --show-toplevel) GIT_ROOT=$(git rev-parse --show-toplevel)
SMI_BIN=$(which nvidia-smi || which rocm-smi)
# Trap the SIGINT signal (triggered by Ctrl+C) # Trap the SIGINT signal (triggered by Ctrl+C)
trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT
...@@ -44,6 +46,13 @@ get_model_args() { ...@@ -44,6 +46,13 @@ get_model_args() {
echo "$extra_args" echo "$extra_args"
} }
get_num_gpus() {
if [[ "$SMI_BIN" == *"nvidia"* ]]; then
echo "$($SMI_BIN --query-gpu=name --format=csv,noheader | wc -l)"
else
echo "$($SMI_BIN -l | grep GPU | wc -l)"
fi
}
# Function to run tests for a specific model # Function to run tests for a specific model
run_tests_for_model() { run_tests_for_model() {
...@@ -64,7 +73,7 @@ run_tests_for_model() { ...@@ -64,7 +73,7 @@ run_tests_for_model() {
# Start prefill instances # Start prefill instances
for i in $(seq 0 $((NUM_PREFILL_INSTANCES-1))); do for i in $(seq 0 $((NUM_PREFILL_INSTANCES-1))); do
# Calculate GPU ID - we'll distribute across available GPUs # Calculate GPU ID - we'll distribute across available GPUs
GPU_ID=$((i % $(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l))) GPU_ID=$((i % $(get_num_gpus)))
# Calculate port number (base port + instance number) # Calculate port number (base port + instance number)
PORT=$((8100 + i)) PORT=$((8100 + i))
# Calculate side channel port # Calculate side channel port
...@@ -96,7 +105,7 @@ run_tests_for_model() { ...@@ -96,7 +105,7 @@ run_tests_for_model() {
# Start decode instances # Start decode instances
for i in $(seq 0 $((NUM_DECODE_INSTANCES-1))); do for i in $(seq 0 $((NUM_DECODE_INSTANCES-1))); do
# Calculate GPU ID - we'll distribute across available GPUs, starting from after prefill GPUs # Calculate GPU ID - we'll distribute across available GPUs, starting from after prefill GPUs
GPU_ID=$(((i + NUM_PREFILL_INSTANCES) % $(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l))) GPU_ID=$(((i + NUM_PREFILL_INSTANCES) % $(get_num_gpus)))
# Calculate port number (base port + instance number) # Calculate port number (base port + instance number)
PORT=$((8200 + i)) PORT=$((8200 + i))
# Calculate side channel port # Calculate side channel port
......
...@@ -239,3 +239,11 @@ def get_connector_events() -> dict[str, list[str]]: ...@@ -239,3 +239,11 @@ def get_connector_events() -> dict[str, list[str]]:
print(f"[ERROR] Could not read connector events for {name}: {e}") print(f"[ERROR] Could not read connector events for {name}: {e}")
return connector_events return connector_events
def test_engine_id_conflict():
configs = [KVTransferConfig() for _ in range(2)]
ids = [config.engine_id for config in configs]
assert ids[0] != ids[1], (
"Engine IDs should be different for different configs. "
f"Got {ids}")
...@@ -340,3 +340,84 @@ def test_full_block_prompt(): ...@@ -340,3 +340,84 @@ def test_full_block_prompt():
output = outputs[0] output = outputs[0]
assert output.finish_reason == FinishReason.STOP assert output.finish_reason == FinishReason.STOP
assert_scheduler_empty(scheduler) assert_scheduler_empty(scheduler)
def test_cannot_schedule_after_recv():
"""
Test that we can handle no schedule after recv due to not
enough remaining KV blocks.
"""
# NOTE: the KVCacheManager will use 1 null block.
# So there are 5 total working blocks.
TOTAL_NUM_BLOCKS = 6
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config, num_blocks=TOTAL_NUM_BLOCKS)
# Prime the KVCache.
NUM_PROMPT_BLOCKS = 2
BLOCK_SIZE = vllm_config.cache_config.block_size
# Prompt will use 2 blocks + 1 block after we schedule.
NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS)
NUM_TOKENS_REMOTE = int(BLOCK_SIZE * (NUM_PROMPT_BLOCKS + 0.5))
request_normal = create_request(request_id=1, num_tokens=NUM_TOKENS_LOCAL)
request_remote = create_request(request_id=2,
num_tokens=NUM_TOKENS_REMOTE,
do_remote_prefill=True)
# STEP 1: 3 blocks are in use (2 for prompt, 1 for decode).
scheduler.add_request(request_normal)
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_normal])
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 0
# Step 2: 5 blocks are in use (2 new for remote blocks).
scheduler.add_request(request_remote)
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_normal])
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1
# Step 3: finish recving (5 blocks in use)
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(
reqs=[request_normal], finished_recving=[request_remote.request_id])
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1
# Step 4: try to schedule, not enough blocks.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_normal])
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1
# Step 5: finish the request, free it.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_normal],
use_eos=True)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1
# Step 6: now we can schedule (with 2 blocks computed).
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_remote])
assert (scheduler_output.scheduled_new_reqs[0].num_computed_tokens ==
NUM_PROMPT_BLOCKS * BLOCK_SIZE)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 0
# Step 7: free everything.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_remote],
use_eos=True)
scheduler.update_from_output(scheduler_output, model_runner_output)
_ = scheduler.schedule()
assert_scheduler_empty(scheduler)
...@@ -16,31 +16,40 @@ VOCAB_SIZE = 128 * 1024 ...@@ -16,31 +16,40 @@ VOCAB_SIZE = 128 * 1024
FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available
@pytest.fixture(autouse=True)
def reset_default_device():
"""
Explicitly set the default device, which can affect subsequent tests.
Adding this fixture helps avoid this problem.
"""
original_device = torch.get_default_device()
yield
torch.set_default_device(original_device)
def test_topk_impl_equivalance(): def test_topk_impl_equivalance():
with torch.device(DEVICE): torch.set_default_device(DEVICE)
generator = Generator(device=DEVICE).manual_seed(33) generator = Generator(device=DEVICE).manual_seed(33)
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator) logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
# Random top-k values between 1 and 9. # Random top-k values between 1 and 9.
k = torch.randint(1, 10, (BATCH_SIZE, ), generator=generator) k = torch.randint(1, 10, (BATCH_SIZE, ), generator=generator)
# Set k=vocab_size for ~50% of requests in the batch (top-k disabled). # Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
k.masked_fill_( k.masked_fill_(
torch.randint(0, torch.randint(0, 2, (BATCH_SIZE, ), generator=generator, dtype=bool),
2, (BATCH_SIZE, ), VOCAB_SIZE)
generator=generator,
dtype=bool), VOCAB_SIZE)
# Top-k only implementation # Top-k only implementation
result1 = apply_top_k_top_p(logits=logits.clone(), k=k, p=None) result1 = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
# Top-p + top-k # Top-p + top-k
no_op_top_p = torch.tensor([1.0]) no_op_top_p = torch.tensor([1.0])
result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p) result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p)
assert torch.allclose(result1, result2) assert torch.allclose(result1, result2)
def test_flashinfer_sampler(): def test_flashinfer_sampler():
...@@ -58,50 +67,49 @@ def test_flashinfer_sampler(): ...@@ -58,50 +67,49 @@ def test_flashinfer_sampler():
pytest.skip( pytest.skip(
"FlashInfer not installed or not available on this platform.") "FlashInfer not installed or not available on this platform.")
with torch.device(DEVICE): torch.set_default_device(DEVICE)
generator = Generator(device=DEVICE).manual_seed(42) generator = Generator(device=DEVICE).manual_seed(42)
# Generate random logits # Generate random logits
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator) logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
# Generate various top-k and top-p values # Generate various top-k and top-p values
k_values = torch.randint(1, 1000, (BATCH_SIZE, ), generator=generator) k_values = torch.randint(1, 1000, (BATCH_SIZE, ), generator=generator)
p_values = torch.rand( p_values = torch.rand(
(BATCH_SIZE, ), (BATCH_SIZE, ), generator=generator) * 0.5 + 0.5 # range in [0.5, 1.0]
generator=generator) * 0.5 + 0.5 # range in [0.5, 1.0]
# Sometimes disable top-k (k=vocab_size)
# Sometimes disable top-k (k=vocab_size) k_values.masked_fill_(
k_values.masked_fill_( torch.randint(0,
torch.randint(0, 2, (BATCH_SIZE, ),
2, (BATCH_SIZE, ), generator=generator,
generator=generator, dtype=torch.bool), VOCAB_SIZE)
dtype=torch.bool), VOCAB_SIZE)
# Sometimes disable top-p (p=1.0)
# Sometimes disable top-p (p=1.0) p_values.masked_fill_(
p_values.masked_fill_( torch.randint(0,
torch.randint(0, 2, (BATCH_SIZE, ),
2, (BATCH_SIZE, ), generator=generator,
generator=generator, dtype=torch.bool), 1.0)
dtype=torch.bool), 1.0)
python_logits = apply_top_k_top_p(
python_logits = apply_top_k_top_p( logits=logits.clone(),
logits=logits.clone(), k=k_values,
k=k_values, p=p_values,
p=p_values, )
) python_probs = torch.softmax(python_logits, dim=-1)
python_probs = torch.softmax(python_logits, dim=-1)
# FlashInfer only exposed renorm interfaces for probs so convert first
# FlashInfer only exposed renorm interfaces for probs so convert first flashinfer_probs = torch.softmax(logits.clone(), dim=-1)
flashinfer_probs = torch.softmax(logits.clone(), dim=-1) flashinfer_probs = top_k_renorm_probs(
flashinfer_probs = top_k_renorm_probs( probs=flashinfer_probs,
probs=flashinfer_probs, top_k=k_values,
top_k=k_values, )
) flashinfer_probs = top_p_renorm_probs(
flashinfer_probs = top_p_renorm_probs( probs=flashinfer_probs,
probs=flashinfer_probs, top_p=p_values,
top_p=p_values, )
)
# Compare the results
# Compare the results assert torch.allclose(python_probs, flashinfer_probs, atol=2e-2), \
assert torch.allclose(python_probs, flashinfer_probs, atol=2e-2), \ "FlashInfer and Python sampling implementations do not match!"
"FlashInfer and Python sampling implementations do not match!"
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import re
from enum import Enum from enum import Enum
from typing import Optional from typing import Optional
import regex as re
from vllm import CompletionOutput from vllm import CompletionOutput
......
...@@ -100,8 +100,12 @@ def test_prepare_inputs(): ...@@ -100,8 +100,12 @@ def test_prepare_inputs():
dtype=torch.int32, dtype=torch.int32,
device=device) device=device)
# n1 + n2 + n3 - a - b -c
num_tokens = cu_target_query_lens[-1].item() - num_rejected_tokens.sum(
).item()
cu_num_tokens, token_indices = EagleProposer.prepare_inputs( cu_num_tokens, token_indices = EagleProposer.prepare_inputs(
cu_target_query_lens, num_rejected_tokens) cu_target_query_lens, num_rejected_tokens, num_tokens)
assert torch.equal(cu_num_tokens, expected_cu_num_tokens) assert torch.equal(cu_num_tokens, expected_cu_num_tokens)
assert token_indices.shape[0] == expected_cu_num_tokens[-1].item() assert token_indices.shape[0] == expected_cu_num_tokens[-1].item()
...@@ -117,34 +121,13 @@ def test_prepare_inputs(): ...@@ -117,34 +121,13 @@ def test_prepare_inputs():
]) ])
@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group') @mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config') @mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
@mock.patch('vllm.v1.spec_decode.eagle.ModelRegistry') @mock.patch('vllm.v1.spec_decode.eagle.get_model')
@mock.patch('vllm.v1.spec_decode.eagle.get_model_loader') def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
@mock.patch('vllm.v1.spec_decode.eagle.set_default_torch_dtype')
@mock.patch('vllm.v1.spec_decode.eagle.set_current_vllm_config')
def test_load_model(mock_set_config, mock_set_dtype, mock_get_loader,
mock_registry, mock_get_layers, mock_get_pp_group, method,
proposer_helper, draft_model_dir, target_attribute_path): proposer_helper, draft_model_dir, target_attribute_path):
# Setup mock for model class # Setup model mock
mock_model_cls = mock.MagicMock() mock_model = mock.MagicMock()
mock_registry.resolve_model_cls.return_value = (mock_model_cls, mock_get_model.return_value = mock_model
"test_arch")
# Create a real context manager for mocks
class MockContextManager:
def __init__(self):
pass
def __enter__(self):
return None
def __exit__(self, exc_type, exc_val, exc_tb):
return False
# Make the mocks return actual context manager objects
mock_set_dtype.return_value = MockContextManager()
mock_set_config.return_value = MockContextManager()
# Setup mocks for attention layers # Setup mocks for attention layers
target_attn_layers = { target_attn_layers = {
...@@ -164,25 +147,6 @@ def test_load_model(mock_set_config, mock_set_dtype, mock_get_loader, ...@@ -164,25 +147,6 @@ def test_load_model(mock_set_config, mock_set_dtype, mock_get_loader,
mock_pp_group.world_size = 2 if method == "eagle" else 1 mock_pp_group.world_size = 2 if method == "eagle" else 1
mock_get_pp_group.return_value = mock_pp_group mock_get_pp_group.return_value = mock_pp_group
# Setup model loader mock
mock_loader = mock.MagicMock()
mock_get_loader.return_value = mock_loader
# Setup model mock
mock_model = mock.MagicMock()
mock_model_cls.return_value = mock_model
mock_model.to.return_value = mock_model
# Configure mock to test the attribute sharing path
if method == "eagle":
# For eagle, test the lm_head path
mock_model.load_weights.return_value = {
"model.embed_tokens.weight": torch.zeros(1)
}
else:
# For eagle3, test the embed_tokens path
mock_model.load_weights.return_value = {}
# Setup target model with the appropriate attributes # Setup target model with the appropriate attributes
target_model = mock.MagicMock() target_model = mock.MagicMock()
...@@ -204,13 +168,7 @@ def test_load_model(mock_set_config, mock_set_dtype, mock_get_loader, ...@@ -204,13 +168,7 @@ def test_load_model(mock_set_config, mock_set_dtype, mock_get_loader,
proposer.load_model(target_model) proposer.load_model(target_model)
# Verify common interactions # Verify common interactions
mock_get_loader.assert_called_once() mock_get_model.assert_called_once()
mock_model_cls.assert_called_once()
mock_model.to.assert_called_once()
mock_model.load_weights.assert_called_once()
# Verify the loader was called with the right config
mock_get_loader.assert_called_once_with(proposer.vllm_config.load_config)
# Verify the specific attribute sharing based on the method # Verify the specific attribute sharing based on the method
if method == "eagle": if method == "eagle":
...@@ -288,6 +246,9 @@ def test_propose(num_speculative_tokens): ...@@ -288,6 +246,9 @@ def test_propose(num_speculative_tokens):
# Assign the mock to the proposer # Assign the mock to the proposer
proposer.model = model_mock proposer.model = model_mock
# Assign draft attn_layer_names since load_model is not invoked
proposer.attn_layer_names = ["layer.0"]
# Create input tensors # Create input tensors
cu_num_tokens = torch.tensor([0, seq_len_1, total_tokens], cu_num_tokens = torch.tensor([0, seq_len_1, total_tokens],
dtype=torch.int32, dtype=torch.int32,
......
# SPDX-License-Identifier: Apache-2.0
import prometheus_client
import pytest
from vllm.v1.metrics.reader import (Counter, Gauge, Histogram, Vector,
get_metrics_snapshot)
@pytest.fixture(autouse=True)
def test_registry(monkeypatch):
# Use a custom registry for tests
test_registry = prometheus_client.CollectorRegistry(auto_describe=True)
monkeypatch.setattr("vllm.v1.metrics.reader.REGISTRY", test_registry)
return test_registry
@pytest.mark.parametrize("num_engines", [1, 4])
def test_gauge_metric(test_registry, num_engines):
g = prometheus_client.Gauge("vllm:test_gauge",
"Test gauge metric",
labelnames=["model", "engine_index"],
registry=test_registry)
for i in range(num_engines):
g.labels(model="foo", engine_index=str(i)).set(98.5)
metrics = get_metrics_snapshot()
assert len(metrics) == num_engines
engine_labels = [str(i) for i in range(num_engines)]
for m in metrics:
assert isinstance(m, Gauge)
assert m.name == "vllm:test_gauge"
assert m.value == 98.5
assert m.labels["model"] == "foo"
assert m.labels["engine_index"] in engine_labels
engine_labels.remove(m.labels["engine_index"])
@pytest.mark.parametrize("num_engines", [1, 4])
def test_counter_metric(test_registry, num_engines):
c = prometheus_client.Counter("vllm:test_counter",
"Test counter metric",
labelnames=["model", "engine_index"],
registry=test_registry)
for i in range(num_engines):
c.labels(model="bar", engine_index=str(i)).inc(19)
metrics = get_metrics_snapshot()
assert len(metrics) == num_engines
engine_labels = [str(i) for i in range(num_engines)]
for m in metrics:
assert isinstance(m, Counter)
assert m.name == "vllm:test_counter"
assert m.value == 19
assert m.labels["model"] == "bar"
assert m.labels["engine_index"] in engine_labels
engine_labels.remove(m.labels["engine_index"])
@pytest.mark.parametrize("num_engines", [1, 4])
def test_histogram_metric(test_registry, num_engines):
h = prometheus_client.Histogram("vllm:test_histogram",
"Test histogram metric",
labelnames=["model", "engine_index"],
buckets=[10, 20, 30, 40, 50],
registry=test_registry)
for i in range(num_engines):
hist = h.labels(model="blaa", engine_index=str(i))
hist.observe(42)
hist.observe(21)
hist.observe(7)
metrics = get_metrics_snapshot()
assert len(metrics) == num_engines
engine_labels = [str(i) for i in range(num_engines)]
for m in metrics:
assert isinstance(m, Histogram)
assert m.name == "vllm:test_histogram"
assert m.count == 3
assert m.sum == 70
assert m.buckets["10.0"] == 1
assert m.buckets["20.0"] == 1
assert m.buckets["30.0"] == 2
assert m.buckets["40.0"] == 2
assert m.buckets["50.0"] == 3
assert m.labels["model"] == "blaa"
assert m.labels["engine_index"] in engine_labels
engine_labels.remove(m.labels["engine_index"])
@pytest.mark.parametrize("num_engines", [1, 4])
def test_vector_metric(test_registry, num_engines):
c = prometheus_client.Counter(
"vllm:spec_decode_num_accepted_tokens_per_pos",
"Vector-like counter metric",
labelnames=["position", "model", "engine_index"],
registry=test_registry)
for i in range(num_engines):
c.labels(position="0", model="llama", engine_index=str(i)).inc(10)
c.labels(position="1", model="llama", engine_index=str(i)).inc(5)
c.labels(position="2", model="llama", engine_index=str(i)).inc(1)
metrics = get_metrics_snapshot()
assert len(metrics) == num_engines
engine_labels = [str(i) for i in range(num_engines)]
for m in metrics:
assert isinstance(m, Vector)
assert m.name == "vllm:spec_decode_num_accepted_tokens_per_pos"
assert m.values == [10, 5, 1]
assert m.labels["model"] == "llama"
assert m.labels["engine_index"] in engine_labels
engine_labels.remove(m.labels["engine_index"])
...@@ -12,7 +12,7 @@ UNSUPPORTED_MODELS_V1 = [ ...@@ -12,7 +12,7 @@ UNSUPPORTED_MODELS_V1 = [
"openai/whisper-large-v3", # transcription "openai/whisper-large-v3", # transcription
"facebook/bart-large-cnn", # encoder decoder "facebook/bart-large-cnn", # encoder decoder
"mistralai/Mamba-Codestral-7B-v0.1", # mamba "mistralai/Mamba-Codestral-7B-v0.1", # mamba
"hmellor/bamba-tiny-random", # hybrid "hmellor/tiny-random-BambaForCausalLM", # hybrid
"BAAI/bge-m3", # embedding "BAAI/bge-m3", # embedding
] ]
......
...@@ -251,7 +251,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): ...@@ -251,7 +251,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
device=torch.device(device), device=torch.device(device),
pin_memory=is_pin_memory_available(), pin_memory=is_pin_memory_available(),
vocab_size=1024, vocab_size=1024,
kv_cache_config=get_kv_cache_config(), block_size=1,
) )
reqs: list[CachedRequestState] = [] reqs: list[CachedRequestState] = []
req_id_reqs = {} req_id_reqs = {}
...@@ -341,7 +341,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, ...@@ -341,7 +341,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
device=torch.device(device), device=torch.device(device),
pin_memory=is_pin_memory_available(), pin_memory=is_pin_memory_available(),
vocab_size=1024, vocab_size=1024,
kv_cache_config=get_kv_cache_config(), block_size=1,
) )
ref_input_batch: InputBatch = InputBatch( ref_input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size, max_num_reqs=batch_size,
...@@ -350,7 +350,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, ...@@ -350,7 +350,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
device=torch.device(device), device=torch.device(device),
pin_memory=is_pin_memory_available(), pin_memory=is_pin_memory_available(),
vocab_size=1024, vocab_size=1024,
kv_cache_config=get_kv_cache_config(), block_size=1,
) )
reqs: list[CachedRequestState] = [] reqs: list[CachedRequestState] = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment