"vllm/vscode:/vscode.git/clone" did not exist on "1591c68fdea97a213d5564f687009c4fd1b44608"
Commit 53076d70 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.2' into v0.8.2-ori

parents 322a0be6 9c5c81b0
...@@ -21,18 +21,9 @@ def test_collective_rpc(tp_size, backend): ...@@ -21,18 +21,9 @@ def test_collective_rpc(tp_size, backend):
def echo_rank(self): def echo_rank(self):
return self.rank return self.rank
from vllm.worker.worker import Worker
class MyWorker(Worker):
def echo_rank(self):
return self.rank
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True, enforce_eager=True,
load_format="dummy", load_format="dummy",
tensor_parallel_size=tp_size, tensor_parallel_size=tp_size,
distributed_executor_backend=backend, distributed_executor_backend=backend)
worker_cls=MyWorker) assert llm.collective_rpc(echo_rank) == list(range(tp_size))
for method in ["echo_rank", echo_rank]:
assert llm.collective_rpc(method) == list(range(tp_size))
...@@ -14,7 +14,9 @@ from vllm.outputs import RequestOutput ...@@ -14,7 +14,9 @@ from vllm.outputs import RequestOutput
from vllm.sampling_params import GuidedDecodingParams, SamplingParams from vllm.sampling_params import GuidedDecodingParams, SamplingParams
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"] GUIDED_DECODING_BACKENDS = [
"outlines", "lm-format-enforcer", "xgrammar", "guidance"
]
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
......
# SPDX-License-Identifier: Apache-2.0
import openai # use the official client for correctness check
import pytest
import pytest_asyncio
from ...utils import RemoteOpenAIServer
# a reasoning and tool calling model
MODEL_NAME = "Qwen/QwQ-32B"
@pytest.fixture(scope="module")
def server(): # noqa: F811
args = [
"--max-model-len", "8192", "--enforce-eager", "--enable-reasoning",
"--reasoning-parser", "deepseek_r1", "--enable-auto-tool-choice",
"--tool-call-parser", "hermes"
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client
TOOLS = [{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type":
"string",
"description":
"The city to find the weather for, e.g. 'San Francisco'"
},
"state": {
"type":
"string",
"description":
"the two-letter abbreviation for the state that the city is"
" in, e.g. 'CA' which would mean 'California'"
},
"unit": {
"type": "string",
"description": "The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["city", "state", "unit"]
}
}
}]
MESSAGES = [{
"role": "user",
"content": "Hi! How are you doing today?"
}, {
"role": "assistant",
"content": "I'm doing well! How can I help you?"
}, {
"role":
"user",
"content":
"Can you tell me what the temperate will be in Dallas, in fahrenheit?"
}]
FUNC_NAME = "get_current_weather"
FUNC_ARGS = """{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}"""
def extract_reasoning_and_calls(chunks: list):
reasoning_content = ""
tool_call_idx = -1
arguments = []
function_names = []
for chunk in chunks:
if chunk.choices[0].delta.tool_calls:
tool_call = chunk.choices[0].delta.tool_calls[0]
if tool_call.index != tool_call_idx:
tool_call_idx = chunk.choices[0].delta.tool_calls[0].index
arguments.append("")
function_names.append("")
if tool_call.function:
if tool_call.function.name:
function_names[tool_call_idx] = tool_call.function.name
if tool_call.function.arguments:
arguments[tool_call_idx] += tool_call.function.arguments
else:
if hasattr(chunk.choices[0].delta, "reasoning_content"):
reasoning_content += chunk.choices[0].delta.reasoning_content
return reasoning_content, arguments, function_names
# test streaming
@pytest.mark.asyncio
async def test_chat_streaming_of_tool_and_reasoning(
client: openai.AsyncOpenAI):
stream = await client.chat.completions.create(
model=MODEL_NAME,
messages=MESSAGES,
tools=TOOLS,
temperature=0.0,
stream=True,
)
chunks = []
async for chunk in stream:
chunks.append(chunk)
reasoning_content, arguments, function_names = extract_reasoning_and_calls(
chunks)
assert len(reasoning_content) > 0
assert len(function_names) > 0 and function_names[0] == FUNC_NAME
assert len(arguments) > 0 and arguments[0] == FUNC_ARGS
# test full generate
@pytest.mark.asyncio
async def test_chat_full_of_tool_and_reasoning(client: openai.AsyncOpenAI):
tool_calls = await client.chat.completions.create(
model=MODEL_NAME,
messages=MESSAGES,
tools=TOOLS,
temperature=0.0,
stream=False,
)
assert len(tool_calls.choices[0].message.reasoning_content) > 0
assert tool_calls.choices[0].message.tool_calls[0].function.name \
== FUNC_NAME
assert tool_calls.choices[0].message.tool_calls[0].function.arguments \
== FUNC_ARGS
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from unittest.mock import Mock, patch from unittest.mock import patch
import pytest import pytest
import torch import torch
...@@ -8,7 +8,6 @@ import torch ...@@ -8,7 +8,6 @@ import torch
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
from vllm.platforms.cpu import CpuPlatform from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.openvino import OpenVinoPlatform
from vllm.platforms.rocm import RocmPlatform from vllm.platforms.rocm import RocmPlatform
from vllm.utils import STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, STR_INVALID_VAL from vllm.utils import STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, STR_INVALID_VAL
...@@ -21,9 +20,9 @@ def clear_cache(): ...@@ -21,9 +20,9 @@ def clear_cache():
@pytest.mark.parametrize( @pytest.mark.parametrize(
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"]) "name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER"])
@pytest.mark.parametrize("use_v1", [True, False]) @pytest.mark.parametrize("use_v1", [True, False])
@pytest.mark.parametrize("device", ["cpu", "openvino", "hip", "cuda"]) @pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
def test_env( def test_env(
name: str, name: str,
use_v1: bool, use_v1: bool,
...@@ -49,15 +48,8 @@ def test_env( ...@@ -49,15 +48,8 @@ def test_env(
RocmPlatform()): RocmPlatform()):
backend = get_attn_backend(16, torch.float16, torch.float16, backend = get_attn_backend(16, torch.float16, torch.float16,
16, False) 16, False)
EXPECTED = "ROCM_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH" EXPECTED = "TRITON_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH"
assert backend.get_name() == EXPECTED assert backend.get_name() == EXPECTED
elif device == "openvino":
with patch("vllm.attention.selector.current_platform",
OpenVinoPlatform()), patch.dict('sys.modules',
{'openvino': Mock()}):
backend = get_attn_backend(16, torch.float16, torch.float16,
16, False)
assert backend.get_name() == "OPENVINO"
else: else:
if name in ["XFORMERS", "FLASHINFER"]: if name in ["XFORMERS", "FLASHINFER"]:
with patch("vllm.attention.selector.current_platform", with patch("vllm.attention.selector.current_platform",
......
...@@ -15,6 +15,7 @@ NUM_HEADS = [(4, 4), (8, 2), (16, 2)] ...@@ -15,6 +15,7 @@ NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
HEAD_SIZES = [128, 256] HEAD_SIZES = [128, 256]
BLOCK_SIZES = [16, 32] BLOCK_SIZES = [16, 32]
DTYPES = [torch.float16, torch.bfloat16] DTYPES = [torch.float16, torch.bfloat16]
QDTYPES = [None, torch.float8_e4m3fn]
# one value large enough to test overflow in index calculation. # one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check # one value small enough to test the schema op check
NUM_BLOCKS = [32768, 2048] NUM_BLOCKS = [32768, 2048]
...@@ -85,6 +86,7 @@ def ref_paged_attn( ...@@ -85,6 +86,7 @@ def ref_paged_attn(
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("sliding_window", [None, 256]) @pytest.mark.parametrize("sliding_window", [None, 256])
@pytest.mark.parametrize("fa_version", [2, 3]) @pytest.mark.parametrize("fa_version", [2, 3])
@pytest.mark.parametrize("q_dtype", QDTYPES)
@torch.inference_mode() @torch.inference_mode()
def test_flash_attn_with_paged_kv( def test_flash_attn_with_paged_kv(
use_out: bool, use_out: bool,
...@@ -97,11 +99,15 @@ def test_flash_attn_with_paged_kv( ...@@ -97,11 +99,15 @@ def test_flash_attn_with_paged_kv(
num_blocks: int, num_blocks: int,
sliding_window: Optional[int], sliding_window: Optional[int],
fa_version: int, fa_version: int,
q_dtype: Optional[torch.dtype],
) -> None: ) -> None:
torch.set_default_device("cuda") torch.set_default_device("cuda")
if not is_fa_version_supported(fa_version): if not is_fa_version_supported(fa_version):
pytest.skip(f"Flash attention version {fa_version} not supported due " pytest.skip(f"Flash attention version {fa_version} not supported due "
f"to: \"{fa_version_unsupported_reason(fa_version)}\"") f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2):
pytest.skip("Flash attention with quantized inputs is only "
"supported on version 3 with bfloat16 base type")
current_platform.seed_everything(0) current_platform.seed_everything(0)
num_seqs = len(kv_lens) num_seqs = len(kv_lens)
...@@ -130,10 +136,28 @@ def test_flash_attn_with_paged_kv( ...@@ -130,10 +136,28 @@ def test_flash_attn_with_paged_kv(
q = query.unsqueeze(1) q = query.unsqueeze(1)
out = torch.empty_like(q) if use_out else None out = torch.empty_like(q) if use_out else None
maybe_quantized_query = q
maybe_quantized_key_cache = key_cache
maybe_quantized_value_cache = value_cache
q_descale = None
k_descale = None
v_descale = None
if q_dtype is not None:
# QKV are drawn from N(0, 1): no need for a fp8 scaling factor
maybe_quantized_query = query.to(q_dtype)
maybe_quantized_key_cache = key_cache.to(q_dtype)
maybe_quantized_value_cache = value_cache.to(q_dtype)
scale_shape = (num_seqs, num_kv_heads)
q_descale = torch.ones(scale_shape, dtype=torch.float32)
k_descale = torch.ones(scale_shape, dtype=torch.float32)
v_descale = torch.ones(scale_shape, dtype=torch.float32)
output = flash_attn_with_kvcache( output = flash_attn_with_kvcache(
q=q, q=maybe_quantized_query,
k_cache=key_cache, k_cache=maybe_quantized_key_cache,
v_cache=value_cache, v_cache=maybe_quantized_value_cache,
out=out, out=out,
softmax_scale=scale, softmax_scale=scale,
causal=True, causal=True,
...@@ -142,10 +166,17 @@ def test_flash_attn_with_paged_kv( ...@@ -142,10 +166,17 @@ def test_flash_attn_with_paged_kv(
softcap=soft_cap if soft_cap is not None else 0, softcap=soft_cap if soft_cap is not None else 0,
window_size=window_size, window_size=window_size,
fa_version=fa_version, fa_version=fa_version,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
) )
output = output if not use_out else out output = output if not use_out else out
output = output.squeeze(1) output = output.squeeze(1)
atol, rtol = 1.5e-2, 1e-2
if q_dtype is not None:
atol, rtol = 1.5e-1, 1.5e-1
ref_output = ref_paged_attn(query=query, ref_output = ref_paged_attn(query=query,
key_cache=key_cache, key_cache=key_cache,
value_cache=value_cache, value_cache=value_cache,
...@@ -155,7 +186,7 @@ def test_flash_attn_with_paged_kv( ...@@ -155,7 +186,7 @@ def test_flash_attn_with_paged_kv(
scale=scale, scale=scale,
soft_cap=soft_cap, soft_cap=soft_cap,
sliding_window=sliding_window) sliding_window=sliding_window)
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \ torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
f"{torch.max(torch.abs(output - ref_output))}" f"{torch.max(torch.abs(output - ref_output))}"
...@@ -171,6 +202,7 @@ def test_flash_attn_with_paged_kv( ...@@ -171,6 +202,7 @@ def test_flash_attn_with_paged_kv(
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) @pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("fa_version", [2, 3]) @pytest.mark.parametrize("fa_version", [2, 3])
@pytest.mark.parametrize("q_dtype", QDTYPES)
@torch.inference_mode() @torch.inference_mode()
def test_varlen_with_paged_kv( def test_varlen_with_paged_kv(
use_out: bool, use_out: bool,
...@@ -183,11 +215,15 @@ def test_varlen_with_paged_kv( ...@@ -183,11 +215,15 @@ def test_varlen_with_paged_kv(
soft_cap: Optional[float], soft_cap: Optional[float],
num_blocks: int, num_blocks: int,
fa_version: int, fa_version: int,
q_dtype: Optional[torch.dtype],
) -> None: ) -> None:
torch.set_default_device("cuda") torch.set_default_device("cuda")
if not is_fa_version_supported(fa_version): if not is_fa_version_supported(fa_version):
pytest.skip(f"Flash attention version {fa_version} not supported due " pytest.skip(f"Flash attention version {fa_version} not supported due "
f"to: \"{fa_version_unsupported_reason(fa_version)}\"") f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2):
pytest.skip("Flash attention with quantized inputs is only "
"supported on version 3 with bfloat16 base type")
current_platform.seed_everything(0) current_platform.seed_everything(0)
num_seqs = len(seq_lens) num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens] query_lens = [x[0] for x in seq_lens]
...@@ -223,10 +259,28 @@ def test_varlen_with_paged_kv( ...@@ -223,10 +259,28 @@ def test_varlen_with_paged_kv(
dtype=torch.int32) dtype=torch.int32)
out = torch.empty_like(query) if use_out else None out = torch.empty_like(query) if use_out else None
maybe_quantized_query = query
maybe_quantized_key_cache = key_cache
maybe_quantized_value_cache = value_cache
q_descale = None
k_descale = None
v_descale = None
if q_dtype is not None:
# QKV are drawn from N(0, 1): no need for a fp8 scaling factor
maybe_quantized_query = query.to(q_dtype)
maybe_quantized_key_cache = key_cache.to(q_dtype)
maybe_quantized_value_cache = value_cache.to(q_dtype)
scale_shape = (num_seqs, num_kv_heads)
q_descale = torch.ones(scale_shape, dtype=torch.float32)
k_descale = torch.ones(scale_shape, dtype=torch.float32)
v_descale = torch.ones(scale_shape, dtype=torch.float32)
output = flash_attn_varlen_func( output = flash_attn_varlen_func(
q=query, q=maybe_quantized_query,
k=key_cache, k=maybe_quantized_key_cache,
v=value_cache, v=maybe_quantized_value_cache,
out=out, out=out,
cu_seqlens_q=cu_query_lens, cu_seqlens_q=cu_query_lens,
seqused_k=kv_lens, seqused_k=kv_lens,
...@@ -238,6 +292,9 @@ def test_varlen_with_paged_kv( ...@@ -238,6 +292,9 @@ def test_varlen_with_paged_kv(
block_table=block_tables, block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0, softcap=soft_cap if soft_cap is not None else 0,
fa_version=fa_version, fa_version=fa_version,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
) )
output = output if not use_out else out output = output if not use_out else out
...@@ -252,5 +309,8 @@ def test_varlen_with_paged_kv( ...@@ -252,5 +309,8 @@ def test_varlen_with_paged_kv(
sliding_window=sliding_window, sliding_window=sliding_window,
soft_cap=soft_cap, soft_cap=soft_cap,
) )
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \ atol, rtol = 1.5e-2, 1e-2
if q_dtype is not None:
atol, rtol = 1.5e-1, 1.5e-1
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
f"{torch.max(torch.abs(output - ref_output))}" f"{torch.max(torch.abs(output - ref_output))}"
...@@ -26,7 +26,7 @@ def test_selector(monkeypatch: pytest.MonkeyPatch): ...@@ -26,7 +26,7 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
# Test standard ROCm attention # Test standard ROCm attention
backend = get_attn_backend(16, torch.float16, torch.float16, 16, False) backend = get_attn_backend(16, torch.float16, torch.float16, 16, False)
assert (backend.get_name() == "ROCM_FLASH" assert (backend.get_name() == "ROCM_FLASH"
or backend.get_name() == "ROCM_ATTN_VLLM_V1") or backend.get_name() == "TRITON_ATTN_VLLM_V1")
# mla test for deepseek related # mla test for deepseek related
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import asyncio import asyncio
import time import time
from pathlib import Path
import pytest import pytest
from huggingface_hub import snapshot_download
import vllm.envs as env import vllm.envs as env
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
...@@ -13,35 +11,9 @@ from vllm.lora.request import LoRARequest ...@@ -13,35 +11,9 @@ from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import merge_async_iterators from vllm.utils import merge_async_iterators
MODEL_PATH = "meta-llama/Llama-2-7b-hf" MODEL_PATH = "THUDM/chatglm3-6b"
LORA_MODULE_DOWNLOAD_PATH = None # Populated by download_and_prepare_lora_module() #noqa LORA_RANK = 64
LORA_RANK = 8 DEFAULT_MAX_LORAS = 4 * 3
DEFAULT_MAX_LORAS = 16 * 3
def download_and_prepare_lora_module():
"""
Request submission is expensive when the LoRA adapters have their own
tokenizers. This is because, for each request with a new LoRA adapter ID,
the front-end loads the tokenizer from disk.
In this test, as we are comparing request processing times, we want to
minimize any extra activity. To this effect, we download the LoRA
adapter and remove all the tokenizer files, so the engine will default
to the base model tokenizer.
"""
global LORA_MODULE_DOWNLOAD_PATH
LORA_MODULE_HF_PATH = "yard1/llama-2-7b-sql-lora-test"
LORA_MODULE_DOWNLOAD_PATH = snapshot_download(repo_id=LORA_MODULE_HF_PATH)
tokenizer_files = [
'added_tokens.json', 'tokenizer_config.json', 'tokenizer.json',
'tokenizer.model'
]
for tokenizer_file in tokenizer_files:
del_path = Path(LORA_MODULE_DOWNLOAD_PATH) / tokenizer_file
del_path.unlink(missing_ok=True)
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
...@@ -52,11 +24,9 @@ def v1(run_with_both_engines_lora): ...@@ -52,11 +24,9 @@ def v1(run_with_both_engines_lora):
pass pass
def get_lora_requests() -> list[LoRARequest]: def get_lora_requests(lora_path) -> list[LoRARequest]:
lora_requests: list[LoRARequest] = [ lora_requests: list[LoRARequest] = [
LoRARequest(lora_name=f"{i}", LoRARequest(lora_name=f"{i}", lora_int_id=i, lora_path=lora_path)
lora_int_id=i,
lora_path=LORA_MODULE_DOWNLOAD_PATH)
for i in range(1, DEFAULT_MAX_LORAS + 1) for i in range(1, DEFAULT_MAX_LORAS + 1)
] ]
return lora_requests return lora_requests
...@@ -93,7 +63,7 @@ async def requests_processing_time(llm, ...@@ -93,7 +63,7 @@ async def requests_processing_time(llm,
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_add_lora(): async def test_add_lora(chatglm3_lora_files):
""" """
The add_lora function is used to pre-load some LoRA adapters into the The add_lora function is used to pre-load some LoRA adapters into the
engine in anticipation of future requests using these adapters. To test engine in anticipation of future requests using these adapters. To test
...@@ -103,10 +73,7 @@ async def test_add_lora(): ...@@ -103,10 +73,7 @@ async def test_add_lora():
We measure the request processing time in both cases and expect the time We measure the request processing time in both cases and expect the time
to be lesser in the case with add_lora() calls. to be lesser in the case with add_lora() calls.
""" """
lora_requests: list[LoRARequest] = get_lora_requests(chatglm3_lora_files)
download_and_prepare_lora_module()
lora_requests: list[LoRARequest] = get_lora_requests()
max_loras = len(set([lr.lora_int_id for lr in lora_requests])) max_loras = len(set([lr.lora_int_id for lr in lora_requests]))
# Create engine in eager-mode. Due to high max_loras, the CI can # Create engine in eager-mode. Due to high max_loras, the CI can
...@@ -118,6 +85,7 @@ async def test_add_lora(): ...@@ -118,6 +85,7 @@ async def test_add_lora():
max_lora_rank=LORA_RANK, max_lora_rank=LORA_RANK,
max_model_len=128, max_model_len=128,
gpu_memory_utilization=0.8, #avoid OOM gpu_memory_utilization=0.8, #avoid OOM
trust_remote_code=True,
enforce_eager=True) enforce_eager=True)
# The run_with_both_engines_lora fixture sets up the `VLLM_USE_V1` # The run_with_both_engines_lora fixture sets up the `VLLM_USE_V1`
......
...@@ -84,12 +84,14 @@ def v1(run_with_both_engines_lora): ...@@ -84,12 +84,14 @@ def v1(run_with_both_engines_lora):
@create_new_process_for_each_test() @create_new_process_for_each_test()
def test_llama_lora(sql_lora_files): def test_llama_lora(sql_lora_files):
llm = vllm.LLM(MODEL_PATH, llm = vllm.LLM(
enable_lora=True, MODEL_PATH,
max_num_seqs=16, enable_lora=True,
max_loras=4, # also test odd max_num_seqs
tensor_parallel_size=1, max_num_seqs=13,
enable_chunked_prefill=True) max_loras=4,
tensor_parallel_size=1,
enable_chunked_prefill=True)
generate_and_test(llm, sql_lora_files) generate_and_test(llm, sql_lora_files)
......
...@@ -24,12 +24,10 @@ async def test_tokenizer_group_lora(sql_lora_files, tokenizer_group_type): ...@@ -24,12 +24,10 @@ async def test_tokenizer_group_lora(sql_lora_files, tokenizer_group_type):
) )
lora_request = LoRARequest("1", 1, sql_lora_files) lora_request = LoRARequest("1", 1, sql_lora_files)
assert reference_tokenizer.encode("prompt") == tokenizer_group.encode( assert reference_tokenizer.encode("prompt") == tokenizer_group.encode(
request_id="request_id", prompt="prompt", lora_request=lora_request) prompt="prompt", lora_request=lora_request)
assert reference_tokenizer.encode( assert reference_tokenizer.encode(
"prompt") == await tokenizer_group.encode_async( "prompt") == await tokenizer_group.encode_async(
request_id="request_id", prompt="prompt", lora_request=lora_request)
prompt="prompt",
lora_request=lora_request)
assert isinstance(tokenizer_group.get_lora_tokenizer(None), assert isinstance(tokenizer_group.get_lora_tokenizer(None),
PreTrainedTokenizerBase) PreTrainedTokenizerBase)
assert tokenizer_group.get_lora_tokenizer( assert tokenizer_group.get_lora_tokenizer(
......
...@@ -7,7 +7,10 @@ from vllm.model_executor.custom_op import CustomOp ...@@ -7,7 +7,10 @@ from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.activation import (GeluAndMul, from vllm.model_executor.layers.activation import (GeluAndMul,
ReLUSquaredActivation, ReLUSquaredActivation,
SiluAndMul) SiluAndMul)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import (
RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm,
rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm)
from vllm.platforms import current_platform
# Registered subclass for test # Registered subclass for test
...@@ -87,3 +90,27 @@ def test_enabled_ops_invalid(env: str): ...@@ -87,3 +90,27 @@ def test_enabled_ops_invalid(env: str):
custom_ops=env.split(","))) custom_ops=env.split(",")))
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
RMSNorm(1024).enabled() RMSNorm(1024).enabled()
@pytest.mark.parametrize("add_residual", [True, False])
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"])
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="AITER is a feature exclusive for ROCm")
def test_rms_norm_dispatch(add_residual: bool, use_rocm_aiter: str,
use_rocm_aiter_norm: str, monkeypatch):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
monkeypatch.setenv("VLLM_ROCM_USE_AITER_RMSNORM", use_rocm_aiter_norm)
rms_norm_func = dispatch_cuda_rmsnorm_func(add_residual)
if not add_residual:
if current_platform.is_rocm() and int(use_rocm_aiter) and int(
use_rocm_aiter_norm):
assert rms_norm_func == rocm_aiter_rms_norm
else:
assert rms_norm_func == rms_norm
elif current_platform.is_rocm() and int(use_rocm_aiter) and int(
use_rocm_aiter_norm):
assert rms_norm_func == rocm_aiter_fused_add_rms_norm
else:
assert rms_norm_func == fused_add_rms_norm
...@@ -16,7 +16,9 @@ from vllm.model_executor.guided_decoding.outlines_logits_processors import ( ...@@ -16,7 +16,9 @@ from vllm.model_executor.guided_decoding.outlines_logits_processors import (
from vllm.sampling_params import GuidedDecodingParams from vllm.sampling_params import GuidedDecodingParams
MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta' MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta'
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"] GUIDED_DECODING_BACKENDS = [
"outlines", "lm-format-enforcer", "xgrammar", "guidance"
]
GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT = ["outlines", "xgrammar"] GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT = ["outlines", "xgrammar"]
REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
......
...@@ -3,7 +3,11 @@ ...@@ -3,7 +3,11 @@
Run `pytest tests/models/test_models.py`. Run `pytest tests/models/test_models.py`.
""" """
import pytest import pytest
import torch
from vllm.platforms import current_platform
from ...utils import check_logprobs_close from ...utils import check_logprobs_close
...@@ -13,7 +17,21 @@ from ...utils import check_logprobs_close ...@@ -13,7 +17,21 @@ from ...utils import check_logprobs_close
# https://github.com/vllm-project/vllm/issues/14524 # https://github.com/vllm-project/vllm/issues/14524
REQUIRES_V0 = ["microsoft/phi-2", "stabilityai/stablelm-3b-4e1t"] REQUIRES_V0 = ["microsoft/phi-2", "stabilityai/stablelm-3b-4e1t"]
# This list contains the model that are using AITER kernel.
# Skip model that are not using AITER tests.
# When more AITER kernels are added, this list will not be
# needed as all the models will be calling AITER kernels
# in parts of the operators
AITER_MODEL_LIST = [
"meta-llama/Llama-3.2-1B-Instruct",
"openbmb/MiniCPM3-4B",
"Qwen/Qwen-7B",
"Qwen/Qwen2.5-0.5B-Instruct",
"ehristoforu/Falcon3-MoE-2x7B-Insruct",
]
# @maybe_test_rocm_aiter
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model", "model",
[ [
...@@ -69,19 +87,24 @@ REQUIRES_V0 = ["microsoft/phi-2", "stabilityai/stablelm-3b-4e1t"] ...@@ -69,19 +87,24 @@ REQUIRES_V0 = ["microsoft/phi-2", "stabilityai/stablelm-3b-4e1t"]
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("num_logprobs", [5])
def test_models( @pytest.mark.parametrize(
hf_runner, "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
vllm_runner, def test_models(hf_runner, vllm_runner, example_prompts, model: str,
example_prompts, dtype: str, max_tokens: int, num_logprobs: int,
model: str, use_rocm_aiter: bool, monkeypatch) -> None:
dtype: str,
max_tokens: int,
num_logprobs: int,
monkeypatch,
) -> None:
if model in REQUIRES_V0: if model in REQUIRES_V0:
monkeypatch.setenv("VLLM_USE_V1", "0") monkeypatch.setenv("VLLM_USE_V1", "0")
if use_rocm_aiter and (model in AITER_MODEL_LIST):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
elif use_rocm_aiter and model not in AITER_MODEL_LIST:
# Skip model that are not using AITER tests.
# When more AITER kernels are added, this list will not be
# needed as all the models will be calling AITER kernels
# in parts of the operators
pytest.skip(f"Skipping '{model}' model test with AITER kernel.")
with hf_runner(model, dtype=dtype) as hf_model: with hf_runner(model, dtype=dtype) as hf_model:
if model.startswith("THUDM/chatglm3"): if model.startswith("THUDM/chatglm3"):
hf_model.model.get_output_embeddings = lambda: \ hf_model.model.get_output_embeddings = lambda: \
...@@ -100,3 +123,10 @@ def test_models( ...@@ -100,3 +123,10 @@ def test_models(
name_0="hf", name_0="hf",
name_1="vllm", name_1="vllm",
) )
if use_rocm_aiter:
# this is to ensure that vllm engine
# has deallocated the memory before running the next
# unit tests. On ROCm, when using AITER
# the memory might not be deallocated completely
# before running the next test case
torch.cuda.synchronize()
...@@ -508,6 +508,19 @@ VLM_TEST_SETTINGS = { ...@@ -508,6 +508,19 @@ VLM_TEST_SETTINGS = {
limit_mm_per_prompt={"image": 4}, limit_mm_per_prompt={"image": 4},
)], )],
), ),
# regression test for https://github.com/vllm-project/vllm/issues/15122
"qwen2_5_vl-windows-attention": VLMTestInfo(
models=["Qwen/Qwen2.5-VL-3B-Instruct"],
test_type=VLMTestType.CUSTOM_INPUTS,
max_model_len=4096,
max_num_seqs=2,
auto_cls=AutoModelForVision2Seq,
vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output,
custom_test_opts=[CustomTestOptions(
inputs=custom_inputs.windows_attention_image_qwen2_5_vl(),
limit_mm_per_prompt={"image": 1},
)],
),
} }
# yapf: enable # yapf: enable
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Custom input builders for edge-cases in different models.""" """Custom input builders for edge-cases in different models."""
from io import BytesIO
from typing import Callable from typing import Callable
import requests
from PIL import Image
from vllm.multimodal.image import rescale_image_size from vllm.multimodal.image import rescale_image_size
from vllm.multimodal.video import (rescale_video_size, resize_video, from vllm.multimodal.video import (rescale_video_size, resize_video,
sample_frames_from_video) sample_frames_from_video)
...@@ -102,3 +106,17 @@ def different_patch_input_cases_internvl(): ...@@ -102,3 +106,17 @@ def different_patch_input_cases_internvl():
build_single_image_inputs(images, formatted_sprompts, wrapped_sf), build_single_image_inputs(images, formatted_sprompts, wrapped_sf),
build_multi_image_inputs([images], formatted_mprompts, wrapped_sf), build_multi_image_inputs([images], formatted_mprompts, wrapped_sf),
] ]
def windows_attention_image_qwen2_5_vl():
# image from regression issue: https://github.com/vllm-project/vllm/issues/15122
image_url = "https://aomediacodec.github.io/av1-avif/testFiles/Link-U/hato.jpg"
image = Image.open(BytesIO(requests.get(image_url).content))
question = "Describe the image."
img_prompt = "<|vision_start|><|image_pad|><|vision_end|>"
prompt = (f"<|im_start|>User\n{img_prompt}{question}<|im_end|>\n"
"<|im_start|>assistant\n")
wrapped_sf = ImageSizeWrapper(type=SizeType.SIZE_FACTOR, data=[0.5])
return build_single_image_inputs([image], [prompt], wrapped_sf)
...@@ -215,7 +215,6 @@ def _run_test( ...@@ -215,7 +215,6 @@ def _run_test(
max_num_seqs=2, max_num_seqs=2,
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend, distributed_executor_backend=distributed_executor_backend,
enforce_eager=True,
limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT
}) as vllm_model: }) as vllm_model:
vllm_outputs_per_image = [ vllm_outputs_per_image = [
...@@ -425,7 +424,6 @@ def test_bnb_regression( ...@@ -425,7 +424,6 @@ def test_bnb_regression(
dtype=dtype, dtype=dtype,
max_model_len=4096, max_model_len=4096,
max_num_seqs=2, max_num_seqs=2,
enforce_eager=True,
quantization="bitsandbytes", quantization="bitsandbytes",
load_format="bitsandbytes", load_format="bitsandbytes",
) )
...@@ -481,7 +479,6 @@ def test_explicit_implicit_prompt( ...@@ -481,7 +479,6 @@ def test_explicit_implicit_prompt(
max_model_len=4096, max_model_len=4096,
max_num_seqs=2, max_num_seqs=2,
tensor_parallel_size=1, tensor_parallel_size=1,
enforce_eager=True,
) )
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=0, temperature=0,
...@@ -513,7 +510,6 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens, ...@@ -513,7 +510,6 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens,
max_model_len=4096, max_model_len=4096,
max_num_seqs=2, max_num_seqs=2,
tensor_parallel_size=1, tensor_parallel_size=1,
enforce_eager=True,
limit_mm_per_prompt={"image": limit_mm_per_prompt={"image":
_LIMIT_IMAGE_PER_PROMPT}) as vllm_model: _LIMIT_IMAGE_PER_PROMPT}) as vllm_model:
......
...@@ -192,6 +192,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -192,6 +192,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct"), "SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct"),
"TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B", "TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B",
trust_remote_code=True), trust_remote_code=True),
"TeleFLMForCausalLM": _HfExamplesInfo("CofeAI/FLM-2-52B-Instruct-2407",
trust_remote_code=True),
"XverseForCausalLM": _HfExamplesInfo("xverse/XVERSE-7B-Chat", "XverseForCausalLM": _HfExamplesInfo("xverse/XVERSE-7B-Chat",
is_available_online=False, is_available_online=False,
trust_remote_code=True), trust_remote_code=True),
......
...@@ -6,7 +6,7 @@ from vllm.core.scheduler import Scheduler ...@@ -6,7 +6,7 @@ from vllm.core.scheduler import Scheduler
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.core.scheduler import Scheduler as V1Scheduler from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
......
...@@ -56,7 +56,7 @@ def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, ...@@ -56,7 +56,7 @@ def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
def maybe_assert_ngram_worker(llm): def maybe_assert_ngram_worker(llm):
# Verify the proposer worker is ngram if ngram is specified. # Verify the proposer worker is ngram if ngram is specified.
if (llm.llm_engine.speculative_config is not None if (llm.llm_engine.speculative_config is not None
and llm.llm_engine.speculative_config.ngram_prompt_lookup_max > 0): and llm.llm_engine.speculative_config.method == "ngram"):
from vllm.spec_decode.ngram_worker import NGramWorker from vllm.spec_decode.ngram_worker import NGramWorker
assert isinstance( assert isinstance(
llm.llm_engine.model_executor.driver_worker.proposer_worker, llm.llm_engine.model_executor.driver_worker.proposer_worker,
......
...@@ -7,28 +7,39 @@ from vllm import SamplingParams ...@@ -7,28 +7,39 @@ from vllm import SamplingParams
from .conftest import get_output_from_llm_generator from .conftest import get_output_from_llm_generator
@pytest.mark.parametrize("common_llm_kwargs", [{ @pytest.mark.parametrize("common_llm_kwargs",
"model": "meta-llama/Llama-3.2-1B-Instruct", [{
"speculative_model": "JackFram/llama-68m", "model": "meta-llama/Llama-3.2-1B-Instruct",
"num_speculative_tokens": 5, }])
}])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"per_test_common_llm_kwargs", "per_test_common_llm_kwargs",
[ [
{ {
# Speculative max model len > overridden max model len should raise. # Speculative max model len > overridden max model len should raise.
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"max_model_len": 129,
},
"max_model_len": 128, "max_model_len": 128,
"speculative_max_model_len": 129,
}, },
{ {
# Speculative max model len > draft max model len should raise. # Speculative max model len > draft max model len should raise.
# https://huggingface.co/JackFram/llama-68m/blob/3b606af5198a0b26762d589a3ee3d26ee6fa6c85/config.json#L12 # https://huggingface.co/JackFram/llama-68m/blob/3b606af5198a0b26762d589a3ee3d26ee6fa6c85/config.json#L12
"speculative_max_model_len": 2048 + 1, "speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"max_model_len": 2048 + 1,
},
}, },
{ {
# Speculative max model len > target max model len should raise. # Speculative max model len > target max model len should raise.
# https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct/blob/9213176726f574b556790deb65791e0c5aa438b6/config.json#L18 # https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/blob/9213176726f574b556790deb65791e0c5aa438b6/config.json#L18
"speculative_max_model_len": 131072 + 1, "speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"max_model_len": 131072 + 1,
},
}, },
]) ])
@pytest.mark.parametrize("test_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [{}])
......
...@@ -57,8 +57,10 @@ PRECISION = "float32" ...@@ -57,8 +57,10 @@ PRECISION = "float32"
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": SPEC_MODEL, "speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS, "model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
}, },
]) ])
@pytest.mark.parametrize("output_len", [ @pytest.mark.parametrize("output_len", [
...@@ -95,18 +97,19 @@ def test_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, ...@@ -95,18 +97,19 @@ def test_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [{
{ "speculative_config": {
"speculative_model": SPEC_MODEL, "model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS, "num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs_during_spec_decoding": False, "disable_logprobs": False,
}, },
{ }, {
"speculative_model": SPEC_MODEL, "speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS, "num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs_during_spec_decoding": True, "disable_logprobs": True,
}, },
]) }])
@pytest.mark.parametrize("output_len", [ @pytest.mark.parametrize("output_len", [
128, 128,
]) ])
...@@ -119,18 +122,19 @@ def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, ...@@ -119,18 +122,19 @@ def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
batch_size: int, output_len: int, seed: int, batch_size: int, output_len: int, seed: int,
logprobs: int): logprobs: int):
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(
common_llm_kwargs, vllm_runner,
per_test_common_llm_kwargs, common_llm_kwargs,
baseline_llm_kwargs, per_test_common_llm_kwargs,
test_llm_kwargs, baseline_llm_kwargs,
batch_size, test_llm_kwargs,
output_len, batch_size,
seed, output_len,
logprobs=logprobs, seed,
prompt_logprobs=logprobs, logprobs=logprobs,
disable_logprobs=test_llm_kwargs[ prompt_logprobs=logprobs,
'disable_logprobs_during_spec_decoding']) disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -151,8 +155,10 @@ def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, ...@@ -151,8 +155,10 @@ def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": SPEC_MODEL, "speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS, "model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
}, },
]) ])
@pytest.mark.parametrize("output_len", [ @pytest.mark.parametrize("output_len", [
...@@ -193,8 +199,10 @@ def test_eagle_e2e_greedy_correctness_cuda_graph( ...@@ -193,8 +199,10 @@ def test_eagle_e2e_greedy_correctness_cuda_graph(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": SPEC_MODEL, "speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS, "model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
}, },
]) ])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -236,8 +244,10 @@ def test_eagle_e2e_greedy_correctness_with_preemption( ...@@ -236,8 +244,10 @@ def test_eagle_e2e_greedy_correctness_with_preemption(
"test_llm_kwargs", "test_llm_kwargs",
[ [
{ {
"speculative_model": SPEC_MODEL, "speculative_config": {
"num_speculative_tokens": k, "model": SPEC_MODEL,
"num_speculative_tokens": k,
},
} }
# Try a range of num. speculative tokens # Try a range of num. speculative tokens
for k in range(1, 1 + MAX_SPEC_TOKENS) for k in range(1, 1 + MAX_SPEC_TOKENS)
...@@ -277,12 +287,13 @@ def test_eagle_different_k(vllm_runner, common_llm_kwargs, ...@@ -277,12 +287,13 @@ def test_eagle_different_k(vllm_runner, common_llm_kwargs,
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", @pytest.mark.parametrize("test_llm_kwargs", [{
[{ "speculative_config": {
"speculative_model": SPEC_MODEL, "model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS, "num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_disable_by_batch_size": 4 "disable_by_batch_size": 4,
}]) },
}])
@pytest.mark.parametrize("batch_size", [1, 5]) @pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"output_len", "output_len",
...@@ -324,8 +335,10 @@ def test_eagle_disable_queue(vllm_runner, common_llm_kwargs, ...@@ -324,8 +335,10 @@ def test_eagle_disable_queue(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": "yuhuili/EAGLE-llama2-chat-7B", "speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS, "model": "yuhuili/EAGLE-llama2-chat-7B",
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
}, },
]) ])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -372,8 +385,10 @@ def test_llama2_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, ...@@ -372,8 +385,10 @@ def test_llama2_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": "yuhuili/EAGLE-LLaMA3-Instruct-8B", "speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS, "model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
}, },
]) ])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -420,8 +435,10 @@ def test_llama3_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, ...@@ -420,8 +435,10 @@ def test_llama3_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": "yuhuili/EAGLE-Qwen2-7B-Instruct", "speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS, "model": "yuhuili/EAGLE-Qwen2-7B-Instruct",
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
}, },
]) ])
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment