Commit 675ba75f authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.3' into v0.8.3-ori

parents 5cc98918 296c6572
...@@ -2,21 +2,20 @@ ...@@ -2,21 +2,20 @@
from __future__ import annotations from __future__ import annotations
from typing import Any from typing import Any, Union
import pytest import pytest
import torch import torch
from tests.quantization.utils import is_quant_method_supported from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.config import CompilationLevel from vllm.config import CompilationConfig, CompilationLevel
from vllm.platforms import current_platform from vllm.platforms import current_platform
from ..utils import create_new_process_for_each_test from ..utils import create_new_process_for_each_test
@pytest.fixture(params=None, name="model_info") def models_list(all: bool):
def models_list_fixture(request):
TEST_MODELS: list[tuple[str, dict[str, Any]]] = [ TEST_MODELS: list[tuple[str, dict[str, Any]]] = [
("facebook/opt-125m", {}), ("facebook/opt-125m", {}),
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", { ("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
...@@ -33,6 +32,9 @@ def models_list_fixture(request): ...@@ -33,6 +32,9 @@ def models_list_fixture(request):
("meta-llama/Llama-3.2-1B-Instruct", {}), ("meta-llama/Llama-3.2-1B-Instruct", {}),
] ]
if not all:
return TEST_MODELS
if is_quant_method_supported("aqlm"): if is_quant_method_supported("aqlm"):
TEST_MODELS.append(("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", { TEST_MODELS.append(("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", {
"quantization": "aqlm" "quantization": "aqlm"
...@@ -77,7 +79,7 @@ def models_list_fixture(request): ...@@ -77,7 +79,7 @@ def models_list_fixture(request):
"optimization_level", "optimization_level",
[CompilationLevel.DYNAMO_ONCE, CompilationLevel.PIECEWISE], [CompilationLevel.DYNAMO_ONCE, CompilationLevel.PIECEWISE],
) )
@pytest.mark.parametrize("model_info", "", indirect=True) @pytest.mark.parametrize("model_info", models_list(all=True))
@create_new_process_for_each_test() @create_new_process_for_each_test()
def test_full_graph( def test_full_graph(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
...@@ -91,25 +93,50 @@ def test_full_graph( ...@@ -91,25 +93,50 @@ def test_full_graph(
m.setenv("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") m.setenv("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1")
print(f"MODEL={model}") print(f"MODEL={model}")
prompts = [ run_model(optimization_level, model, model_kwargs)
"Hello, my name is",
"The president of the United States is",
"The capital of France is", # TODO(luka) add other supported compilation config scenarios here
"The future of AI is", @pytest.mark.parametrize(
] "compilation_config",
sampling_params = SamplingParams(temperature=0) # additional compile sizes
llm = LLM( [
model=model, CompilationConfig(level=CompilationLevel.PIECEWISE,
enforce_eager=True, compile_sizes=[1, 2])
tensor_parallel_size=1, ])
disable_custom_all_reduce=True, # only test some of the models
compilation_config=optimization_level, @pytest.mark.parametrize("model_info", models_list(all=False))
**model_kwargs, @create_new_process_for_each_test()
) def test_custom_compile_config(
outputs = llm.generate(prompts, sampling_params) model_info: tuple[str, dict[str, Any]],
compilation_config: CompilationConfig,
# Print the outputs. ):
for output in outputs: model, model_kwargs = model_info
prompt = output.prompt print(f"MODEL={model}")
generated_text = output.outputs[0].text run_model(compilation_config, model, model_kwargs)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
def run_model(compile_config: Union[int, CompilationConfig], model: str,
model_kwargs: dict[str, Any]):
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0)
llm = LLM(
model=model,
enforce_eager=True,
tensor_parallel_size=1,
disable_custom_all_reduce=True,
compilation_config=compile_config,
**model_kwargs,
)
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
import pytest import pytest
import torch import torch
from compressed_tensors.quantization import FP8_DTYPE
import vllm.envs as envs import vllm.envs as envs
import vllm.plugins import vllm.plugins
...@@ -14,9 +13,12 @@ from vllm.config import CompilationConfig, CompilationLevel, VllmConfig ...@@ -14,9 +13,12 @@ from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity) CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity)
from vllm.platforms import current_platform
from .backend import TestBackend from .backend import TestBackend
FP8_DTYPE = current_platform.fp8_dtype()
class TestModel(torch.nn.Module): class TestModel(torch.nn.Module):
...@@ -59,8 +61,8 @@ class TestModel(torch.nn.Module): ...@@ -59,8 +61,8 @@ class TestModel(torch.nn.Module):
@pytest.mark.parametrize("static", [True, False]) @pytest.mark.parametrize("static", [True, False])
@pytest.mark.parametrize("cutlass_fp8_enabled", @pytest.mark.parametrize("cutlass_fp8_enabled",
[True, False] if CUTLASS_FP8_SUPPORTED else [False]) [True, False] if CUTLASS_FP8_SUPPORTED else [False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
reason="Only test on CUDA") reason="Only test on CUDA and ROCm")
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
cutlass_fp8_enabled): cutlass_fp8_enabled):
torch.set_default_device("cuda") torch.set_default_device("cuda")
......
# Same as test_config.yaml but with model specified
model: config-model
port: 12312
served_model_name: mymodel
tensor_parallel_size: 2
trust_remote_code: true
multi_step_stream_outputs: false
...@@ -747,30 +747,27 @@ class VllmRunner: ...@@ -747,30 +747,27 @@ class VllmRunner:
videos: Optional[PromptVideoInput] = None, videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None, audios: Optional[PromptAudioInput] = None,
) -> list[TextPrompt]: ) -> list[TextPrompt]:
if images is not None:
assert len(prompts) == len(images)
if videos is not None:
assert len(prompts) == len(videos)
if audios is not None: if any(x is not None and len(x) != len(prompts)
assert len(prompts) == len(audios) for x in [images, videos, audios]):
raise ValueError(
"All non-None multimodal inputs must have the same length as "
"prompts")
inputs = [TextPrompt(prompt=prompt) for prompt in prompts] inputs = []
if images is not None: for i, prompt in enumerate(prompts):
for i, image in enumerate(images): multi_modal_data = {}
if image is not None: if images is not None and (image := images[i]) is not None:
inputs[i]["multi_modal_data"] = {"image": image} multi_modal_data["image"] = image
if videos is not None and (video := videos[i]) is not None:
if videos is not None: multi_modal_data["video"] = video
for i, video in enumerate(videos): if audios is not None and (audio := audios[i]) is not None:
if video is not None: multi_modal_data["audio"] = audio
inputs[i]["multi_modal_data"] = {"video": video}
if audios is not None: inputs.append(
for i, audio in enumerate(audios): TextPrompt(prompt=prompt,
if audio is not None: multi_modal_data=multi_modal_data
inputs[i]["multi_modal_data"] = {"audio": audio} if multi_modal_data else None))
return inputs return inputs
...@@ -1120,3 +1117,15 @@ def pytest_collection_modifyitems(config, items): ...@@ -1120,3 +1117,15 @@ def pytest_collection_modifyitems(config, items):
for item in items: for item in items:
if "optional" in item.keywords: if "optional" in item.keywords:
item.add_marker(skip_optional) item.add_marker(skip_optional)
@pytest.fixture(scope="session")
def cli_config_file():
"""Return the path to the CLI config file."""
return os.path.join(_TEST_DIR, "config", "test_config.yaml")
@pytest.fixture(scope="session")
def cli_config_file_with_model():
"""Return the path to the CLI config file with model."""
return os.path.join(_TEST_DIR, "config", "test_config_with_model.yaml")
...@@ -129,12 +129,16 @@ def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed, ...@@ -129,12 +129,16 @@ def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed,
check_answers(indices, answer, test_texts) check_answers(indices, answer, test_texts)
def prep_prompts(batch_size: int): def prep_prompts(batch_size: int, ln_range: tuple[int, int] = (800, 1100)):
""" """
Generate prompts which a bunch of assignments, Generate prompts which a bunch of assignments,
then asking for the value of one of them. then asking for the value of one of them.
The prompt is just under 10k tokens; sliding window is 4k The prompt is just under 10k tokens; sliding window is 4k
so the answer is outside sliding window, but should still be correct. so the answer is outside sliding window, but should still be correct.
Args:
batch_size: number of prompts to generate
ln_range: an argument to control the length of the prompt
""" """
prompts: list[str] = [] prompts: list[str] = []
answer: list[int] = [] answer: list[int] = []
...@@ -145,7 +149,7 @@ def prep_prompts(batch_size: int): ...@@ -145,7 +149,7 @@ def prep_prompts(batch_size: int):
indices.append(idx) indices.append(idx)
prompt = "```python\n# We set a number of variables, " + \ prompt = "```python\n# We set a number of variables, " + \
f"x{idx} will be important later\n" f"x{idx} will be important later\n"
ln = random.randint(800, 1100) ln = random.randint(*ln_range)
for k in range(30, ln): for k in range(30, ln):
v = random.randint(10, 99) v = random.randint(10, 99)
if k == idx: if k == idx:
...@@ -157,7 +161,10 @@ def prep_prompts(batch_size: int): ...@@ -157,7 +161,10 @@ def prep_prompts(batch_size: int):
return prompts, answer, indices return prompts, answer, indices
def check_answers(indices: list[int], answer: list[int], outputs: list[str]): def check_answers(indices: list[int],
answer: list[int],
outputs: list[str],
accept_rate: float = 0.7):
answer2 = [int(text[0:2].strip()) for text in outputs] answer2 = [int(text[0:2].strip()) for text in outputs]
print(list(zip(indices, zip(answer, answer2)))) print(list(zip(indices, zip(answer, answer2))))
numok = 0 numok = 0
...@@ -166,7 +173,7 @@ def check_answers(indices: list[int], answer: list[int], outputs: list[str]): ...@@ -166,7 +173,7 @@ def check_answers(indices: list[int], answer: list[int], outputs: list[str]):
numok += 1 numok += 1
frac_ok = numok / len(answer) frac_ok = numok / len(answer)
print(f"Num OK: {numok}/{len(answer)} {frac_ok}") print(f"Num OK: {numok}/{len(answer)} {frac_ok}")
assert frac_ok > 0.7 assert frac_ok >= accept_rate
def check_window(prompts: list[str]): def check_window(prompts: list[str]):
......
...@@ -106,7 +106,7 @@ def eager_allreduce( ...@@ -106,7 +106,7 @@ def eager_allreduce(
# communicate independently # communicate independently
num_communication = rank // tp_size + 1 num_communication = rank // tp_size + 1
sz = 1024 sz = 1024
fa = get_tp_group().ca_comm fa = get_tp_group().device_communicator.ca_comm
inp = torch.ones(sz, dtype=torch.float32, device=device) inp = torch.ones(sz, dtype=torch.float32, device=device)
out = inp out = inp
for _ in range(num_communication): for _ in range(num_communication):
......
...@@ -175,7 +175,7 @@ TEXT_GENERATION_MODELS = { ...@@ -175,7 +175,7 @@ TEXT_GENERATION_MODELS = {
"inceptionai/jais-13b-chat": PPTestSettings.fast(), "inceptionai/jais-13b-chat": PPTestSettings.fast(),
"ai21labs/Jamba-tiny-dev": PPTestSettings.fast(), "ai21labs/Jamba-tiny-dev": PPTestSettings.fast(),
"meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.detailed(), "meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.detailed(),
# Tests TransformersModel # Tests TransformersForCausalLM
"ArthurZ/Ilama-3.2-1B": PPTestSettings.fast(), "ArthurZ/Ilama-3.2-1B": PPTestSettings.fast(),
"openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(), "openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(),
"openbmb/MiniCPM3-4B": PPTestSettings.fast(), "openbmb/MiniCPM3-4B": PPTestSettings.fast(),
...@@ -217,7 +217,7 @@ EMBEDDING_MODELS = { # type: ignore[var-annotated] ...@@ -217,7 +217,7 @@ EMBEDDING_MODELS = { # type: ignore[var-annotated]
MULTIMODAL_MODELS = { MULTIMODAL_MODELS = {
# [Decoder-only] # [Decoder-only]
"Salesforce/blip2-opt-2.7b": PPTestSettings.fast(), "Salesforce/blip2-opt-6.7b": PPTestSettings.fast(),
"facebook/chameleon-7b": PPTestSettings.fast(), "facebook/chameleon-7b": PPTestSettings.fast(),
"adept/fuyu-8b": PPTestSettings.fast(), "adept/fuyu-8b": PPTestSettings.fast(),
"THUDM/glm-4v-9b": PPTestSettings.fast(), "THUDM/glm-4v-9b": PPTestSettings.fast(),
...@@ -245,7 +245,7 @@ TEST_MODELS = [ ...@@ -245,7 +245,7 @@ TEST_MODELS = [
# [LANGUAGE GENERATION] # [LANGUAGE GENERATION]
"microsoft/Phi-3.5-MoE-instruct", "microsoft/Phi-3.5-MoE-instruct",
"meta-llama/Llama-3.2-1B-Instruct", "meta-llama/Llama-3.2-1B-Instruct",
# "ArthurZ/Ilama-3.2-1B", NOTE: Uncomment after #13905 "ArthurZ/Ilama-3.2-1B",
"ibm/PowerLM-3b", "ibm/PowerLM-3b",
# [LANGUAGE EMBEDDING] # [LANGUAGE EMBEDDING]
"intfloat/e5-mistral-7b-instruct", "intfloat/e5-mistral-7b-instruct",
......
...@@ -13,18 +13,24 @@ import pytest ...@@ -13,18 +13,24 @@ import pytest
from vllm.platforms import current_platform from vllm.platforms import current_platform
MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct" MODEL_NAMES = [
"Qwen/Qwen2-1.5B-Instruct",
"google/gemma-3-1b-it",
]
NUM_CONCURRENT = 500 NUM_CONCURRENT = 500
TASK = "gsm8k" TASK = "gsm8k"
FILTER = "exact_match,strict-match" FILTER = "exact_match,strict-match"
RTOL = 0.03 RTOL = 0.03
EXPECTED_VALUE = 0.58 EXPECTED_VALUES = {
"Qwen/Qwen2-1.5B-Instruct": 0.58,
"google/gemma-3-1b-it": 0.25,
}
def run_test(more_args=None): def run_test(model_name, more_args=None):
"""Run the end to end accuracy test.""" """Run the end to end accuracy test."""
model_args = f"pretrained={MODEL_NAME},max_model_len=4096" model_args = f"pretrained={model_name},max_model_len=4096"
if more_args is not None: if more_args is not None:
model_args = "{},{}".format(model_args, more_args) model_args = "{},{}".format(model_args, more_args)
...@@ -37,9 +43,12 @@ def run_test(more_args=None): ...@@ -37,9 +43,12 @@ def run_test(more_args=None):
) )
measured_value = results["results"][TASK][FILTER] measured_value = results["results"][TASK][FILTER]
assert (measured_value - RTOL < EXPECTED_VALUE assert model_name in EXPECTED_VALUES, (
and measured_value + RTOL > EXPECTED_VALUE f"Cannot find the expected value for the model {model_name=}")
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" expected_value = EXPECTED_VALUES[model_name]
assert (measured_value - RTOL < expected_value
and measured_value + RTOL > expected_value
), f"Expected: {expected_value} | Measured: {measured_value}"
# TODO: [AlexM] Fix it with new CI/CD tests # TODO: [AlexM] Fix it with new CI/CD tests
...@@ -49,7 +58,8 @@ TPU_TP_TEST_STR = "" #"tensor_parallel_size=4" ...@@ -49,7 +58,8 @@ TPU_TP_TEST_STR = "" #"tensor_parallel_size=4"
@pytest.mark.skipif(not current_platform.is_cuda() @pytest.mark.skipif(not current_platform.is_cuda()
and not current_platform.is_tpu(), and not current_platform.is_tpu(),
reason="V1 is currently only supported on CUDA and TPU") reason="V1 is currently only supported on CUDA and TPU")
def test_lm_eval_accuracy_v1_engine(monkeypatch: pytest.MonkeyPatch): @pytest.mark.parametrize("model", MODEL_NAMES)
def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch):
"""Run with the V1 Engine.""" """Run with the V1 Engine."""
with monkeypatch.context() as m: with monkeypatch.context() as m:
...@@ -58,13 +68,13 @@ def test_lm_eval_accuracy_v1_engine(monkeypatch: pytest.MonkeyPatch): ...@@ -58,13 +68,13 @@ def test_lm_eval_accuracy_v1_engine(monkeypatch: pytest.MonkeyPatch):
more_args = None more_args = None
if current_platform.is_tpu(): if current_platform.is_tpu():
# Limit compilation time for TPU V1 # Limit compilation time for TPU V1
more_args = "max_num_seqs=64" more_args = "max_model_len=2048,max_num_seqs=64"
# Add TP test (if provided) # Add TP test (if provided)
if TPU_TP_TEST_STR: if TPU_TP_TEST_STR:
more_args += ",{}".format(TPU_TP_TEST_STR) more_args += ",{}".format(TPU_TP_TEST_STR)
run_test(more_args) run_test(model, more_args)
def test_lm_eval_accuracy_v0_engine(monkeypatch: pytest.MonkeyPatch): def test_lm_eval_accuracy_v0_engine(monkeypatch: pytest.MonkeyPatch):
...@@ -72,4 +82,4 @@ def test_lm_eval_accuracy_v0_engine(monkeypatch: pytest.MonkeyPatch): ...@@ -72,4 +82,4 @@ def test_lm_eval_accuracy_v0_engine(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "0") m.setenv("VLLM_USE_V1", "0")
run_test() run_test("Qwen/Qwen2-1.5B-Instruct")
...@@ -23,7 +23,19 @@ LORA_NAME = "typeof/zephyr-7b-beta-lora" ...@@ -23,7 +23,19 @@ LORA_NAME = "typeof/zephyr-7b-beta-lora"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def llm(): def monkeypatch_module():
from _pytest.monkeypatch import MonkeyPatch
mpatch = MonkeyPatch()
yield mpatch
mpatch.undo()
@pytest.fixture(scope="module", params=[False, True])
def llm(request, monkeypatch_module):
use_v1 = request.param
monkeypatch_module.setenv('VLLM_USE_V1', '1' if use_v1 else '0')
# pytest caches the fixture so we use weakref.proxy to # pytest caches the fixture so we use weakref.proxy to
# enable garbage collection # enable garbage collection
llm = LLM(model=MODEL_NAME, llm = LLM(model=MODEL_NAME,
......
...@@ -6,7 +6,6 @@ import weakref ...@@ -6,7 +6,6 @@ import weakref
import jsonschema import jsonschema
import pytest import pytest
from pydantic import BaseModel
from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed import cleanup_dist_env_and_memory
from vllm.entrypoints.llm import LLM from vllm.entrypoints.llm import LLM
...@@ -15,7 +14,10 @@ from vllm.sampling_params import GuidedDecodingParams, SamplingParams ...@@ -15,7 +14,10 @@ from vllm.sampling_params import GuidedDecodingParams, SamplingParams
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
GUIDED_DECODING_BACKENDS = [ GUIDED_DECODING_BACKENDS = [
"outlines", "lm-format-enforcer", "xgrammar", "guidance" "outlines",
"lm-format-enforcer",
"xgrammar:disable-any-whitespace",
"guidance:disable-any-whitespace",
] ]
...@@ -322,59 +324,9 @@ def test_guided_json_object(llm, guided_decoding_backend: str): ...@@ -322,59 +324,9 @@ def test_guided_json_object(llm, guided_decoding_backend: str):
print(generated_text) print(generated_text)
assert generated_text is not None assert generated_text is not None
if 'disable-any-whitespace' in guided_decoding_backend:
assert "\n" not in generated_text
# Parse to verify it is valid JSON # Parse to verify it is valid JSON
parsed_json = json.loads(generated_text) parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict) assert isinstance(parsed_json, dict)
@pytest.mark.skip_global_cleanup
def test_json_with_any_whitespace_disabled(llm):
class ResponseSchema(BaseModel):
clarifying_question: str
cost_per_serving: str
calories: str
type_dish_ids: str
type_meal_ids: str
product_ids: list[str]
exclude_product_ids: list[str]
allergen_ids: list[str]
total_cooking_time: str
kitchen_ids: str
holiday_ids: str
# Note: Without this setting, the response is sometimes full of `\n`
# for some models. This option prevents that.
guided_decoding_backend = 'xgrammar:disable-any-whitespace'
schema = ResponseSchema.model_json_schema()
guided_params = GuidedDecodingParams(json=schema,
backend=\
guided_decoding_backend)
sampling_params = SamplingParams(max_tokens=2000,
frequency_penalty=0,
presence_penalty=-1.1,
repetition_penalty=1.3,
guided_decoding=guided_params)
prompt = ("<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You"
"are a helpful assistant.<|im_end|>\n<|im_start|>user\nI want a "
"quick launch fast with $10.<|im_end|>\n<|im_start|>assistant\n")
outputs = llm.generate(prompts=prompt,
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
generated_text = output.outputs[0].text
assert generated_text is not None
assert "\n" not in generated_text
# Parse to verify it is valid JSON
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)
jsonschema.validate(instance=parsed_json, schema=schema)
...@@ -11,7 +11,7 @@ import pytest ...@@ -11,7 +11,7 @@ import pytest
import pytest_asyncio import pytest_asyncio
import requests import requests
import torch import torch
from openai import BadRequestError from openai import BadRequestError, OpenAI
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
from .test_completion import zephyr_lora_added_tokens_files # noqa: F401 from .test_completion import zephyr_lora_added_tokens_files # noqa: F401
...@@ -24,7 +24,23 @@ GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"] ...@@ -24,7 +24,23 @@ GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(zephyr_lora_files, zephyr_lora_added_tokens_files): # noqa: F811 def monkeypatch_module():
from _pytest.monkeypatch import MonkeyPatch
mpatch = MonkeyPatch()
yield mpatch
mpatch.undo()
@pytest.fixture(scope="module", params=[False, True])
def server(
request,
monkeypatch_module,
zephyr_lora_files, #noqa: F811
zephyr_lora_added_tokens_files): # noqa: F811
use_v1 = request.param
monkeypatch_module.setenv('VLLM_USE_V1', '1' if use_v1 else '0')
args = [ args = [
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
"--dtype", "--dtype",
...@@ -49,6 +65,13 @@ def server(zephyr_lora_files, zephyr_lora_added_tokens_files): # noqa: F811 ...@@ -49,6 +65,13 @@ def server(zephyr_lora_files, zephyr_lora_added_tokens_files): # noqa: F811
yield remote_server yield remote_server
@pytest.fixture
def is_v1_server(server):
import os
assert os.environ['VLLM_USE_V1'] in ['0', '1']
return os.environ['VLLM_USE_V1'] == '1'
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def client(server): async def client(server):
async with server.get_async_client() as async_client: async with server.get_async_client() as async_client:
...@@ -471,8 +494,13 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI, ...@@ -471,8 +494,13 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) @pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
async def test_guided_choice_chat(client: openai.AsyncOpenAI, async def test_guided_choice_chat(client: openai.AsyncOpenAI,
is_v1_server: bool,
guided_decoding_backend: str, guided_decoding_backend: str,
sample_guided_choice): sample_guided_choice):
if is_v1_server and guided_decoding_backend != 'xgrammar':
pytest.skip("Only xgrammar backend is supported with V1")
messages = [{ messages = [{
"role": "system", "role": "system",
"content": "you are a helpful assistant" "content": "you are a helpful assistant"
...@@ -511,9 +539,13 @@ async def test_guided_choice_chat(client: openai.AsyncOpenAI, ...@@ -511,9 +539,13 @@ async def test_guided_choice_chat(client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) @pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
async def test_guided_json_chat(client: openai.AsyncOpenAI, async def test_guided_json_chat(client: openai.AsyncOpenAI, is_v1_server: bool,
guided_decoding_backend: str, guided_decoding_backend: str,
sample_json_schema): sample_json_schema):
if is_v1_server:
pytest.skip("sample_json_schema has features unsupported in V1")
messages = [{ messages = [{
"role": "system", "role": "system",
"content": "you are a helpful assistant" "content": "you are a helpful assistant"
...@@ -559,7 +591,12 @@ async def test_guided_json_chat(client: openai.AsyncOpenAI, ...@@ -559,7 +591,12 @@ async def test_guided_json_chat(client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) @pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
async def test_guided_regex_chat(client: openai.AsyncOpenAI, async def test_guided_regex_chat(client: openai.AsyncOpenAI,
is_v1_server: bool,
guided_decoding_backend: str, sample_regex): guided_decoding_backend: str, sample_regex):
if is_v1_server and guided_decoding_backend != 'xgrammar':
pytest.skip("Only xgrammar backend is supported with V1")
messages = [{ messages = [{
"role": "system", "role": "system",
"content": "you are a helpful assistant" "content": "you are a helpful assistant"
...@@ -617,8 +654,13 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI): ...@@ -617,8 +654,13 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) @pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI, async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI,
is_v1_server: bool,
guided_decoding_backend: str, guided_decoding_backend: str,
sample_guided_choice): sample_guided_choice):
if is_v1_server and guided_decoding_backend != 'xgrammar':
pytest.skip("Only xgrammar backend is supported with V1")
messages = [{ messages = [{
"role": "system", "role": "system",
"content": "you are a helpful assistant" "content": "you are a helpful assistant"
...@@ -648,9 +690,13 @@ async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI, ...@@ -648,9 +690,13 @@ async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) @pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
async def test_named_tool_use(client: openai.AsyncOpenAI, async def test_named_tool_use(client: openai.AsyncOpenAI, is_v1_server: bool,
guided_decoding_backend: str, guided_decoding_backend: str,
sample_json_schema): sample_json_schema):
if is_v1_server:
pytest.skip("sample_json_schema has features unsupported on V1")
messages = [{ messages = [{
"role": "system", "role": "system",
"content": "you are a helpful assistant" "content": "you are a helpful assistant"
...@@ -740,53 +786,140 @@ async def test_named_tool_use(client: openai.AsyncOpenAI, ...@@ -740,53 +786,140 @@ async def test_named_tool_use(client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_required_tool_use_not_yet_supported(client: openai.AsyncOpenAI, @pytest.mark.parametrize("model_name", [MODEL_NAME])
sample_json_schema): async def test_required_tool_use(client: openai.AsyncOpenAI,
messages = [{ is_v1_server: bool, model_name: str):
"role": "system", if is_v1_server:
"content": "you are a helpful assistant" pytest.skip(
}, { "tool_choice='required' requires features unsupported on V1")
"role":
"user", tools = [
"content": {
f"Give an example JSON for an employee profile that " "type": "function",
f"fits this schema: {sample_json_schema}" "function": {
}] "name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description":
"The city to find the weather for, e.g. 'Vienna'",
"default": "Vienna",
},
"country": {
"type":
"string",
"description":
"The country that the city is in, e.g. 'Austria'",
},
"unit": {
"type": "string",
"description":
"The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["country", "unit"],
},
},
},
{
"type": "function",
"function": {
"name": "get_forecast",
"description": "Get the weather forecast for a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description":
"The city to get the forecast for, e.g. 'Vienna'",
"default": "Vienna",
},
"country": {
"type":
"string",
"description":
"The country that the city is in, e.g. 'Austria'",
},
"days": {
"type":
"integer",
"description":
"Number of days to get the forecast for (1-7)",
},
"unit": {
"type": "string",
"description":
"The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["country", "days", "unit"],
},
},
},
]
with pytest.raises(openai.BadRequestError): messages = [
await client.chat.completions.create( {
model=MODEL_NAME, "role": "user",
messages=messages, "content": "Hi! How are you doing today?"
max_completion_tokens=1000, },
tools=[{ {
"type": "function", "role": "assistant",
"function": { "content": "I'm doing well! How can I help you?"
"name": "dummy_function_name", },
"description": "This is a dummy function", {
"parameters": sample_json_schema "role":
} "user",
}], "content":
tool_choice="required") "Can you tell me what the current weather is in Berlin and the "\
"forecast for the next 5 days, in fahrenheit?",
},
]
with pytest.raises(openai.BadRequestError): # Non-streaming test
await client.chat.completions.create( chat_completion = await client.chat.completions.create(
model=MODEL_NAME, messages=messages,
messages=messages, model=model_name,
max_completion_tokens=1000, tools=tools,
tools=[{ tool_choice="required",
"type": "function", extra_body=dict(guided_decoding_backend="outlines"),
"function": { )
"name": "dummy_function_name",
"description": "This is a dummy function", assert chat_completion.choices[0].message.tool_calls is not None
"parameters": sample_json_schema assert len(chat_completion.choices[0].message.tool_calls) > 0
}
}], # Streaming test
tool_choice="auto") stream = await client.chat.completions.create(
messages=messages,
model=model_name,
tools=tools,
tool_choice="required",
extra_body=dict(guided_decoding_backend="outlines"),
stream=True,
)
output = []
async for chunk in stream:
if chunk.choices and chunk.choices[0].delta.tool_calls:
output.extend(chunk.choices[0].delta.tool_calls)
assert len(output) > 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_inconsistent_tool_choice_and_tools(client: openai.AsyncOpenAI, async def test_inconsistent_tool_choice_and_tools(client: openai.AsyncOpenAI,
is_v1_server: bool,
sample_json_schema): sample_json_schema):
if is_v1_server:
pytest.skip("sample_json_schema has features unsupported on V1")
messages = [{ messages = [{
"role": "system", "role": "system",
"content": "you are a helpful assistant" "content": "you are a helpful assistant"
...@@ -1000,7 +1133,7 @@ async def test_long_seed(client: openai.AsyncOpenAI): ...@@ -1000,7 +1133,7 @@ async def test_long_seed(client: openai.AsyncOpenAI):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_http_chat_wo_model_name(server: RemoteOpenAIServer): async def test_http_chat_no_model_name_with_curl(server: RemoteOpenAIServer):
url = f"http://localhost:{server.port}/v1/chat/completions" url = f"http://localhost:{server.port}/v1/chat/completions"
headers = { headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
...@@ -1021,10 +1154,35 @@ async def test_http_chat_wo_model_name(server: RemoteOpenAIServer): ...@@ -1021,10 +1154,35 @@ async def test_http_chat_wo_model_name(server: RemoteOpenAIServer):
response = requests.post(url, headers=headers, json=data) response = requests.post(url, headers=headers, json=data)
response_data = response.json() response_data = response.json()
print(response_data) print(response_data)
assert response_data.get("model") == MODEL_NAME
choice = response_data.get("choices")[0] choice = response_data.get("choices")[0]
message = choice.get("message") message = choice.get("message")
assert message is not None assert message is not None
content = message.get("content") content = message.get("content")
assert content is not None assert content is not None
assert len(content) > 0 assert len(content) > 0
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME, ""])
async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer,
model_name: str):
openai_api_key = "EMPTY"
openai_api_base = f"http://localhost:{server.port}/v1"
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
messages = [
{
"role": "user",
"content": "Hello, vLLM!"
},
]
response = client.chat.completions.create(
model="", # empty string
messages=messages,
)
assert response.model == MODEL_NAME
...@@ -53,7 +53,20 @@ def zephyr_lora_files(): ...@@ -53,7 +53,20 @@ def zephyr_lora_files():
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server_with_lora_modules_json(zephyr_lora_files): def monkeypatch_module():
from _pytest.monkeypatch import MonkeyPatch
mpatch = MonkeyPatch()
yield mpatch
mpatch.undo()
@pytest.fixture(scope="module", params=[False, True])
def server_with_lora_modules_json(request, monkeypatch_module,
zephyr_lora_files):
use_v1 = request.param
monkeypatch_module.setenv('VLLM_USE_V1', '1' if use_v1 else '0')
# Define the json format LoRA module configurations # Define the json format LoRA module configurations
lora_module_1 = { lora_module_1 = {
"name": "zephyr-lora", "name": "zephyr-lora",
......
...@@ -13,9 +13,12 @@ import requests ...@@ -13,9 +13,12 @@ import requests
from prometheus_client.parser import text_string_to_metric_families from prometheus_client.parser import text_string_to_metric_families
from transformers import AutoTokenizer from transformers import AutoTokenizer
from vllm import version
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
PREV_MINOR_VERSION = version._prev_minor_version()
@pytest.fixture(scope="module", params=[True, False]) @pytest.fixture(scope="module", params=[True, False])
...@@ -55,6 +58,7 @@ def default_server_args(): ...@@ -55,6 +58,7 @@ def default_server_args():
"", "",
"--enable-chunked-prefill", "--enable-chunked-prefill",
"--disable-frontend-multiprocessing", "--disable-frontend-multiprocessing",
f"--show-hidden-metrics-for-version={PREV_MINOR_VERSION}",
]) ])
def server(use_v1, default_server_args, request): def server(use_v1, default_server_args, request):
if request.param: if request.param:
...@@ -129,7 +133,9 @@ async def test_metrics_counts(server: RemoteOpenAIServer, ...@@ -129,7 +133,9 @@ async def test_metrics_counts(server: RemoteOpenAIServer,
# Loop over all expected metric_families # Loop over all expected metric_families
for metric_family, suffix_values_list in EXPECTED_VALUES.items(): for metric_family, suffix_values_list in EXPECTED_VALUES.items():
if use_v1 and metric_family not in EXPECTED_METRICS_V1: if ((use_v1 and metric_family not in EXPECTED_METRICS_V1)
or (not server.show_hidden_metrics
and metric_family in HIDDEN_DEPRECATED_METRICS)):
continue continue
found_metric = False found_metric = False
...@@ -165,10 +171,10 @@ async def test_metrics_counts(server: RemoteOpenAIServer, ...@@ -165,10 +171,10 @@ async def test_metrics_counts(server: RemoteOpenAIServer,
EXPECTED_METRICS = [ EXPECTED_METRICS = [
"vllm:num_requests_running", "vllm:num_requests_running",
"vllm:num_requests_swapped", "vllm:num_requests_swapped", # deprecated
"vllm:num_requests_waiting", "vllm:num_requests_waiting",
"vllm:gpu_cache_usage_perc", "vllm:gpu_cache_usage_perc",
"vllm:cpu_cache_usage_perc", "vllm:cpu_cache_usage_perc", # deprecated
"vllm:time_to_first_token_seconds_sum", "vllm:time_to_first_token_seconds_sum",
"vllm:time_to_first_token_seconds_bucket", "vllm:time_to_first_token_seconds_bucket",
"vllm:time_to_first_token_seconds_count", "vllm:time_to_first_token_seconds_count",
...@@ -268,6 +274,11 @@ EXPECTED_METRICS_V1 = [ ...@@ -268,6 +274,11 @@ EXPECTED_METRICS_V1 = [
"vllm:request_decode_time_seconds_count", "vllm:request_decode_time_seconds_count",
] ]
HIDDEN_DEPRECATED_METRICS = [
"vllm:num_requests_swapped",
"vllm:cpu_cache_usage_perc",
]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_metrics_exist(server: RemoteOpenAIServer, async def test_metrics_exist(server: RemoteOpenAIServer,
...@@ -282,7 +293,9 @@ async def test_metrics_exist(server: RemoteOpenAIServer, ...@@ -282,7 +293,9 @@ async def test_metrics_exist(server: RemoteOpenAIServer,
assert response.status_code == HTTPStatus.OK assert response.status_code == HTTPStatus.OK
for metric in (EXPECTED_METRICS_V1 if use_v1 else EXPECTED_METRICS): for metric in (EXPECTED_METRICS_V1 if use_v1 else EXPECTED_METRICS):
assert metric in response.text if (not server.show_hidden_metrics
and metric not in HIDDEN_DEPRECATED_METRICS):
assert metric in response.text
def test_metrics_exist_run_batch(use_v1: bool): def test_metrics_exist_run_batch(use_v1: bool):
......
...@@ -25,15 +25,37 @@ def test_sleep_mode(): ...@@ -25,15 +25,37 @@ def test_sleep_mode():
"VLLM_SERVER_DEV_MODE": "1", "VLLM_SERVER_DEV_MODE": "1",
"CUDA_VISIBLE_DEVICES": "0" "CUDA_VISIBLE_DEVICES": "0"
}) as remote_server: }) as remote_server:
response = requests.post(remote_server.url_for("/sleep"), response = requests.post(remote_server.url_for("sleep"),
data={"level": "1"}) params={"level": "1"})
assert response.status_code == 200 assert response.status_code == 200
response = requests.get(remote_server.url_for("/is_sleeping")) response = requests.get(remote_server.url_for("is_sleeping"))
assert response.status_code == 200 assert response.status_code == 200
assert response.json().get("is_sleeping") is True assert response.json().get("is_sleeping") is True
response = requests.post(remote_server.url_for("/wake_up")) response = requests.post(remote_server.url_for("wake_up"))
assert response.status_code == 200 assert response.status_code == 200
response = requests.get(remote_server.url_for("/is_sleeping")) response = requests.get(remote_server.url_for("is_sleeping"))
assert response.status_code == 200
assert response.json().get("is_sleeping") is False
# test wake up with tags
response = requests.post(remote_server.url_for("sleep"),
params={"level": "1"})
assert response.status_code == 200
response = requests.post(remote_server.url_for("wake_up"),
params={"tags": ["weights"]})
assert response.status_code == 200
# is sleeping should be false after waking up any part of the engine
response = requests.get(remote_server.url_for("is_sleeping"))
assert response.status_code == 200
assert response.json().get("is_sleeping") is True
response = requests.post(remote_server.url_for("wake_up"),
params={"tags": ["kv_cache"]})
assert response.status_code == 200
response = requests.get(remote_server.url_for("is_sleeping"))
assert response.status_code == 200 assert response.status_code == 200
assert response.json().get("is_sleeping") is False assert response.json().get("is_sleeping") is False
...@@ -3,6 +3,9 @@ ...@@ -3,6 +3,9 @@
import openai import openai
import pytest import pytest
import pytest_asyncio import pytest_asyncio
import requests
from PIL import Image
from transformers import AutoProcessor
from vllm.multimodal.utils import encode_image_base64, fetch_image from vllm.multimodal.utils import encode_image_base64, fetch_image
...@@ -53,11 +56,31 @@ def base64_encoded_image() -> dict[str, str]: ...@@ -53,11 +56,31 @@ def base64_encoded_image() -> dict[str, str]:
} }
def get_hf_prompt_tokens(model_name, content, image_url):
processor = AutoProcessor.from_pretrained(model_name,
trust_remote_code=True,
num_crops=4)
placeholder = "<|image_1|>\n"
messages = [{
"role": "user",
"content": f"{placeholder}{content}",
}]
images = [Image.open(requests.get(image_url, stream=True).raw)]
prompt = processor.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True)
inputs = processor(prompt, images, return_tensors="pt")
return inputs.input_ids.shape[1]
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_single_chat_session_image(client: openai.AsyncOpenAI, async def test_single_chat_session_image(client: openai.AsyncOpenAI,
model_name: str, image_url: str): model_name: str, image_url: str):
content_text = "What's in this image?"
messages = [{ messages = [{
"role": "role":
"user", "user",
...@@ -70,16 +93,17 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI, ...@@ -70,16 +93,17 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI,
}, },
{ {
"type": "text", "type": "text",
"text": "What's in this image?" "text": content_text
}, },
], ],
}] }]
max_completion_tokens = 10
# test single completion # test single completion
chat_completion = await client.chat.completions.create( chat_completion = await client.chat.completions.create(
model=model_name, model=model_name,
messages=messages, messages=messages,
max_completion_tokens=10, max_completion_tokens=max_completion_tokens,
logprobs=True, logprobs=True,
temperature=0.0, temperature=0.0,
top_logprobs=5) top_logprobs=5)
...@@ -87,8 +111,12 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI, ...@@ -87,8 +111,12 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI,
choice = chat_completion.choices[0] choice = chat_completion.choices[0]
assert choice.finish_reason == "length" assert choice.finish_reason == "length"
hf_prompt_tokens = get_hf_prompt_tokens(model_name, content_text,
image_url)
assert chat_completion.usage == openai.types.CompletionUsage( assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=774, total_tokens=784) completion_tokens=max_completion_tokens,
prompt_tokens=hf_prompt_tokens,
total_tokens=hf_prompt_tokens + max_completion_tokens)
message = choice.message message = choice.message
message = chat_completion.choices[0].message message = chat_completion.choices[0].message
...@@ -150,6 +178,7 @@ async def test_single_chat_session_image_base64encoded( ...@@ -150,6 +178,7 @@ async def test_single_chat_session_image_base64encoded(
client: openai.AsyncOpenAI, model_name: str, image_url: str, client: openai.AsyncOpenAI, model_name: str, image_url: str,
base64_encoded_image: dict[str, str]): base64_encoded_image: dict[str, str]):
content_text = "What's in this image?"
messages = [{ messages = [{
"role": "role":
"user", "user",
...@@ -163,16 +192,17 @@ async def test_single_chat_session_image_base64encoded( ...@@ -163,16 +192,17 @@ async def test_single_chat_session_image_base64encoded(
}, },
{ {
"type": "text", "type": "text",
"text": "What's in this image?" "text": content_text
}, },
], ],
}] }]
max_completion_tokens = 10
# test single completion # test single completion
chat_completion = await client.chat.completions.create( chat_completion = await client.chat.completions.create(
model=model_name, model=model_name,
messages=messages, messages=messages,
max_completion_tokens=10, max_completion_tokens=max_completion_tokens,
logprobs=True, logprobs=True,
temperature=0.0, temperature=0.0,
top_logprobs=5) top_logprobs=5)
...@@ -180,8 +210,12 @@ async def test_single_chat_session_image_base64encoded( ...@@ -180,8 +210,12 @@ async def test_single_chat_session_image_base64encoded(
choice = chat_completion.choices[0] choice = chat_completion.choices[0]
assert choice.finish_reason == "length" assert choice.finish_reason == "length"
hf_prompt_tokens = get_hf_prompt_tokens(model_name, content_text,
image_url)
assert chat_completion.usage == openai.types.CompletionUsage( assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=774, total_tokens=784) completion_tokens=max_completion_tokens,
prompt_tokens=hf_prompt_tokens,
total_tokens=hf_prompt_tokens + max_completion_tokens)
message = choice.message message = choice.message
message = chat_completion.choices[0].message message = chat_completion.choices[0].message
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
import pytest import pytest
import requests import requests
from PIL import Image
from transformers import AutoProcessor
from vllm.entrypoints.openai.protocol import EmbeddingResponse from vllm.entrypoints.openai.protocol import EmbeddingResponse
from vllm.multimodal.utils import encode_image_base64, fetch_image from vllm.multimodal.utils import encode_image_base64, fetch_image
...@@ -52,11 +54,24 @@ def base64_encoded_image() -> dict[str, str]: ...@@ -52,11 +54,24 @@ def base64_encoded_image() -> dict[str, str]:
} }
def get_hf_prompt_tokens(model_name, content, image_url):
processor = AutoProcessor.from_pretrained(model_name,
trust_remote_code=True,
num_crops=4)
placeholder = "<|image_1|> "
prompt = f"{placeholder}{content}"
images = [Image.open(requests.get(image_url, stream=True).raw)]
inputs = processor(prompt, images, return_tensors="pt")
return inputs.input_ids.shape[1]
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_image_embedding(server: RemoteOpenAIServer, model_name: str, async def test_image_embedding(server: RemoteOpenAIServer, model_name: str,
image_url: str): image_url: str):
content_text = "Represent the given image."
messages = [{ messages = [{
"role": "role":
"user", "user",
...@@ -69,7 +84,7 @@ async def test_image_embedding(server: RemoteOpenAIServer, model_name: str, ...@@ -69,7 +84,7 @@ async def test_image_embedding(server: RemoteOpenAIServer, model_name: str,
}, },
{ {
"type": "text", "type": "text",
"text": "Represent the given image." "text": content_text
}, },
], ],
}] }]
...@@ -85,9 +100,12 @@ async def test_image_embedding(server: RemoteOpenAIServer, model_name: str, ...@@ -85,9 +100,12 @@ async def test_image_embedding(server: RemoteOpenAIServer, model_name: str,
response.raise_for_status() response.raise_for_status()
embeddings = EmbeddingResponse.model_validate(response.json()) embeddings = EmbeddingResponse.model_validate(response.json())
hf_prompt_tokens = get_hf_prompt_tokens(model_name, content_text,
image_url)
assert embeddings.id is not None assert embeddings.id is not None
assert len(embeddings.data) == 1 assert len(embeddings.data) == 1
assert len(embeddings.data[0].embedding) == 3072 assert len(embeddings.data[0].embedding) == 3072
assert embeddings.usage.completion_tokens == 0 assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens == 763 assert embeddings.usage.prompt_tokens == hf_prompt_tokens
assert embeddings.usage.total_tokens == 763 assert embeddings.usage.total_tokens == hf_prompt_tokens
...@@ -9,11 +9,11 @@ from transformers import __version__ as TRANSFORMERS_VERSION ...@@ -9,11 +9,11 @@ from transformers import __version__ as TRANSFORMERS_VERSION
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (_resolve_hf_chat_template, from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template,
_try_extract_ast, load_chat_template,
parse_chat_messages, parse_chat_messages,
parse_chat_messages_futures, parse_chat_messages_futures,
resolve_chat_template_content_format) resolve_chat_template_content_format,
resolve_hf_chat_template)
from vllm.entrypoints.llm import apply_hf_chat_template from vllm.entrypoints.llm import apply_hf_chat_template
from vllm.multimodal import MultiModalDataDict from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import encode_image_base64 from vllm.multimodal.utils import encode_image_base64
...@@ -747,7 +747,7 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): ...@@ -747,7 +747,7 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
}] if use_tools else None }] if use_tools else None
# Test detecting the tokenizer's chat_template # Test detecting the tokenizer's chat_template
chat_template = _resolve_hf_chat_template( chat_template = resolve_hf_chat_template(
tokenizer, tokenizer,
chat_template=None, chat_template=None,
tools=tools, tools=tools,
...@@ -781,7 +781,7 @@ def test_resolve_content_format_hf_defined(model, expected_format): ...@@ -781,7 +781,7 @@ def test_resolve_content_format_hf_defined(model, expected_format):
tokenizer = tokenizer_group.tokenizer tokenizer = tokenizer_group.tokenizer
# Test detecting the tokenizer's chat_template # Test detecting the tokenizer's chat_template
chat_template = _resolve_hf_chat_template( chat_template = resolve_hf_chat_template(
tokenizer, tokenizer,
chat_template=None, chat_template=None,
tools=None, tools=None,
......
...@@ -6,12 +6,25 @@ import itertools ...@@ -6,12 +6,25 @@ import itertools
import pytest import pytest
import torch import torch
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
deep_gemm_moe_fp8)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size)
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8, w8a8_block_fp8_matmul) per_token_group_quant_fp8, w8a8_block_fp8_matmul)
from vllm.platforms import current_platform from vllm.platforms import current_platform
dg_available = False
try:
import deep_gemm
dg_available = True
except ImportError:
pass
if current_platform.get_device_capability() < (9, 0): if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
allow_module_level=True) allow_module_level=True)
...@@ -21,17 +34,18 @@ DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32] ...@@ -21,17 +34,18 @@ DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
NUM_TOKENS = [7, 83, 2048] NUM_TOKENS = [7, 83, 2048]
D = [512, 4096, 5120, 13824] D = [512, 4096, 5120, 13824]
GROUP_SIZE = [64, 128, 256, 512] GROUP_SIZE = [64, 128, 256, 512]
M = [1, 7, 83, 512, 2048] M = [1, 7, 8, 83, 84, 512, 2048, 4096]
N = [128, 512, 1024, 4096, 7748, 13824] N = [128, 512, 1024, 4096, 7168, 7748, 13824]
K = [256, 4096, 5120, 3884, 13824] K = [256, 4096, 5120, 3884, 13824, 16384]
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
# and its hidden size is 7168. # and its hidden size is 7168.
M_moe = [1, 7, 83, 512, 2048] M_moe = [1, 2, 7, 83, 128, 512, 2048]
N_moe = [4608] # [128, 4608, 13824] M_moe_dg = [128, 192, 512, 1335, 2048]
K_moe = [7168] # [256, 7168, 13824] N_moe = [128, 256, 1024, 4608] # [13824]
K_moe = [256, 512, 7168] # [13824]
BLOCK_SIZE = [[128, 128]] BLOCK_SIZE = [[128, 128]]
E = [8, 24] # [8, 24, 128, 256] E = [2, 8, 16, 24] # [128, 256]
TOP_KS = [2] # [1, 2, 6] TOP_KS = [1, 2, 6]
OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16] OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16]
SEEDS = [0] SEEDS = [0]
...@@ -217,11 +231,16 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): ...@@ -217,11 +231,16 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
SEEDS)) SEEDS))
@torch.inference_mode() @torch.inference_mode()
def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
if topk > E:
pytest.skip(f"Skipping test; topk={topk} > E={E}")
torch.manual_seed(seed) torch.manual_seed(seed)
factor_for_scale = 1e-2 factor_for_scale = 1e-2
fp8_info = torch.finfo(torch.float8_e4m3fn) fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min fp8_max, fp8_min = fp8_info.max, fp8_info.min
vllm_config = VllmConfig()
a = torch.randn((M, K), dtype=dtype) / 10 a = torch.randn((M, K), dtype=dtype) / 10
w1_bf16 = (torch.rand( w1_bf16 = (torch.rand(
...@@ -246,25 +265,240 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): ...@@ -246,25 +265,240 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
score = torch.randn((M, E), dtype=dtype) score = torch.randn((M, E), dtype=dtype)
out = fused_moe( # Set the context to avoid lots of warning spam.
a, with set_current_vllm_config(vllm_config):
w1, out = fused_moe(
w2, a,
score, w1,
topk, w2,
renormalize=False, score,
use_fp8_w8a8=True, topk,
w1_scale=w1_s, renormalize=False,
w2_scale=w2_s, use_fp8_w8a8=True,
block_shape=block_size, w1_scale=w1_s,
) w2_scale=w2_s,
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape=block_size,
block_size) )
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk,
print(f"{out.sum()=}") block_size)
print(f"{ref_out.sum()=}")
#print(f"{out.sum()=}")
#print(f"{ref_out.sum()=}")
rel_diff = (torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32))))
assert rel_diff < 0.03
def per_block_cast_to_fp8(
x: torch.Tensor,
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros(
(deep_gemm.ceil_div(m, 128) * 128,
deep_gemm.ceil_div(n, block_size_n) * block_size_n),
dtype=x.dtype,
device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
return x_scaled_sub, scales
@pytest.mark.parametrize(
"M,N,K,block_size,out_dtype,seed",
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
@torch.inference_mode()
def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
# only aligned sizes
if M % 4 != 0 or K % 128 != 0 or N % 64 != 0:
pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}")
torch.manual_seed(seed)
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max = fp8_info.max
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
_, block_k = block_size[0], block_size[1]
A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_k)
B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32)
As = As_fp8.to(torch.float32)
Bs = Bs_fp8.to(torch.float32)
ref_out = native_w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size,
out_dtype)
# Transpose earlier so that the testing will not trigger transposing kernels
As_fp8 = deep_gemm.get_col_major_tma_aligned_tensor(As_fp8)
out = torch.zeros((M, N), device='cuda', dtype=out_dtype)
assert As_fp8.shape == (M, (K + 127) //
128), f"{As_fp8.shape} != {(M, (K + 127) // 128)}"
deep_gemm.gemm_fp8_fp8_bf16_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out)
rel_diff = (torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32))))
assert rel_diff < 0.001
def fp8_perm(m, idx):
if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8:
return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype)
else:
return m[idx, ...]
def _moe_permute(a, a_s, topk_ids, num_groups, topk, block_m):
M, K = a.shape
sorted_token_ids, m_indices, num_pad = moe_align_block_size(
topk_ids, block_m, num_groups, None, pad_sorted_ids=True)
num_tokens = topk * M
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
m_indices = torch.repeat_interleave(m_indices, block_m, dim=0)
inv_perm = torch.argsort(sorted_token_ids)[:M * topk]
a = fp8_perm(a, sorted_token_ids // topk)
if a_s is not None:
a_s = a_s[sorted_token_ids // topk]
return a, a_s, m_indices, inv_perm
def _moe_unpermute(out, inv_perm, topk, K, topk_weight):
M = topk_weight.shape[0]
out = out[inv_perm, ...]
tmp_out = out.view(-1, topk, K)
return (tmp_out * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
block_shape):
"""Fused moe with block-wise quantization using DeepGemm grouped gemm."""
num_groups = w1.shape[0]
M, K = a.shape
N = w2.shape[-1]
topk_weight, topk_ids = fused_topk(a, score.float(), topk, False)
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
_, block_k = block_shape[0], block_shape[1]
a_q, a_s = per_token_group_quant_fp8(a, block_m)
a_q, a_s, m_indices, inv_perm = _moe_permute(a_q, a_s, topk_ids,
num_groups, topk, block_m)
inter_out = torch.zeros((a_q.shape[0], N * 2),
dtype=torch.bfloat16,
device=a.device)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s),
inter_out, m_indices)
act_out = SiluAndMul().forward_native(inter_out)
act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k)
out = torch.zeros(a_q.shape[0], K, dtype=torch.bfloat16, device=a.device)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(act_out_q, act_out_s), (w2, w2_s), out, m_indices)
final_out = _moe_unpermute(out, inv_perm, topk, K, topk_weight)
return final_out
@pytest.mark.parametrize(
"M,N,K,E,topk,seed",
itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, SEEDS))
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
@torch.inference_mode()
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
block_size = [block_m, block_m]
dtype = torch.bfloat16
# only aligned sizes
if (N % block_m != 0 or K % block_m != 0 or topk > E):
pytest.skip(
f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}")
if N <= 512:
pytest.skip("Skipping N <= 512 until performance issues solved.")
vllm_config = VllmConfig()
torch.manual_seed(seed)
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
a = torch.randn((M, K), dtype=dtype) / 10
w1_bf16 = ((torch.rand((E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 *
fp8_max).clamp(min=fp8_min, max=fp8_max)
w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 *
fp8_max).clamp(min=fp8_min, max=fp8_max)
score = torch.randn((M, E), dtype=dtype)
block_n, block_k = block_size[0], block_size[1]
n_tiles_w1 = ((2 * N) + block_n - 1) // block_n
k_tiles_w1 = (K + block_k - 1) // block_k
n_tiles_w2 = (K + block_n - 1) // block_n
k_tiles_w2 = (N + block_k - 1) // block_k
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn)
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn)
w1_s = torch.empty((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
w2_s = torch.empty((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
w1_s = deep_gemm.get_col_major_tma_aligned_tensor(w1_s).contiguous()
w2_s = deep_gemm.get_col_major_tma_aligned_tensor(w2_s).contiguous()
assert w1_s.shape == (E, (2 * N + 127) // 128, (K + 127) // 128)
assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2]
for i in range(E):
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
# Set the context to avoid lots of warning spam.
with set_current_vllm_config(vllm_config):
if M >= 128:
ref_out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s,
score, topk, block_size)
else:
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score,
topk, block_size)
topk_weights, topk_ids = fused_topk(a, score.float(), topk, False)
out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids)
#print(f"{out.sum()=}")
#print(f"{ref_out.sum()=}")
rel_diff = (torch.mean( rel_diff = (torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32)))) torch.mean(torch.abs(ref_out.to(torch.float32))))
assert rel_diff < 0.03 assert rel_diff < 0.03
...@@ -749,3 +749,72 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size, ...@@ -749,3 +749,72 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
ops.gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size) ops.gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size)
torch.testing.assert_close(dst, expected) torch.testing.assert_close(dst, expected)
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.cpu_model
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
@torch.inference_mode()
def test_concat_and_cache_mla_cpu(
kv_lora_rank: int,
qk_rope_head_dim: int,
num_tokens: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
seed: int,
) -> None:
device = "cpu"
kv_cache_dtype = "auto"
current_platform.seed_everything(seed)
torch.set_default_device(device)
total_slots = num_blocks * block_size
slot_mapping_lst = random.sample(range(total_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping_lst,
dtype=torch.long,
device=device)
kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
k_pe = torch.randn(num_tokens,
qk_rope_head_dim,
dtype=dtype,
device=device)
entry_size = kv_lora_rank + qk_rope_head_dim
scale = torch.tensor(0.1, dtype=torch.float32, device=device)
kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device)
ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)
for i in range(num_tokens):
slot = slot_mapping[i].item()
block_idx = slot // block_size
block_offset = slot % block_size
ref_temp[block_idx, block_offset, :kv_lora_rank] = kv_c[i]
ref_temp[block_idx, block_offset, kv_lora_rank:] = k_pe[i]
if kv_cache_dtype == "fp8":
ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype)
ops.convert_fp8(ref_kv_cache,
ref_temp,
scale.item(),
kv_dtype=kv_cache_dtype)
else:
ref_kv_cache = ref_temp
opcheck(
torch.ops._C_cache_ops.concat_and_cache_mla,
(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping,
kv_cache_dtype, scale)
torch.testing.assert_close(kv_cache, ref_kv_cache)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment