Unverified Commit a73e183e authored by Sibi's avatar Sibi Committed by GitHub
Browse files

[Misc] Replace os environ to monkeypatch in test suite (#14516)


Signed-off-by: default avatarsibi <85477603+t-sibiraj@users.noreply.github.com>
Signed-off-by: default avatarAaron Pham <contact@aarnphm.xyz>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: default avatarAaron Pham <contact@aarnphm.xyz>
parent 1e799b7e
...@@ -12,11 +12,10 @@ import pytest ...@@ -12,11 +12,10 @@ import pytest
from tests.kernels.utils import override_backend_env_variable from tests.kernels.utils import override_backend_env_variable
from tests.quantization.utils import is_quant_method_supported from tests.quantization.utils import is_quant_method_supported
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import STR_BACKEND_ENV_VAR
from ...utils import check_logprobs_close from ...utils import check_logprobs_close
os.environ["TOKENIZERS_PARALLELISM"] = "true"
@pytest.mark.quant_model @pytest.mark.quant_model
@pytest.mark.skipif(not is_quant_method_supported("fp8"), @pytest.mark.skipif(not is_quant_method_supported("fp8"),
...@@ -55,45 +54,47 @@ def test_models( ...@@ -55,45 +54,47 @@ def test_models(
backend: str, backend: str,
tensor_parallel_size: int, tensor_parallel_size: int,
disable_async_output_proc: bool, disable_async_output_proc: bool,
monkeypatch, monkeypatch: pytest.MonkeyPatch,
) -> None: ) -> None:
""" """
Only checks log probs match to cover the discrepancy in Only checks log probs match to cover the discrepancy in
numerical sensitive kernels. numerical sensitive kernels.
""" """
override_backend_env_variable(monkeypatch, backend) with monkeypatch.context() as m:
m.setenv("TOKENIZERS_PARALLELISM", 'true')
MAX_MODEL_LEN = 1024 m.setenv(STR_BACKEND_ENV_VAR, backend)
NUM_LOG_PROBS = 8
MAX_MODEL_LEN = 1024
with vllm_runner( NUM_LOG_PROBS = 8
base_model,
max_model_len=MAX_MODEL_LEN, with vllm_runner(
tensor_parallel_size=tensor_parallel_size, base_model,
enforce_eager=enforce_eager, max_model_len=MAX_MODEL_LEN,
kv_cache_dtype="auto", tensor_parallel_size=tensor_parallel_size,
disable_async_output_proc=disable_async_output_proc, enforce_eager=enforce_eager,
) as vllm_model: kv_cache_dtype="auto",
baseline_outputs = vllm_model.generate_greedy_logprobs( disable_async_output_proc=disable_async_output_proc,
example_prompts, max_tokens, NUM_LOG_PROBS) ) as vllm_model:
baseline_outputs = vllm_model.generate_greedy_logprobs(
with vllm_runner( example_prompts, max_tokens, NUM_LOG_PROBS)
test_model,
max_model_len=MAX_MODEL_LEN, with vllm_runner(
tensor_parallel_size=tensor_parallel_size, test_model,
enforce_eager=enforce_eager, max_model_len=MAX_MODEL_LEN,
kv_cache_dtype=kv_cache_dtype, tensor_parallel_size=tensor_parallel_size,
disable_async_output_proc=disable_async_output_proc, enforce_eager=enforce_eager,
) as vllm_model: kv_cache_dtype=kv_cache_dtype,
test_outputs = vllm_model.generate_greedy_logprobs( disable_async_output_proc=disable_async_output_proc,
example_prompts, max_tokens, NUM_LOG_PROBS) ) as vllm_model:
test_outputs = vllm_model.generate_greedy_logprobs(
check_logprobs_close( example_prompts, max_tokens, NUM_LOG_PROBS)
outputs_0_lst=baseline_outputs,
outputs_1_lst=test_outputs, check_logprobs_close(
name_0="fp16_kv_cache", outputs_0_lst=baseline_outputs,
name_1="fp8_kv_cache", outputs_1_lst=test_outputs,
) name_0="fp16_kv_cache",
name_1="fp8_kv_cache",
)
@pytest.mark.cpu_model @pytest.mark.cpu_model
...@@ -119,38 +120,41 @@ def test_cpu_models( ...@@ -119,38 +120,41 @@ def test_cpu_models(
test_model: str, test_model: str,
max_tokens: int, max_tokens: int,
disable_async_output_proc: bool, disable_async_output_proc: bool,
monkeypatch: pytest.MonkeyPatch,
) -> None: ) -> None:
""" """
Only checks log probs match to cover the discrepancy in Only checks log probs match to cover the discrepancy in
numerical sensitive kernels. numerical sensitive kernels.
""" """
with monkeypatch.context() as m:
MAX_MODEL_LEN = 1024 m.setenv("TOKENIZERS_PARALLELISM", 'true')
NUM_LOG_PROBS = 8
MAX_MODEL_LEN = 1024
with vllm_runner( NUM_LOG_PROBS = 8
base_model,
max_model_len=MAX_MODEL_LEN, with vllm_runner(
dtype="bfloat16", base_model,
kv_cache_dtype="auto", max_model_len=MAX_MODEL_LEN,
disable_async_output_proc=disable_async_output_proc, dtype="bfloat16",
) as vllm_model: kv_cache_dtype="auto",
baseline_outputs = vllm_model.generate_greedy_logprobs( disable_async_output_proc=disable_async_output_proc,
example_prompts, max_tokens, NUM_LOG_PROBS) ) as vllm_model:
baseline_outputs = vllm_model.generate_greedy_logprobs(
with vllm_runner( example_prompts, max_tokens, NUM_LOG_PROBS)
test_model,
max_model_len=MAX_MODEL_LEN, with vllm_runner(
dtype="bfloat16", test_model,
kv_cache_dtype=kv_cache_dtype, max_model_len=MAX_MODEL_LEN,
disable_async_output_proc=disable_async_output_proc, dtype="bfloat16",
) as vllm_model: kv_cache_dtype=kv_cache_dtype,
test_outputs = vllm_model.generate_greedy_logprobs( disable_async_output_proc=disable_async_output_proc,
example_prompts, max_tokens, NUM_LOG_PROBS) ) as vllm_model:
test_outputs = vllm_model.generate_greedy_logprobs(
check_logprobs_close( example_prompts, max_tokens, NUM_LOG_PROBS)
outputs_0_lst=baseline_outputs,
outputs_1_lst=test_outputs, check_logprobs_close(
name_0="bf16_kv_cache", outputs_0_lst=baseline_outputs,
name_1="fp8_kv_cache", outputs_1_lst=test_outputs,
) name_0="bf16_kv_cache",
name_1="fp8_kv_cache",
)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import importlib.util import importlib.util
import math import math
...@@ -11,6 +12,7 @@ from scipy.spatial.distance import cosine ...@@ -11,6 +12,7 @@ from scipy.spatial.distance import cosine
import vllm import vllm
import vllm.config import vllm.config
from vllm.utils import STR_BACKEND_ENV_VAR
from ....utils import RemoteOpenAIServer from ....utils import RemoteOpenAIServer
...@@ -29,36 +31,34 @@ def _arr(arr): ...@@ -29,36 +31,34 @@ def _arr(arr):
return array("i", arr) return array("i", arr)
def test_find_array(monkeypatch): def test_find_array(monkeypatch: pytest.MonkeyPatch):
# GritLM embedding implementation is only supported by XFormers backend. # GritLM embedding implementation is only supported by XFormers backend.
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS") with monkeypatch.context() as m:
m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS")
from vllm.model_executor.models.gritlm import GritLMPooler from vllm.model_executor.models.gritlm import GritLMPooler
# Create an LLM object to get the model config. # Create an LLM object to get the model config.
llm = vllm.LLM(MODEL_NAME, task="embed", max_model_len=MAX_MODEL_LEN) llm = vllm.LLM(MODEL_NAME, task="embed", max_model_len=MAX_MODEL_LEN)
pooler = GritLMPooler(model_config=llm.llm_engine.model_config) pooler = GritLMPooler(model_config=llm.llm_engine.model_config)
arr = _arr([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) arr = _arr([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=0) == 3 assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=0) == 3
assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=1) == 3 assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=1) == 3
assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=5) == -1 assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=5) == -1
assert pooler._find_array(arr, _arr([3, 5]), start_idx=0) == -1 assert pooler._find_array(arr, _arr([3, 5]), start_idx=0) == -1
with pytest.raises(ValueError): with pytest.raises(ValueError):
pooler._find_array(arr, _arr([3, 4, 5]), start_idx=-1) pooler._find_array(arr, _arr([3, 4, 5]), start_idx=-1)
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server_embedding(): def server_embedding():
# GritLM embedding implementation is only supported by XFormers backend. # GritLM embedding implementation is only supported by XFormers backend.
with pytest.MonkeyPatch.context() as mp: args = ["--task", "embed", "--max_model_len", str(MAX_MODEL_LEN)]
mp.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS") with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
args = ["--task", "embed", "--max_model_len", str(MAX_MODEL_LEN)]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
...@@ -69,9 +69,12 @@ def server_generate(): ...@@ -69,9 +69,12 @@ def server_generate():
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def client_embedding(server_embedding: RemoteOpenAIServer): async def client_embedding(monkeypatch: pytest.MonkeyPatch,
async with server_embedding.get_async_client() as async_client: server_embedding: RemoteOpenAIServer):
yield async_client with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS")
async with server_embedding.get_async_client() as async_client:
yield async_client
@pytest_asyncio.fixture @pytest_asyncio.fixture
...@@ -80,14 +83,20 @@ async def client_generate(server_generate: RemoteOpenAIServer): ...@@ -80,14 +83,20 @@ async def client_generate(server_generate: RemoteOpenAIServer):
yield async_client yield async_client
def run_llm_encode(llm: vllm.LLM, queries: list[str], def run_llm_encode(
instruction: str) -> list[float]: llm: vllm.LLM,
queries: list[str],
instruction: str,
) -> list[float]:
outputs = llm.encode([instruction + q for q in queries], ) outputs = llm.encode([instruction + q for q in queries], )
return [output.outputs.embedding for output in outputs] return [output.outputs.embedding for output in outputs]
async def run_client_embeddings(client: vllm.LLM, queries: list[str], async def run_client_embeddings(
instruction: str) -> list[float]: client: vllm.LLM,
queries: list[str],
instruction: str,
) -> list[float]:
outputs = await client.embeddings.create( outputs = await client.embeddings.create(
model=MODEL_NAME, model=MODEL_NAME,
input=[instruction + q for q in queries], input=[instruction + q for q in queries],
...@@ -106,7 +115,7 @@ def get_test_data(): ...@@ -106,7 +115,7 @@ def get_test_data():
README.md in https://github.com/ContextualAI/gritlm README.md in https://github.com/ContextualAI/gritlm
""" """
q_instruction = gritlm_instruction( q_instruction = gritlm_instruction(
"Given a scientific paper title, retrieve the paper's abstract") "Given a scientific paper title, retrieve the paper's abstract", )
queries = [ queries = [
"Bitcoin: A Peer-to-Peer Electronic Cash System", "Bitcoin: A Peer-to-Peer Electronic Cash System",
"Generative Representational Instruction Tuning", "Generative Representational Instruction Tuning",
...@@ -136,31 +145,32 @@ def validate_embed_output(q_rep: list[float], d_rep: list[float]): ...@@ -136,31 +145,32 @@ def validate_embed_output(q_rep: list[float], d_rep: list[float]):
assert math.isclose(cosine_sim_q1_d1, 0.532, abs_tol=0.001) assert math.isclose(cosine_sim_q1_d1, 0.532, abs_tol=0.001)
def test_gritlm_offline_embedding(monkeypatch): def test_gritlm_offline_embedding(monkeypatch: pytest.MonkeyPatch):
# GritLM embedding implementation is only supported by XFormers backend. # GritLM embedding implementation is only supported by XFormers backend.
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS") with monkeypatch.context() as m:
m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS")
queries, q_instruction, documents, d_instruction = get_test_data() queries, q_instruction, documents, d_instruction = get_test_data()
llm = vllm.LLM(MODEL_NAME, task="embed", max_model_len=MAX_MODEL_LEN) llm = vllm.LLM(MODEL_NAME, task="embed", max_model_len=MAX_MODEL_LEN)
d_rep = run_llm_encode( d_rep = run_llm_encode(
llm, llm,
documents, documents,
d_instruction, d_instruction,
) )
q_rep = run_llm_encode( q_rep = run_llm_encode(
llm, llm,
queries, queries,
q_instruction, q_instruction,
) )
validate_embed_output(q_rep, d_rep) validate_embed_output(q_rep, d_rep)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_gritlm_api_server_embedding( async def test_gritlm_api_server_embedding(
client_embedding: openai.AsyncOpenAI): client_embedding: openai.AsyncOpenAI, ):
queries, q_instruction, documents, d_instruction = get_test_data() queries, q_instruction, documents, d_instruction = get_test_data()
d_rep = await run_client_embeddings( d_rep = await run_client_embeddings(
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os
import pytest import pytest
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
...@@ -11,76 +9,92 @@ from ..utils import fork_new_process_for_each_test ...@@ -11,76 +9,92 @@ from ..utils import fork_new_process_for_each_test
@fork_new_process_for_each_test @fork_new_process_for_each_test
def test_plugin(dummy_opt_path, monkeypatch): def test_plugin(
monkeypatch: pytest.MonkeyPatch,
dummy_opt_path: str,
):
# V1 shuts down rather than raising an error here. # V1 shuts down rather than raising an error here.
monkeypatch.setenv("VLLM_USE_V1", "0") with monkeypatch.context() as m:
os.environ["VLLM_PLUGINS"] = "" m.setenv("VLLM_USE_V1", "0")
with pytest.raises(Exception) as excinfo: m.setenv("VLLM_PLUGINS", "")
LLM(model=dummy_opt_path, load_format="dummy")
error_msg = "has no vLLM implementation and " \
"the Transformers implementation is not compatible with vLLM"
assert (error_msg in str(excinfo.value))
with pytest.raises(Exception) as excinfo:
LLM(model=dummy_opt_path, load_format="dummy")
error_msg = "has no vLLM implementation and the Transformers implementation is not compatible with vLLM" # noqa: E501
assert (error_msg in str(excinfo.value))
@fork_new_process_for_each_test
def test_oot_registration_text_generation(dummy_opt_path):
os.environ["VLLM_PLUGINS"] = "register_dummy_model"
prompts = ["Hello, my name is", "The text does not matter"]
sampling_params = SamplingParams(temperature=0)
llm = LLM(model=dummy_opt_path, load_format="dummy")
first_token = llm.get_tokenizer().decode(0)
outputs = llm.generate(prompts, sampling_params)
for output in outputs: @fork_new_process_for_each_test
generated_text = output.outputs[0].text def test_oot_registration_text_generation(
# make sure only the first token is generated monkeypatch: pytest.MonkeyPatch,
rest = generated_text.replace(first_token, "") dummy_opt_path: str,
assert rest == "" ):
with monkeypatch.context() as m:
m.setenv("VLLM_PLUGINS", "register_dummy_model")
prompts = ["Hello, my name is", "The text does not matter"]
sampling_params = SamplingParams(temperature=0)
llm = LLM(model=dummy_opt_path, load_format="dummy")
first_token = llm.get_tokenizer().decode(0)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
generated_text = output.outputs[0].text
# make sure only the first token is generated
rest = generated_text.replace(first_token, "")
assert rest == ""
@fork_new_process_for_each_test @fork_new_process_for_each_test
def test_oot_registration_embedding(dummy_gemma2_embedding_path): def test_oot_registration_embedding(
os.environ["VLLM_PLUGINS"] = "register_dummy_model" monkeypatch: pytest.MonkeyPatch,
prompts = ["Hello, my name is", "The text does not matter"] dummy_gemma2_embedding_path: str,
llm = LLM(model=dummy_gemma2_embedding_path, load_format="dummy") ):
outputs = llm.embed(prompts) with monkeypatch.context() as m:
m.setenv("VLLM_PLUGINS", "register_dummy_model")
prompts = ["Hello, my name is", "The text does not matter"]
llm = LLM(model=dummy_gemma2_embedding_path, load_format="dummy")
outputs = llm.embed(prompts)
for output in outputs: for output in outputs:
assert all(v == 0 for v in output.outputs.embedding) assert all(v == 0 for v in output.outputs.embedding)
image = ImageAsset("cherry_blossom").pil_image.convert("RGB") image = ImageAsset("cherry_blossom").pil_image.convert("RGB")
@fork_new_process_for_each_test @fork_new_process_for_each_test
def test_oot_registration_multimodal(dummy_llava_path, monkeypatch): def test_oot_registration_multimodal(
os.environ["VLLM_PLUGINS"] = "register_dummy_model" monkeypatch: pytest.MonkeyPatch,
prompts = [{ dummy_llava_path: str,
"prompt": "What's in the image?<image>", ):
"multi_modal_data": { with monkeypatch.context() as m:
"image": image m.setenv("VLLM_PLUGINS", "register_dummy_model")
}, prompts = [{
}, { "prompt": "What's in the image?<image>",
"prompt": "Describe the image<image>", "multi_modal_data": {
"multi_modal_data": { "image": image
"image": image },
}, }, {
}] "prompt": "Describe the image<image>",
"multi_modal_data": {
sampling_params = SamplingParams(temperature=0) "image": image
llm = LLM(model=dummy_llava_path, },
load_format="dummy", }]
max_num_seqs=1,
trust_remote_code=True, sampling_params = SamplingParams(temperature=0)
gpu_memory_utilization=0.98, llm = LLM(model=dummy_llava_path,
max_model_len=4096, load_format="dummy",
enforce_eager=True, max_num_seqs=1,
limit_mm_per_prompt={"image": 1}) trust_remote_code=True,
first_token = llm.get_tokenizer().decode(0) gpu_memory_utilization=0.98,
outputs = llm.generate(prompts, sampling_params) max_model_len=4096,
enforce_eager=True,
for output in outputs: limit_mm_per_prompt={"image": 1})
generated_text = output.outputs[0].text first_token = llm.get_tokenizer().decode(0)
# make sure only the first token is generated outputs = llm.generate(prompts, sampling_params)
rest = generated_text.replace(first_token, "")
assert rest == "" for output in outputs:
generated_text = output.outputs[0].text
# make sure only the first token is generated
rest = generated_text.replace(first_token, "")
assert rest == ""
...@@ -235,25 +235,28 @@ async def test_bad_request(tmp_socket): ...@@ -235,25 +235,28 @@ async def test_bad_request(tmp_socket):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_mp_crash_detection(monkeypatch): async def test_mp_crash_detection(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.") parser = FlexibleArgumentParser(
parser = make_arg_parser(parser) description="vLLM's remote OpenAI server.")
args = parser.parse_args([]) parser = make_arg_parser(parser)
args = parser.parse_args([])
# When LLMEngine is loaded, it will crash. # When LLMEngine is loaded, it will crash.
def mock_init(): def mock_init():
raise ValueError raise ValueError
monkeypatch.setattr(LLMEngine, "__init__", mock_init) m.setattr(LLMEngine, "__init__", mock_init)
start = time.perf_counter() start = time.perf_counter()
async with build_async_engine_client(args): async with build_async_engine_client(args):
pass pass
end = time.perf_counter() end = time.perf_counter()
assert end - start < 60, ("Expected vLLM to gracefully shutdown in <60s " assert end - start < 60, (
"if there is an error in the startup.") "Expected vLLM to gracefully shutdown in <60s "
"if there is an error in the startup.")
@pytest.mark.asyncio @pytest.mark.asyncio
......
...@@ -5,7 +5,7 @@ from typing import Optional ...@@ -5,7 +5,7 @@ from typing import Optional
import pytest import pytest
from tests.kernels.utils import override_backend_env_variable from vllm.utils import STR_BACKEND_ENV_VAR
from ..models.utils import check_logprobs_close from ..models.utils import check_logprobs_close
from ..utils import (completions_with_server_args, get_client_text_generations, from ..utils import (completions_with_server_args, get_client_text_generations,
...@@ -52,7 +52,7 @@ async def test_multi_step( ...@@ -52,7 +52,7 @@ async def test_multi_step(
num_logprobs: Optional[int], num_logprobs: Optional[int],
attention_backend: str, attention_backend: str,
enable_chunked_prefill: bool, enable_chunked_prefill: bool,
monkeypatch, monkeypatch: pytest.MonkeyPatch,
) -> None: ) -> None:
"""Test vLLM engine with multi-step scheduling in an OpenAI-protocol """Test vLLM engine with multi-step scheduling in an OpenAI-protocol
client/server environment. client/server environment.
...@@ -82,67 +82,70 @@ async def test_multi_step( ...@@ -82,67 +82,70 @@ async def test_multi_step(
pytest.skip("Multi-step with Chunked-Prefill only supports" pytest.skip("Multi-step with Chunked-Prefill only supports"
"PP=1 and FLASH_ATTN backend") "PP=1 and FLASH_ATTN backend")
override_backend_env_variable(monkeypatch, attention_backend) with monkeypatch.context() as m:
m.setenv(STR_BACKEND_ENV_VAR, attention_backend)
prompts = example_prompts
if len(prompts) < num_prompts: prompts = example_prompts
prompts = prompts * ((num_prompts // len(prompts)) + 1) if len(prompts) < num_prompts:
prompts = prompts[:num_prompts] prompts = prompts * ((num_prompts // len(prompts)) + 1)
assert len(prompts) == num_prompts prompts = prompts[:num_prompts]
assert len(prompts) == num_prompts
server_args = DEFAULT_SERVER_ARGS + ["--enforce-eager"]
ms_server_args = DEFAULT_SERVER_ARGS + \ server_args = DEFAULT_SERVER_ARGS + ["--enforce-eager"]
["--num-scheduler-steps", f"{num_scheduler_steps}"] ms_server_args = DEFAULT_SERVER_ARGS + \
["--num-scheduler-steps", f"{num_scheduler_steps}"]
if not is_async:
ms_server_args += ["--disable-async-output-proc"] if not is_async:
ms_server_args += ["--disable-async-output-proc"]
if eager_mode:
ms_server_args.append("--enforce-eager") if eager_mode:
ms_server_args.append("--enforce-eager")
if enable_chunked_prefill:
ms_server_args.append("--enable-chunked-prefill") if enable_chunked_prefill:
ms_server_args.append("--enable-chunked-prefill")
distributed_args = [
"--tensor-parallel-size", distributed_args = [
str(tp_size), "--tensor-parallel-size",
"--pipeline-parallel-size", str(tp_size),
str(pp_size), "--pipeline-parallel-size",
] str(pp_size),
]
# Spin up client/server & issue completion API requests.
# Default `max_wait_seconds` is 240 but was empirically # Spin up client/server & issue completion API requests.
# was raised 5x to 1200 *just for this test* due to # Default `max_wait_seconds` is 240 but was empirically
# observed timeouts in GHA CI # was raised 5x to 1200 *just for this test* due to
ref_completions = await completions_with_server_args( # observed timeouts in GHA CI
prompts, ref_completions = await completions_with_server_args(
model, prompts,
server_args + distributed_args, model,
num_logprobs, server_args + distributed_args,
max_wait_seconds=5 * 240) num_logprobs,
test_completions = await completions_with_server_args( max_wait_seconds=5 * 240)
prompts, test_completions = await completions_with_server_args(
model, prompts,
ms_server_args + distributed_args, model,
num_logprobs, ms_server_args + distributed_args,
max_wait_seconds=5 * 240) num_logprobs,
max_wait_seconds=5 * 240)
# Assert multi-step scheduling produces identical tokens
# to single-step scheduling. # Assert multi-step scheduling produces identical tokens
ref_generations = get_client_text_generations(ref_completions) # to single-step scheduling.
test_generations = get_client_text_generations(test_completions) ref_generations = get_client_text_generations(ref_completions)
assert ref_generations == test_generations test_generations = get_client_text_generations(test_completions)
assert ref_generations == test_generations
# Assert multi-step scheduling produces nearly-identical logprobs
# to single-step scheduling. # Assert multi-step scheduling produces nearly-identical logprobs
ref_text_logprobs = get_client_text_logprob_generations(ref_completions) # to single-step scheduling.
test_text_logprobs = get_client_text_logprob_generations(test_completions) ref_text_logprobs = get_client_text_logprob_generations(
check_logprobs_close( ref_completions)
outputs_0_lst=ref_text_logprobs, test_text_logprobs = get_client_text_logprob_generations(
outputs_1_lst=test_text_logprobs, test_completions)
name_0="hf", check_logprobs_close(
name_1="vllm", outputs_0_lst=ref_text_logprobs,
) outputs_1_lst=test_text_logprobs,
name_0="hf",
name_1="vllm",
)
@pytest.mark.parametrize(("tp_size, pp_size"), [ @pytest.mark.parametrize(("tp_size, pp_size"), [
...@@ -152,7 +155,7 @@ async def test_multi_step( ...@@ -152,7 +155,7 @@ async def test_multi_step(
async def test_multi_step_pp_smoke( async def test_multi_step_pp_smoke(
tp_size: int, tp_size: int,
pp_size: int, pp_size: int,
monkeypatch, monkeypatch: pytest.MonkeyPatch,
) -> None: ) -> None:
""" """
Smoke test for the vLLM engine with multi-step scheduling in an Smoke test for the vLLM engine with multi-step scheduling in an
...@@ -174,54 +177,55 @@ async def test_multi_step_pp_smoke( ...@@ -174,54 +177,55 @@ async def test_multi_step_pp_smoke(
attention_backend = "FLASH_ATTN" attention_backend = "FLASH_ATTN"
max_num_seqs = 3 max_num_seqs = 3
override_backend_env_variable(monkeypatch, attention_backend) with monkeypatch.context() as m:
m.setenv(STR_BACKEND_ENV_VAR, attention_backend)
# Prompt from the ShareGPT dataset
prompts = [ # Prompt from the ShareGPT dataset
"in the jtbd context whats a push?", # codespell:ignore prompts = [
"in the jtbd context whats a push?", # codespell:ignore "in the jtbd context whats a push?", # codespell:ignore
"in the jtbd context whats a push?", # codespell:ignore "in the jtbd context whats a push?", # codespell:ignore
"in the jtbd context whats a push?", # codespell:ignore "in the jtbd context whats a push?", # codespell:ignore
] "in the jtbd context whats a push?", # codespell:ignore
# Use varying max_tokens to introduce scheduling randomness. ]
max_tokens = [10 * i for i in range(1, len(prompts) + 1)] # Use varying max_tokens to introduce scheduling randomness.
assert len(prompts) == len(max_tokens) max_tokens = [10 * i for i in range(1, len(prompts) + 1)]
assert len(prompts) == len(max_tokens)
test_args = [
"--tensor-parallel-size", test_args = [
str(tp_size), "--pipeline-parallel-size", "--tensor-parallel-size",
str(pp_size), "--max-num-seqs", str(tp_size), "--pipeline-parallel-size",
str(max_num_seqs) str(pp_size), "--max-num-seqs",
] str(max_num_seqs)
]
server_args = DEFAULT_SERVER_ARGS + test_args
ms_server_args = DEFAULT_SERVER_ARGS + \ server_args = DEFAULT_SERVER_ARGS + test_args
["--num-scheduler-steps", f"{num_scheduler_steps}"] + \ ms_server_args = DEFAULT_SERVER_ARGS + \
test_args ["--num-scheduler-steps", f"{num_scheduler_steps}"] + \
test_args
# Spin up client/server & issue completion API requests.
# Default `max_wait_seconds` is 240 but was empirically # Spin up client/server & issue completion API requests.
# was raised 3x to 720 *just for this test* due to # Default `max_wait_seconds` is 240 but was empirically
# observed timeouts in GHA CI # was raised 3x to 720 *just for this test* due to
ref_completions = await completions_with_server_args( # observed timeouts in GHA CI
prompts=prompts, ref_completions = await completions_with_server_args(
model_name=model, prompts=prompts,
server_cli_args=server_args, model_name=model,
num_logprobs=None, server_cli_args=server_args,
max_wait_seconds=5 * 240, num_logprobs=None,
max_tokens=max_tokens) max_wait_seconds=5 * 240,
max_tokens=max_tokens)
test_completions = await completions_with_server_args(
prompts=prompts, test_completions = await completions_with_server_args(
model_name=model, prompts=prompts,
server_cli_args=ms_server_args, model_name=model,
num_logprobs=None, server_cli_args=ms_server_args,
max_wait_seconds=5 * 240, num_logprobs=None,
max_tokens=max_tokens) max_wait_seconds=5 * 240,
max_tokens=max_tokens)
# Assert multi-step scheduling produces identical tokens
# to single-step scheduling. # Assert multi-step scheduling produces identical tokens
ref_generations = get_client_text_generations(ref_completions) # to single-step scheduling.
test_generations = get_client_text_generations(test_completions) ref_generations = get_client_text_generations(ref_completions)
test_generations = get_client_text_generations(test_completions)
assert ref_generations == test_generations
assert ref_generations == test_generations
...@@ -7,7 +7,7 @@ from typing import Optional ...@@ -7,7 +7,7 @@ from typing import Optional
import pytest import pytest
from tests.kernels.utils import override_backend_env_variable from vllm.utils import STR_BACKEND_ENV_VAR
from ..models.utils import check_logprobs_close, check_outputs_equal from ..models.utils import check_logprobs_close, check_outputs_equal
...@@ -42,7 +42,7 @@ def test_multi_step_llm( ...@@ -42,7 +42,7 @@ def test_multi_step_llm(
num_prompts: int, num_prompts: int,
num_logprobs: Optional[int], num_logprobs: Optional[int],
attention_backend: str, attention_backend: str,
monkeypatch, monkeypatch: pytest.MonkeyPatch,
) -> None: ) -> None:
"""Test vLLM engine with multi-step scheduling via sync LLM Engine. """Test vLLM engine with multi-step scheduling via sync LLM Engine.
...@@ -70,48 +70,49 @@ def test_multi_step_llm( ...@@ -70,48 +70,49 @@ def test_multi_step_llm(
num_logprobs: corresponds to the `logprobs` argument to the OpenAI num_logprobs: corresponds to the `logprobs` argument to the OpenAI
completions endpoint; `None` -> 1 logprob returned. completions endpoint; `None` -> 1 logprob returned.
""" """
override_backend_env_variable(monkeypatch, attention_backend) with monkeypatch.context() as m:
m.setenv(STR_BACKEND_ENV_VAR, attention_backend)
prompts = example_prompts
if len(prompts) < num_prompts: prompts = example_prompts
prompts = prompts * ((num_prompts // len(prompts)) + 1) if len(prompts) < num_prompts:
prompts = prompts[:num_prompts] prompts = prompts * ((num_prompts // len(prompts)) + 1)
assert len(prompts) == num_prompts prompts = prompts[:num_prompts]
assert len(prompts) == num_prompts
with vllm_runner(
model, with vllm_runner(
dtype=dtype, model,
enforce_eager=enforce_eager, dtype=dtype,
gpu_memory_utilization=0.7, enforce_eager=enforce_eager,
tensor_parallel_size=tp_size, gpu_memory_utilization=0.7,
enable_chunked_prefill=enable_chunked_prefill, tensor_parallel_size=tp_size,
num_scheduler_steps=num_scheduler_steps, enable_chunked_prefill=enable_chunked_prefill,
) as vllm_model: num_scheduler_steps=num_scheduler_steps,
vllm_outputs = (vllm_model.generate_greedy(prompts, max_tokens) ) as vllm_model:
if num_logprobs is None else vllm_outputs = (vllm_model.generate_greedy(prompts, max_tokens)
vllm_model.generate_greedy_logprobs( if num_logprobs is None else
prompts, max_tokens, num_logprobs)) vllm_model.generate_greedy_logprobs(
prompts, max_tokens, num_logprobs))
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = (hf_model.generate_greedy(prompts, max_tokens) with hf_runner(model, dtype=dtype) as hf_model:
if num_logprobs is None else hf_outputs = (hf_model.generate_greedy(prompts, max_tokens)
hf_model.generate_greedy_logprobs_limit( if num_logprobs is None else
prompts, max_tokens, num_logprobs)) hf_model.generate_greedy_logprobs_limit(
prompts, max_tokens, num_logprobs))
if num_logprobs is None:
check_outputs_equal( if num_logprobs is None:
outputs_0_lst=hf_outputs, check_outputs_equal(
outputs_1_lst=vllm_outputs, outputs_0_lst=hf_outputs,
name_0="hf", outputs_1_lst=vllm_outputs,
name_1="vllm", name_0="hf",
) name_1="vllm",
else: )
check_logprobs_close( else:
outputs_0_lst=hf_outputs, check_logprobs_close(
outputs_1_lst=vllm_outputs, outputs_0_lst=hf_outputs,
name_0="hf", outputs_1_lst=vllm_outputs,
name_1="vllm", name_0="hf",
) name_1="vllm",
)
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
...@@ -136,7 +137,7 @@ def test_multi_step_llm_w_prompt_logprobs( ...@@ -136,7 +137,7 @@ def test_multi_step_llm_w_prompt_logprobs(
num_logprobs: Optional[int], num_logprobs: Optional[int],
num_prompt_logprobs: Optional[int], num_prompt_logprobs: Optional[int],
attention_backend: str, attention_backend: str,
monkeypatch, monkeypatch: pytest.MonkeyPatch,
) -> None: ) -> None:
"""Test prompt logprobs with multi-step scheduling via sync LLM Engine. """Test prompt logprobs with multi-step scheduling via sync LLM Engine.
...@@ -166,47 +167,48 @@ def test_multi_step_llm_w_prompt_logprobs( ...@@ -166,47 +167,48 @@ def test_multi_step_llm_w_prompt_logprobs(
note that this argument is not supported by the note that this argument is not supported by the
OpenAI completions endpoint. OpenAI completions endpoint.
""" """
override_backend_env_variable(monkeypatch, attention_backend) with monkeypatch.context() as m:
m.setenv(STR_BACKEND_ENV_VAR, attention_backend)
prompts = example_prompts
if len(prompts) < num_prompts: prompts = example_prompts
prompts = prompts * ((num_prompts // len(prompts)) + 1) if len(prompts) < num_prompts:
prompts = prompts[:num_prompts] prompts = prompts * ((num_prompts // len(prompts)) + 1)
assert len(prompts) == num_prompts prompts = prompts[:num_prompts]
assert len(prompts) == num_prompts
with vllm_runner(
model, with vllm_runner(
dtype=dtype, model,
enforce_eager=enforce_eager, dtype=dtype,
gpu_memory_utilization=0.7, enforce_eager=enforce_eager,
tensor_parallel_size=tp_size, gpu_memory_utilization=0.7,
num_scheduler_steps=num_scheduler_steps, tensor_parallel_size=tp_size,
) as vllm_model: num_scheduler_steps=num_scheduler_steps,
vllm_outputs = vllm_model.generate_greedy_logprobs( ) as vllm_model:
prompts, vllm_outputs = vllm_model.generate_greedy_logprobs(
max_tokens, prompts,
num_logprobs, max_tokens,
num_prompt_logprobs=num_prompt_logprobs) num_logprobs,
num_prompt_logprobs=num_prompt_logprobs)
with vllm_runner(
model, with vllm_runner(
dtype=dtype, model,
enforce_eager=enforce_eager, dtype=dtype,
gpu_memory_utilization=0.7, enforce_eager=enforce_eager,
tensor_parallel_size=tp_size, gpu_memory_utilization=0.7,
) as vllm_model: tensor_parallel_size=tp_size,
single_step_vllm_outputs = vllm_model.generate_greedy_logprobs( ) as vllm_model:
prompts, single_step_vllm_outputs = vllm_model.generate_greedy_logprobs(
max_tokens, prompts,
num_logprobs, max_tokens,
num_prompt_logprobs=num_prompt_logprobs) num_logprobs,
num_prompt_logprobs=num_prompt_logprobs)
check_logprobs_close(
outputs_0_lst=single_step_vllm_outputs, check_logprobs_close(
outputs_1_lst=vllm_outputs, outputs_0_lst=single_step_vllm_outputs,
name_0="hf", outputs_1_lst=vllm_outputs,
name_1="vllm", name_0="hf",
) name_1="vllm",
)
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
...@@ -230,7 +232,7 @@ def test_multi_step_llm_chunked_prefill_prefix_cache( ...@@ -230,7 +232,7 @@ def test_multi_step_llm_chunked_prefill_prefix_cache(
num_prompts: int, num_prompts: int,
num_logprobs: Optional[int], num_logprobs: Optional[int],
attention_backend: str, attention_backend: str,
monkeypatch, monkeypatch: pytest.MonkeyPatch,
) -> None: ) -> None:
"""Test vLLM engine with multi-step+"single-step chunked prefill"+APC. """Test vLLM engine with multi-step+"single-step chunked prefill"+APC.
...@@ -293,77 +295,78 @@ def test_multi_step_llm_chunked_prefill_prefix_cache( ...@@ -293,77 +295,78 @@ def test_multi_step_llm_chunked_prefill_prefix_cache(
# #
# The Incorrect scheduling behavior - if it occurs - will cause an exception # The Incorrect scheduling behavior - if it occurs - will cause an exception
# in the model runner resulting from `do_sample=False`. # in the model runner resulting from `do_sample=False`.
override_backend_env_variable(monkeypatch, attention_backend) with monkeypatch.context() as m:
m.setenv(STR_BACKEND_ENV_VAR, attention_backend)
assert len(example_prompts) >= 2
challenge_prompts = copy.deepcopy(example_prompts) assert len(example_prompts) >= 2
challenge_prompts[0] = ('vLLM is a high-throughput and memory-efficient ' challenge_prompts = copy.deepcopy(example_prompts)
'inference and serving engine for LLMs.\n' challenge_prompts[0] = (
) # 24 tok 'vLLM is a high-throughput and memory-efficient '
challenge_prompts[1] = ( 'inference and serving engine for LLMs.\n') # 24 tok
'Briefly describe the major milestones in the ' challenge_prompts[1] = (
'development of artificial intelligence from 1950 to 2020.\n' 'Briefly describe the major milestones in the '
) # 30 tok 'development of artificial intelligence from 1950 to 2020.\n'
) # 30 tok
# If necessary, adjust the length of `challenge_prompts` to match
# `num_prompts` # If necessary, adjust the length of `challenge_prompts` to match
if len(challenge_prompts) < num_prompts: # `num_prompts`
challenge_prompts = (challenge_prompts * if len(challenge_prompts) < num_prompts:
((num_prompts // len(challenge_prompts)) + 1)) challenge_prompts = (challenge_prompts *
challenge_prompts = challenge_prompts[:num_prompts] ((num_prompts // len(challenge_prompts)) + 1))
assert len(challenge_prompts) == num_prompts challenge_prompts = challenge_prompts[:num_prompts]
assert len(challenge_prompts) == num_prompts
# Single-step scheduler baseline
with vllm_runner( # Single-step scheduler baseline
model, with vllm_runner(
dtype=dtype, model,
enforce_eager=enforce_eager, dtype=dtype,
gpu_memory_utilization=0.7, enforce_eager=enforce_eager,
tensor_parallel_size=tp_size, gpu_memory_utilization=0.7,
num_scheduler_steps=num_scheduler_steps, tensor_parallel_size=tp_size,
max_model_len=48, num_scheduler_steps=num_scheduler_steps,
max_num_batched_tokens=48, max_model_len=48,
max_num_seqs=4, max_num_batched_tokens=48,
block_size=16, max_num_seqs=4,
) as vllm_model: block_size=16,
outputs_baseline = (vllm_model.generate_greedy( ) as vllm_model:
challenge_prompts, max_tokens) if num_logprobs is None else outputs_baseline = (
vllm_model.generate_greedy_logprobs( vllm_model.generate_greedy(challenge_prompts, max_tokens) if
challenge_prompts, max_tokens, num_logprobs)) num_logprobs is None else vllm_model.generate_greedy_logprobs(
challenge_prompts, max_tokens, num_logprobs))
# multi-step+"single-step chunked prefill"+APC
with vllm_runner( # multi-step+"single-step chunked prefill"+APC
model, with vllm_runner(
dtype=dtype, model,
enforce_eager=enforce_eager, dtype=dtype,
gpu_memory_utilization=0.7, enforce_eager=enforce_eager,
tensor_parallel_size=tp_size, gpu_memory_utilization=0.7,
enable_chunked_prefill=True, tensor_parallel_size=tp_size,
enable_prefix_caching=True, enable_chunked_prefill=True,
num_scheduler_steps=num_scheduler_steps, enable_prefix_caching=True,
max_model_len=48, num_scheduler_steps=num_scheduler_steps,
max_num_batched_tokens=48, max_model_len=48,
max_num_seqs=4, max_num_batched_tokens=48,
block_size=16, max_num_seqs=4,
) as vllm_model: block_size=16,
outputs_w_features = (vllm_model.generate_greedy( ) as vllm_model:
challenge_prompts, max_tokens) if num_logprobs is None else outputs_w_features = (
vllm_model.generate_greedy_logprobs( vllm_model.generate_greedy(challenge_prompts, max_tokens) if
challenge_prompts, max_tokens, num_logprobs)) num_logprobs is None else vllm_model.generate_greedy_logprobs(
challenge_prompts, max_tokens, num_logprobs))
if num_logprobs is None:
# No-logprobs test if num_logprobs is None:
check_outputs_equal( # No-logprobs test
outputs_0_lst=outputs_baseline, check_outputs_equal(
outputs_1_lst=outputs_w_features, outputs_0_lst=outputs_baseline,
name_0="multi-step", outputs_1_lst=outputs_w_features,
name_1="multi-step+features", name_0="multi-step",
) name_1="multi-step+features",
else: )
# Yes-logprobs test else:
check_logprobs_close( # Yes-logprobs test
outputs_0_lst=outputs_baseline, check_logprobs_close(
outputs_1_lst=outputs_w_features, outputs_0_lst=outputs_baseline,
name_0="multi-step", outputs_1_lst=outputs_w_features,
name_1="multi-step+features", name_0="multi-step",
) name_1="multi-step+features",
)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os
import neuronxcc.nki.language as nl import neuronxcc.nki.language as nl
import pytest import pytest
...@@ -99,6 +98,7 @@ def ref_block_tables_transform( ...@@ -99,6 +98,7 @@ def ref_block_tables_transform(
) )
@torch.inference_mode() @torch.inference_mode()
def test_load_and_transform_block_tables( def test_load_and_transform_block_tables(
monkeypatch: pytest.MonkeyPatch,
num_tiles, num_tiles,
num_blocks_per_tile, num_blocks_per_tile,
q_head_per_kv_head, q_head_per_kv_head,
...@@ -108,46 +108,46 @@ def test_load_and_transform_block_tables( ...@@ -108,46 +108,46 @@ def test_load_and_transform_block_tables(
device = xm.xla_device() device = xm.xla_device()
compiler_flags = [ compiler_flags_str = " ".join([
"-O1", "-O1",
"--retry_failed_compilation", "--retry_failed_compilation",
] ])
compiler_flags_str = " ".join(compiler_flags) with monkeypatch.context() as m:
os.environ["NEURON_CC_FLAGS"] = compiler_flags_str m.setenv("NEURON_CC_FLAGS", compiler_flags_str)
torch.manual_seed(10000) torch.manual_seed(10000)
torch.set_printoptions(sci_mode=False) torch.set_printoptions(sci_mode=False)
# On Neuron, we need B_P_SIZE = 128 blocks to make DMA efficient # On Neuron, we need B_P_SIZE = 128 blocks to make DMA efficient
B_P_SIZE = 128 B_P_SIZE = 128
if num_blocks_per_tile < B_P_SIZE: if num_blocks_per_tile < B_P_SIZE:
assert B_P_SIZE % num_blocks_per_tile == 0 assert B_P_SIZE % num_blocks_per_tile == 0
block_size_tiling_factor = B_P_SIZE // num_blocks_per_tile block_size_tiling_factor = B_P_SIZE // num_blocks_per_tile
else: else:
block_size_tiling_factor = 1 block_size_tiling_factor = 1
max_num_blocks = 100000 max_num_blocks = 100000
block_tables = torch.randint( block_tables = torch.randint(
0, 0,
max_num_blocks, max_num_blocks,
(num_tiles * num_blocks_per_tile, ), (num_tiles * num_blocks_per_tile, ),
dtype=torch.int32, dtype=torch.int32,
) )
nki_out = nki.jit(nki_load_and_transform_block_tables)[1, 1]( nki_out = nki.jit(nki_load_and_transform_block_tables)[1, 1](
block_tables.to(device=device), block_tables.to(device=device),
num_tiles, num_tiles,
num_blocks_per_tile, num_blocks_per_tile,
q_head_per_kv_head, q_head_per_kv_head,
head_id, head_id,
block_size_tiling_factor, block_size_tiling_factor,
).cpu() ).cpu()
ref_out = ref_block_tables_transform( ref_out = ref_block_tables_transform(
block_tables, block_tables,
num_tiles, num_tiles,
num_blocks_per_tile, num_blocks_per_tile,
q_head_per_kv_head, q_head_per_kv_head,
head_id, head_id,
block_size_tiling_factor, block_size_tiling_factor,
) )
assert (nki_out.shape == ref_out.shape assert (nki_out.shape == ref_out.shape
), f"{nki_out.shape=} != {ref_out.shape=}" ), f"{nki_out.shape=} != {ref_out.shape=}"
assert torch.all(nki_out == ref_out) assert torch.all(nki_out == ref_out)
...@@ -320,6 +320,7 @@ def get_active_block_tables(block_tables, query_lens, seq_lens, block_size, ...@@ -320,6 +320,7 @@ def get_active_block_tables(block_tables, query_lens, seq_lens, block_size,
]) ])
@torch.inference_mode() @torch.inference_mode()
def test_contexted_kv_attention( def test_contexted_kv_attention(
monkeypatch: pytest.MonkeyPatch,
prefill_batch_size: int, prefill_batch_size: int,
decode_batch_size: int, decode_batch_size: int,
num_heads: int, num_heads: int,
...@@ -329,7 +330,6 @@ def test_contexted_kv_attention( ...@@ -329,7 +330,6 @@ def test_contexted_kv_attention(
large_tile_size, large_tile_size,
mixed_precision: bool, mixed_precision: bool,
) -> None: ) -> None:
import os
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
...@@ -340,174 +340,178 @@ def test_contexted_kv_attention( ...@@ -340,174 +340,178 @@ def test_contexted_kv_attention(
device = xm.xla_device() device = xm.xla_device()
compiler_flags = [ compiler_flags_str = " ".join([
"-O1", "-O1",
"--retry_failed_compilation", "--retry_failed_compilation",
] ])
compiler_flags_str = " ".join(compiler_flags) with monkeypatch.context() as m:
os.environ["NEURON_CC_FLAGS"] = compiler_flags_str m.setenv("NEURON_CC_FLAGS", compiler_flags_str)
torch.manual_seed(0) torch.manual_seed(0)
torch.set_printoptions(sci_mode=False) torch.set_printoptions(sci_mode=False)
torch.set_default_device("cpu") torch.set_default_device("cpu")
dtype = torch.float32 dtype = torch.float32
min_ctx_len = 32 min_ctx_len = 32
max_ctx_len = 1024 max_ctx_len = 1024
min_query_len = 16 min_query_len = 16
max_query_len = 512 max_query_len = 512
num_kv_heads = num_heads // num_queries_per_kv num_kv_heads = num_heads // num_queries_per_kv
( (
query, query,
k_active, k_active,
v_active, v_active,
k_cache, k_cache,
v_cache, v_cache,
block_table, block_table,
key, key,
value, value,
query_lens, query_lens,
seq_lens, seq_lens,
) = sample_inputs( ) = sample_inputs(
prefill_batch_size=prefill_batch_size, prefill_batch_size=prefill_batch_size,
decode_batch_size=decode_batch_size, decode_batch_size=decode_batch_size,
min_query_len=min_query_len, min_query_len=min_query_len,
max_query_len=max_query_len, max_query_len=max_query_len,
min_ctx_len=min_ctx_len, min_ctx_len=min_ctx_len,
max_ctx_len=max_ctx_len, max_ctx_len=max_ctx_len,
block_size=block_size, block_size=block_size,
num_heads=num_heads, num_heads=num_heads,
num_kv_heads=num_kv_heads, num_kv_heads=num_kv_heads,
head_size=head_size, head_size=head_size,
dtype=dtype, dtype=dtype,
) )
output_ref = ref_context_attention( output_ref = ref_context_attention(
query, query,
key, key,
value, value,
query_lens, query_lens,
seq_lens, seq_lens,
head_size, head_size,
num_queries_per_kv, num_queries_per_kv,
return_max_reduce=False, return_max_reduce=False,
) )
# build neuron program # build neuron program
B_P_SIZE = 128 B_P_SIZE = 128
assert (large_tile_size >= B_P_SIZE assert (large_tile_size >= B_P_SIZE
), f"Expect {large_tile_size=} to be larger than {B_P_SIZE=}" ), f"Expect {large_tile_size=} to be larger than {B_P_SIZE=}"
def ceil_div(a, b):
return (a + b - 1) // b
def pad_to_multiple(a, b):
return ceil_div(a, b) * b
def pad_to_next_power_of_2(a):
assert a > 0
return 2**int(a - 1).bit_length()
# calculate input shapes
max_num_queries = pad_to_next_power_of_2(sum(query_lens))
context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens)
num_active_blocks = ceil_div(context_lens, block_size).sum().item()
num_active_blocks = pad_to_multiple(num_active_blocks,
large_tile_size // block_size)
context_kv_len = num_active_blocks * block_size
assert (context_kv_len %
large_tile_size == 0), f"invalid context_kv_len={context_kv_len}"
# pad QKV tensors def ceil_div(a, b):
pad_dims = ( return (a + b - 1) // b
0,
0,
0,
0,
0,
max_num_queries - query.shape[0],
)
query = F.pad(query, pad_dims, "constant", 0)
k = F.pad(k_active, pad_dims, "constant", 0)
v = F.pad(v_active, pad_dims, "constant", 0)
# permute QKV tensors
# query: (1, n_heads, d, seq_q)
# key: (1, n_kv_heads, d, seq_k)
# value: (1, n_kv_heads, seq_v, d)
query = query.unsqueeze(0).permute(0, 2, 3, 1).contiguous()
k = k.unsqueeze(0).permute(0, 2, 3, 1).contiguous()
v = v.unsqueeze(0).permute(0, 2, 1, 3).contiguous()
k_cache = k_cache.permute(0, 2, 1, 3).contiguous()
v_cache = v_cache.permute(0, 2, 1, 3).contiguous()
# transform block table
active_block_table = get_active_block_tables(
block_table.cpu(),
torch.tensor(query_lens).cpu(),
torch.tensor(seq_lens).cpu(),
block_size,
num_active_blocks,
)
# Build attention masks def pad_to_multiple(a, b):
prior_mask, active_mask = ( return ceil_div(a, b) * b
BlockDiagonalCausalFromBottomRightMask.from_seqlens(
query_lens, seq_lens, block_size=block_size)) def pad_to_next_power_of_2(a):
prior_mask_padded = F.pad( assert a > 0
prior_mask, return 2**int(a - 1).bit_length()
(
# calculate input shapes
max_num_queries = pad_to_next_power_of_2(sum(query_lens))
context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens)
num_active_blocks = ceil_div(context_lens, block_size).sum().item()
num_active_blocks = pad_to_multiple(num_active_blocks,
large_tile_size // block_size)
context_kv_len = num_active_blocks * block_size
assert (
context_kv_len %
large_tile_size == 0), f"invalid context_kv_len={context_kv_len}"
# pad QKV tensors
pad_dims = (
0, 0,
context_kv_len - prior_mask.shape[1],
0, 0,
max_num_queries - prior_mask.shape[0],
),
"constant",
0,
).bool()
active_mask_padded = F.pad(
active_mask,
(
0, 0,
max_num_queries - active_mask.shape[1],
0, 0,
max_num_queries - active_mask.shape[0], 0,
), max_num_queries - query.shape[0],
"constant", )
0, query = F.pad(query, pad_dims, "constant", 0)
).bool() k = F.pad(k_active, pad_dims, "constant", 0)
attn_mask = torch.concat([prior_mask_padded, active_mask_padded], dim=1) v = F.pad(v_active, pad_dims, "constant", 0)
attn_mask = reorder_context_mask(attn_mask, large_tile_size, block_size) # permute QKV tensors
# query: (1, n_heads, d, seq_q)
input_args = ( # key: (1, n_kv_heads, d, seq_k)
query.to(device=device), # value: (1, n_kv_heads, seq_v, d)
k.to(device=device), query = query.unsqueeze(0).permute(0, 2, 3, 1).contiguous()
v.to(device=device), k = k.unsqueeze(0).permute(0, 2, 3, 1).contiguous()
k_cache.to(device=device), v = v.unsqueeze(0).permute(0, 2, 1, 3).contiguous()
v_cache.to(device=device), k_cache = k_cache.permute(0, 2, 1, 3).contiguous()
active_block_table.to(device=device), v_cache = v_cache.permute(0, 2, 1, 3).contiguous()
attn_mask.to(device=device),
) # transform block table
input_kwargs = dict( active_block_table = get_active_block_tables(
n_kv_head=num_kv_heads, block_table.cpu(),
head_size=head_size, torch.tensor(query_lens).cpu(),
mixed_precision=mixed_precision, torch.tensor(seq_lens).cpu(),
LARGE_TILE_SZ=large_tile_size, block_size,
) num_active_blocks,
)
output_nki = flash_attn_varlen_nkifunc(*input_args, **input_kwargs) # Build attention masks
prior_mask, active_mask = (
BlockDiagonalCausalFromBottomRightMask.from_seqlens(
query_lens, seq_lens, block_size=block_size))
prior_mask_padded = F.pad(
prior_mask,
(
0,
context_kv_len - prior_mask.shape[1],
0,
max_num_queries - prior_mask.shape[0],
),
"constant",
0,
).bool()
active_mask_padded = F.pad(
active_mask,
(
0,
max_num_queries - active_mask.shape[1],
0,
max_num_queries - active_mask.shape[0],
),
"constant",
0,
).bool()
attn_mask = torch.concat([prior_mask_padded, active_mask_padded],
dim=1)
attn_mask = reorder_context_mask(attn_mask, large_tile_size,
block_size)
input_args = (
query.to(device=device),
k.to(device=device),
v.to(device=device),
k_cache.to(device=device),
v_cache.to(device=device),
active_block_table.to(device=device),
attn_mask.to(device=device),
)
input_kwargs = dict(
n_kv_head=num_kv_heads,
head_size=head_size,
mixed_precision=mixed_precision,
LARGE_TILE_SZ=large_tile_size,
)
num_actual_tokens = sum(query_lens) output_nki = flash_attn_varlen_nkifunc(*input_args, **input_kwargs)
# - o: shape (bs, n_heads, seq_q, d) -> (bs, seq_q, n_heads, d)
output_nki = output_nki.cpu().permute(0, 2, 1, 3) num_actual_tokens = sum(query_lens)
output_nki = output_nki[0, :num_actual_tokens, :, :] # - o: shape (bs, n_heads, seq_q, d) -> (bs, seq_q, n_heads, d)
output_ref_padded = F.pad( output_nki = output_nki.cpu().permute(0, 2, 1, 3)
output_ref, output_nki = output_nki[0, :num_actual_tokens, :, :]
(0, 0, 0, 0, 0, 0, 0, max_num_queries - output_ref.shape[0]), output_ref_padded = F.pad(
"constant", output_ref,
0, (0, 0, 0, 0, 0, 0, 0, max_num_queries - output_ref.shape[0]),
) "constant",
output_ref = output_ref_padded.transpose(0, 1)[0, :num_actual_tokens, :, :] 0,
)
output_ref = output_ref_padded.transpose(
0, 1)[0, :num_actual_tokens, :, :]
torch.testing.assert_close(output_nki, output_ref, atol=1e-2, rtol=0) torch.testing.assert_close(output_nki, output_ref, atol=1e-2, rtol=0)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import pytest
import torch import torch
from tests.kernels.utils import override_backend_env_variable
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.utils import STR_INVALID_VAL from vllm.utils import STR_BACKEND_ENV_VAR, STR_INVALID_VAL
def test_platform_plugins(): def test_platform_plugins():
...@@ -25,8 +25,9 @@ def test_platform_plugins(): ...@@ -25,8 +25,9 @@ def test_platform_plugins():
f" is loaded. The first import:\n{_init_trace}") f" is loaded. The first import:\n{_init_trace}")
def test_oot_attention_backend(monkeypatch): def test_oot_attention_backend(monkeypatch: pytest.MonkeyPatch):
# ignore the backend env variable if it is set # ignore the backend env variable if it is set
override_backend_env_variable(monkeypatch, STR_INVALID_VAL) with monkeypatch.context() as m:
backend = get_attn_backend(16, torch.float16, torch.float16, 16, False) m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL)
assert backend.get_name() == "Dummy_Backend" backend = get_attn_backend(16, torch.float16, torch.float16, 16, False)
assert backend.get_name() == "Dummy_Backend"
...@@ -22,43 +22,47 @@ class DummyV1Scheduler(V1Scheduler): ...@@ -22,43 +22,47 @@ class DummyV1Scheduler(V1Scheduler):
raise Exception("Exception raised by DummyV1Scheduler") raise Exception("Exception raised by DummyV1Scheduler")
def test_scheduler_plugins_v0(monkeypatch): def test_scheduler_plugins_v0(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_V1", "0") with monkeypatch.context() as m:
with pytest.raises(Exception) as exception_info: m.setenv("VLLM_USE_V1", "0")
with pytest.raises(Exception) as exception_info:
engine_args = EngineArgs( engine_args = EngineArgs(
model="facebook/opt-125m", model="facebook/opt-125m",
enforce_eager=True, # reduce test time enforce_eager=True, # reduce test time
scheduler_cls=DummyV0Scheduler, scheduler_cls=DummyV0Scheduler,
) )
engine = LLMEngine.from_engine_args(engine_args=engine_args) engine = LLMEngine.from_engine_args(engine_args=engine_args)
sampling_params = SamplingParams(max_tokens=1) sampling_params = SamplingParams(max_tokens=1)
engine.add_request("0", "foo", sampling_params) engine.add_request("0", "foo", sampling_params)
engine.step() engine.step()
assert str(exception_info.value) == "Exception raised by DummyV0Scheduler" assert str(
exception_info.value) == "Exception raised by DummyV0Scheduler"
def test_scheduler_plugins_v1(monkeypatch): def test_scheduler_plugins_v1(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_V1", "1") with monkeypatch.context() as m:
# Explicitly turn off engine multiprocessing so that the scheduler runs in m.setenv("VLLM_USE_V1", "1")
# this process # Explicitly turn off engine multiprocessing so
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") # that the scheduler runs in this process
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
with pytest.raises(Exception) as exception_info: with pytest.raises(Exception) as exception_info:
engine_args = EngineArgs( engine_args = EngineArgs(
model="facebook/opt-125m", model="facebook/opt-125m",
enforce_eager=True, # reduce test time enforce_eager=True, # reduce test time
scheduler_cls=DummyV1Scheduler, scheduler_cls=DummyV1Scheduler,
) )
engine = V1LLMEngine.from_engine_args(engine_args=engine_args) engine = V1LLMEngine.from_engine_args(engine_args=engine_args)
sampling_params = SamplingParams(max_tokens=1) sampling_params = SamplingParams(max_tokens=1)
engine.add_request("0", "foo", sampling_params) engine.add_request("0", "foo", sampling_params)
engine.step() engine.step()
assert str(exception_info.value) == "Exception raised by DummyV1Scheduler" assert str(
exception_info.value) == "Exception raised by DummyV1Scheduler"
...@@ -4,25 +4,29 @@ ...@@ -4,25 +4,29 @@
Run `pytest tests/prefix_caching/test_prefix_caching.py`. Run `pytest tests/prefix_caching/test_prefix_caching.py`.
""" """
from __future__ import annotations
import pytest import pytest
from tests.conftest import VllmRunner from tests.conftest import VllmRunner
from tests.core.utils import SchedulerProxy, create_dummy_prompt from tests.core.utils import SchedulerProxy, create_dummy_prompt
from tests.kernels.utils import override_backend_env_variable
from vllm import SamplingParams, TokensPrompt from vllm import SamplingParams, TokensPrompt
from vllm.core.scheduler import Scheduler from vllm.core.scheduler import Scheduler
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import STR_BACKEND_ENV_VAR
from ..models.utils import check_outputs_equal from ..models.utils import check_outputs_equal
@pytest.fixture(scope="function", autouse=True) @pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch): def use_v0_only(monkeypatch: pytest.MonkeyPatch):
""" """
This module relies on V0 internals, so set VLLM_USE_V1=0. This module relies on V0 internals, so set VLLM_USE_V1=0.
""" """
monkeypatch.setenv('VLLM_USE_V1', '0') with monkeypatch.context() as m:
m.setenv('VLLM_USE_V1', '0')
yield
MODELS = [ MODELS = [
...@@ -56,7 +60,7 @@ def test_mixed_requests( ...@@ -56,7 +60,7 @@ def test_mixed_requests(
cached_position: int, cached_position: int,
enable_chunked_prefill: bool, enable_chunked_prefill: bool,
block_size: int, block_size: int,
monkeypatch, monkeypatch: pytest.MonkeyPatch,
) -> None: ) -> None:
""" """
Test the case when some sequences have the prefix cache hit Test the case when some sequences have the prefix cache hit
...@@ -67,72 +71,77 @@ def test_mixed_requests( ...@@ -67,72 +71,77 @@ def test_mixed_requests(
pytest.skip("Flashinfer does not support ROCm/HIP.") pytest.skip("Flashinfer does not support ROCm/HIP.")
if backend == "XFORMERS" and current_platform.is_rocm(): if backend == "XFORMERS" and current_platform.is_rocm():
pytest.skip("Xformers does not support ROCm/HIP.") pytest.skip("Xformers does not support ROCm/HIP.")
override_backend_env_variable(monkeypatch, backend) with monkeypatch.context() as m:
m.setenv(STR_BACKEND_ENV_VAR, backend)
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
cached_prompt = example_prompts[cached_position]
with vllm_runner( cached_prompt = example_prompts[cached_position]
model, with vllm_runner(
dtype=dtype, model,
enable_prefix_caching=True, dtype=dtype,
enable_chunked_prefill=enable_chunked_prefill, enable_prefix_caching=True,
block_size=block_size, enable_chunked_prefill=enable_chunked_prefill,
) as vllm_model: block_size=block_size,
# Run the first prompt so the cache is populated ) as vllm_model:
vllm_outputs = vllm_model.generate_greedy([cached_prompt], max_tokens) # Run the first prompt so the cache is populated
vllm_outputs = vllm_model.generate_greedy([cached_prompt],
# Run all the promopts max_tokens)
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
req_outputs = vllm_model.model.generate(example_prompts, greedy_params) # Run all the promopts
greedy_params = SamplingParams(temperature=0.0,
# Verify number of cached tokens max_tokens=max_tokens)
for i in range(len(req_outputs)): req_outputs = vllm_model.model.generate(example_prompts,
if i == cached_position: greedy_params)
expected_num_cached_tokens = (
len(req_outputs[i].prompt_token_ids) // # Verify number of cached tokens
block_size) * block_size for i in range(len(req_outputs)):
else: if i == cached_position:
expected_num_cached_tokens = 0 expected_num_cached_tokens = (
assert ( len(req_outputs[i].prompt_token_ids) //
req_outputs[i].num_cached_tokens == expected_num_cached_tokens) block_size) * block_size
else:
vllm_outputs = [( expected_num_cached_tokens = 0
output.prompt_token_ids + list(output.outputs[0].token_ids), assert (req_outputs[i].num_cached_tokens ==
output.prompt + output.outputs[0].text, expected_num_cached_tokens)
) for output in req_outputs]
vllm_outputs = [(
check_outputs_equal( output.prompt_token_ids + list(output.outputs[0].token_ids),
outputs_0_lst=hf_outputs, output.prompt + output.outputs[0].text,
outputs_1_lst=vllm_outputs, ) for output in req_outputs]
name_0="hf",
name_1="vllm", check_outputs_equal(
) outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"]) @pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"])
def test_unstable_prompt_sequence( def test_unstable_prompt_sequence(
vllm_runner, vllm_runner,
backend: str, backend: str,
monkeypatch, monkeypatch: pytest.MonkeyPatch,
) -> None: ) -> None:
if backend == "FLASHINFER" and current_platform.is_rocm(): if backend == "FLASHINFER" and current_platform.is_rocm():
pytest.skip("Flashinfer does not support ROCm/HIP.") pytest.skip("Flashinfer does not support ROCm/HIP.")
if backend == "XFORMERS" and current_platform.is_rocm(): if backend == "XFORMERS" and current_platform.is_rocm():
pytest.skip("Xformers does not support ROCm/HIP.") pytest.skip("Xformers does not support ROCm/HIP.")
override_backend_env_variable(monkeypatch, backend) with monkeypatch.context() as m:
m.setenv(STR_BACKEND_ENV_VAR, backend)
with vllm_runner(
"Qwen/Qwen2.5-0.5B-Instruct", with vllm_runner(
enable_chunked_prefill=True, "Qwen/Qwen2.5-0.5B-Instruct",
enable_prefix_caching=True, enable_chunked_prefill=True,
max_model_len=4096, enable_prefix_caching=True,
) as vllm_model: max_model_len=4096,
for prompt in UNSTABLE_PROMPT_SEQUENCE: ) as vllm_model:
vllm_model.generate(TokensPrompt(prompt_token_ids=prompt), for prompt in UNSTABLE_PROMPT_SEQUENCE:
SamplingParams(max_tokens=1)) vllm_model.generate(TokensPrompt(prompt_token_ids=prompt),
SamplingParams(max_tokens=1))
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
......
...@@ -56,12 +56,11 @@ def test_gc(): ...@@ -56,12 +56,11 @@ def test_gc():
assert allocated < 50 * 1024 * 1024 assert allocated < 50 * 1024 * 1024
def test_model_from_modelscope(monkeypatch): def test_model_from_modelscope(monkeypatch: pytest.MonkeyPatch):
# model: https://modelscope.cn/models/qwen/Qwen1.5-0.5B-Chat/summary # model: https://modelscope.cn/models/qwen/Qwen1.5-0.5B-Chat/summary
MODELSCOPE_MODEL_NAME = "qwen/Qwen1.5-0.5B-Chat" with monkeypatch.context() as m:
monkeypatch.setenv("VLLM_USE_MODELSCOPE", "True") m.setenv("VLLM_USE_MODELSCOPE", "True")
try: llm = LLM(model="qwen/Qwen1.5-0.5B-Chat")
llm = LLM(model=MODELSCOPE_MODEL_NAME)
prompts = [ prompts = [
"Hello, my name is", "Hello, my name is",
...@@ -73,10 +72,3 @@ def test_model_from_modelscope(monkeypatch): ...@@ -73,10 +72,3 @@ def test_model_from_modelscope(monkeypatch):
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
assert len(outputs) == 4 assert len(outputs) == 4
finally:
monkeypatch.delenv("VLLM_USE_MODELSCOPE", raising=False)
if __name__ == "__main__":
import pytest
pytest.main([__file__])
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# ruff: noqa
import asyncio import asyncio
import os
import socket import socket
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from unittest.mock import patch from unittest.mock import patch
...@@ -112,16 +112,16 @@ def test_deprecate_kwargs_additional_message(): ...@@ -112,16 +112,16 @@ def test_deprecate_kwargs_additional_message():
dummy(old_arg=1) dummy(old_arg=1)
def test_get_open_port(): def test_get_open_port(monkeypatch: pytest.MonkeyPatch):
os.environ["VLLM_PORT"] = "5678" with monkeypatch.context() as m:
# make sure we can get multiple ports, even if the env var is set m.setenv("VLLM_PORT", "5678")
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s1: # make sure we can get multiple ports, even if the env var is set
s1.bind(("localhost", get_open_port())) with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s1:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s2: s1.bind(("localhost", get_open_port()))
s2.bind(("localhost", get_open_port())) with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s2:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s3: s2.bind(("localhost", get_open_port()))
s3.bind(("localhost", get_open_port())) with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s3:
os.environ.pop("VLLM_PORT") s3.bind(("localhost", get_open_port()))
# Tests for FlexibleArgumentParser # Tests for FlexibleArgumentParser
...@@ -366,31 +366,32 @@ def test_bind_kv_cache_non_attention(): ...@@ -366,31 +366,32 @@ def test_bind_kv_cache_non_attention():
assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[1] assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[1]
def test_bind_kv_cache_encoder_decoder(monkeypatch): def test_bind_kv_cache_encoder_decoder(monkeypatch: pytest.MonkeyPatch):
# V1 TESTS: ENCODER_DECODER is not supported on V1 yet. # V1 TESTS: ENCODER_DECODER is not supported on V1 yet.
monkeypatch.setenv("VLLM_USE_V1", "0") with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "0")
from vllm.attention import Attention, AttentionType from vllm.attention import Attention, AttentionType
# example from bart # example from bart
ctx = { ctx = {
'encoder.layers.0.self_attn.attn': 'encoder.layers.0.self_attn.attn':
Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER), Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER),
'decoder.layers.0.encoder_attn.attn': 'decoder.layers.0.encoder_attn.attn':
Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER_DECODER), Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER_DECODER),
'decoder.layers.0.self_attn.attn': 'decoder.layers.0.self_attn.attn':
Attention(32, 128, 0.1, attn_type=AttentionType.DECODER), Attention(32, 128, 0.1, attn_type=AttentionType.DECODER),
} }
kv_cache = [ kv_cache = [
torch.zeros((1, )), torch.zeros((1, )),
] ]
encoder_kv_cache = ctx['encoder.layers.0.self_attn.attn'].kv_cache encoder_kv_cache = ctx['encoder.layers.0.self_attn.attn'].kv_cache
bind_kv_cache(ctx, [kv_cache]) bind_kv_cache(ctx, [kv_cache])
assert ctx['encoder.layers.0.self_attn.attn'].kv_cache is encoder_kv_cache assert ctx['encoder.layers.0.self_attn.attn'].kv_cache is encoder_kv_cache
assert ctx['decoder.layers.0.encoder_attn.attn'].kv_cache[0] is kv_cache[0] assert ctx['decoder.layers.0.encoder_attn.attn'].kv_cache[0] is kv_cache[0]
assert ctx['decoder.layers.0.self_attn.attn'].kv_cache[0] is kv_cache[0] assert ctx['decoder.layers.0.self_attn.attn'].kv_cache[0] is kv_cache[0]
def test_bind_kv_cache_pp(): def test_bind_kv_cache_pp():
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os import pytest
from vllm.config import CompilationLevel from vllm.config import CompilationLevel
...@@ -9,16 +9,17 @@ from ..utils import compare_two_settings ...@@ -9,16 +9,17 @@ from ..utils import compare_two_settings
# --enforce-eager on TPU causes graph compilation # --enforce-eager on TPU causes graph compilation
# this times out default Health Check in the MQLLMEngine, # this times out default Health Check in the MQLLMEngine,
# so we set the timeout here to 30s # so we set the timeout here to 30s
os.environ["VLLM_RPC_TIMEOUT"] = "30000"
def test_custom_dispatcher(): def test_custom_dispatcher(monkeypatch: pytest.MonkeyPatch):
compare_two_settings( with monkeypatch.context() as m:
"google/gemma-2b", m.setenv("VLLM_RPC_TIMEOUT", "30000")
arg1=[ compare_two_settings(
"--enforce-eager", "google/gemma-2b",
f"-O{CompilationLevel.DYNAMO_ONCE}", arg1=[
], "--enforce-eager",
arg2=["--enforce-eager", f"-O{CompilationLevel.DYNAMO_AS_IS}"], f"-O{CompilationLevel.DYNAMO_ONCE}",
env1={}, ],
env2={}) arg2=["--enforce-eager", f"-O{CompilationLevel.DYNAMO_AS_IS}"],
env1={},
env2={})
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# ruff: noqa
# type: ignore
from __future__ import annotations
import os
import threading import threading
from collections.abc import Iterable from collections.abc import Iterable
from concurrent import futures from concurrent import futures
from typing import Callable, Literal from typing import Callable, Generator, Literal
import grpc import grpc
import pytest import pytest
...@@ -21,12 +23,14 @@ from vllm.tracing import SpanAttributes ...@@ -21,12 +23,14 @@ from vllm.tracing import SpanAttributes
@pytest.fixture(scope="function", autouse=True) @pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch): def use_v0_only(monkeypatch: pytest.MonkeyPatch):
""" """
Since this module is V0 only, set VLLM_USE_V1=0 for Since this module is V0 only, set VLLM_USE_V1=0 for
all tests in the module. all tests in the module.
""" """
monkeypatch.setenv('VLLM_USE_V1', '0') with monkeypatch.context() as m:
m.setenv('VLLM_USE_V1', '0')
yield
FAKE_TRACE_SERVER_ADDRESS = "localhost:4317" FAKE_TRACE_SERVER_ADDRESS = "localhost:4317"
...@@ -67,7 +71,7 @@ class FakeTraceService(TraceServiceServicer): ...@@ -67,7 +71,7 @@ class FakeTraceService(TraceServiceServicer):
@pytest.fixture @pytest.fixture
def trace_service(): def trace_service() -> Generator[FakeTraceService, None, None]:
"""Fixture to set up a fake gRPC trace service""" """Fixture to set up a fake gRPC trace service"""
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
service = FakeTraceService() service = FakeTraceService()
...@@ -80,136 +84,153 @@ def trace_service(): ...@@ -80,136 +84,153 @@ def trace_service():
server.stop(None) server.stop(None)
def test_traces(trace_service): def test_traces(
os.environ[OTEL_EXPORTER_OTLP_TRACES_INSECURE] = "true" monkeypatch: pytest.MonkeyPatch,
trace_service: FakeTraceService,
sampling_params = SamplingParams(temperature=0.01, ):
top_p=0.1, with monkeypatch.context() as m:
max_tokens=256) m.setenv(OTEL_EXPORTER_OTLP_TRACES_INSECURE, "true")
model = "facebook/opt-125m"
llm = LLM( sampling_params = SamplingParams(
model=model, temperature=0.01,
otlp_traces_endpoint=FAKE_TRACE_SERVER_ADDRESS, top_p=0.1,
) max_tokens=256,
prompts = ["This is a short prompt"] )
outputs = llm.generate(prompts, sampling_params=sampling_params) model = "facebook/opt-125m"
llm = LLM(
timeout = 5 model=model,
if not trace_service.evt.wait(timeout): otlp_traces_endpoint=FAKE_TRACE_SERVER_ADDRESS,
raise TimeoutError( )
f"The fake trace service didn't receive a trace within " prompts = ["This is a short prompt"]
f"the {timeout} seconds timeout") outputs = llm.generate(prompts, sampling_params=sampling_params)
request = trace_service.request timeout = 5
assert len(request.resource_spans) == 1, ( if not trace_service.evt.wait(timeout):
f"Expected 1 resource span, " raise TimeoutError(
f"but got {len(request.resource_spans)}") f"The fake trace service didn't receive a trace within "
assert len(request.resource_spans[0].scope_spans) == 1, ( f"the {timeout} seconds timeout")
f"Expected 1 scope span, "
f"but got {len(request.resource_spans[0].scope_spans)}") request = trace_service.request
assert len(request.resource_spans[0].scope_spans[0].spans) == 1, ( assert len(request.resource_spans) == 1, (
f"Expected 1 span, " f"Expected 1 resource span, "
f"but got {len(request.resource_spans[0].scope_spans[0].spans)}") f"but got {len(request.resource_spans)}")
assert len(request.resource_spans[0].scope_spans) == 1, (
attributes = decode_attributes( f"Expected 1 scope span, "
request.resource_spans[0].scope_spans[0].spans[0].attributes) f"but got {len(request.resource_spans[0].scope_spans)}")
assert attributes.get(SpanAttributes.GEN_AI_RESPONSE_MODEL) == model assert len(request.resource_spans[0].scope_spans[0].spans) == 1, (
assert attributes.get( f"Expected 1 span, "
SpanAttributes.GEN_AI_REQUEST_ID) == outputs[0].request_id f"but got {len(request.resource_spans[0].scope_spans[0].spans)}")
assert attributes.get(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE
) == sampling_params.temperature attributes = decode_attributes(
assert attributes.get( request.resource_spans[0].scope_spans[0].spans[0].attributes)
SpanAttributes.GEN_AI_REQUEST_TOP_P) == sampling_params.top_p assert attributes.get(SpanAttributes.GEN_AI_RESPONSE_MODEL) == model
assert attributes.get( assert attributes.get(
SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS) == sampling_params.max_tokens SpanAttributes.GEN_AI_REQUEST_ID) == outputs[0].request_id
assert attributes.get(SpanAttributes.GEN_AI_REQUEST_N) == sampling_params.n assert attributes.get(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE
assert attributes.get(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS) == len( ) == sampling_params.temperature
outputs[0].prompt_token_ids) assert attributes.get(
completion_tokens = sum(len(o.token_ids) for o in outputs[0].outputs) SpanAttributes.GEN_AI_REQUEST_TOP_P) == sampling_params.top_p
assert attributes.get( assert attributes.get(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS
SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS) == completion_tokens ) == sampling_params.max_tokens
metrics = outputs[0].metrics assert attributes.get(
assert attributes.get( SpanAttributes.GEN_AI_REQUEST_N) == sampling_params.n
SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE) == metrics.time_in_queue assert attributes.get(
ttft = metrics.first_token_time - metrics.arrival_time SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS) == len(
assert attributes.get( outputs[0].prompt_token_ids)
SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN) == ttft completion_tokens = sum(len(o.token_ids) for o in outputs[0].outputs)
e2e_time = metrics.finished_time - metrics.arrival_time assert attributes.get(
assert attributes.get(SpanAttributes.GEN_AI_LATENCY_E2E) == e2e_time SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS) == completion_tokens
assert metrics.scheduler_time > 0 metrics = outputs[0].metrics
assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_SCHEDULER assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE
) == metrics.scheduler_time ) == metrics.time_in_queue
# Model forward and model execute should be none, since detailed traces is ttft = metrics.first_token_time - metrics.arrival_time
# not enabled. assert attributes.get(
assert metrics.model_forward_time is None SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN) == ttft
assert metrics.model_execute_time is None e2e_time = metrics.finished_time - metrics.arrival_time
assert attributes.get(SpanAttributes.GEN_AI_LATENCY_E2E) == e2e_time
assert metrics.scheduler_time > 0
def test_traces_with_detailed_steps(trace_service): assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_SCHEDULER
os.environ[OTEL_EXPORTER_OTLP_TRACES_INSECURE] = "true" ) == metrics.scheduler_time
# Model forward and model execute should be none, since detailed traces is
sampling_params = SamplingParams(temperature=0.01, # not enabled.
top_p=0.1, assert metrics.model_forward_time is None
max_tokens=256) assert metrics.model_execute_time is None
model = "facebook/opt-125m"
llm = LLM(
model=model, def test_traces_with_detailed_steps(
otlp_traces_endpoint=FAKE_TRACE_SERVER_ADDRESS, monkeypatch: pytest.MonkeyPatch,
collect_detailed_traces="all", trace_service: FakeTraceService,
) ):
prompts = ["This is a short prompt"] with monkeypatch.context() as m:
outputs = llm.generate(prompts, sampling_params=sampling_params) m.setenv(OTEL_EXPORTER_OTLP_TRACES_INSECURE, "true")
timeout = 5 sampling_params = SamplingParams(
if not trace_service.evt.wait(timeout): temperature=0.01,
raise TimeoutError( top_p=0.1,
f"The fake trace service didn't receive a trace within " max_tokens=256,
f"the {timeout} seconds timeout") )
model = "facebook/opt-125m"
request = trace_service.request llm = LLM(
assert len(request.resource_spans) == 1, ( model=model,
f"Expected 1 resource span, " otlp_traces_endpoint=FAKE_TRACE_SERVER_ADDRESS,
f"but got {len(request.resource_spans)}") collect_detailed_traces="all",
assert len(request.resource_spans[0].scope_spans) == 1, ( )
f"Expected 1 scope span, " prompts = ["This is a short prompt"]
f"but got {len(request.resource_spans[0].scope_spans)}") outputs = llm.generate(prompts, sampling_params=sampling_params)
assert len(request.resource_spans[0].scope_spans[0].spans) == 1, (
f"Expected 1 span, " timeout = 5
f"but got {len(request.resource_spans[0].scope_spans[0].spans)}") if not trace_service.evt.wait(timeout):
raise TimeoutError(
attributes = decode_attributes( f"The fake trace service didn't receive a trace within "
request.resource_spans[0].scope_spans[0].spans[0].attributes) f"the {timeout} seconds timeout")
assert attributes.get(SpanAttributes.GEN_AI_RESPONSE_MODEL) == model
assert attributes.get( request = trace_service.request
SpanAttributes.GEN_AI_REQUEST_ID) == outputs[0].request_id assert len(request.resource_spans) == 1, (
assert attributes.get(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE f"Expected 1 resource span, "
) == sampling_params.temperature f"but got {len(request.resource_spans)}")
assert attributes.get( assert len(request.resource_spans[0].scope_spans) == 1, (
SpanAttributes.GEN_AI_REQUEST_TOP_P) == sampling_params.top_p f"Expected 1 scope span, "
assert attributes.get( f"but got {len(request.resource_spans[0].scope_spans)}")
SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS) == sampling_params.max_tokens assert len(request.resource_spans[0].scope_spans[0].spans) == 1, (
assert attributes.get(SpanAttributes.GEN_AI_REQUEST_N) == sampling_params.n f"Expected 1 span, "
assert attributes.get(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS) == len( f"but got {len(request.resource_spans[0].scope_spans[0].spans)}")
outputs[0].prompt_token_ids)
completion_tokens = sum(len(o.token_ids) for o in outputs[0].outputs) attributes = decode_attributes(
assert attributes.get( request.resource_spans[0].scope_spans[0].spans[0].attributes)
SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS) == completion_tokens assert attributes.get(SpanAttributes.GEN_AI_RESPONSE_MODEL) == model
metrics = outputs[0].metrics assert attributes.get(
assert attributes.get( SpanAttributes.GEN_AI_REQUEST_ID) == outputs[0].request_id
SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE) == metrics.time_in_queue assert attributes.get(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE
ttft = metrics.first_token_time - metrics.arrival_time ) == sampling_params.temperature
assert attributes.get( assert attributes.get(
SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN) == ttft SpanAttributes.GEN_AI_REQUEST_TOP_P) == sampling_params.top_p
e2e_time = metrics.finished_time - metrics.arrival_time assert attributes.get(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS
assert attributes.get(SpanAttributes.GEN_AI_LATENCY_E2E) == e2e_time ) == sampling_params.max_tokens
assert metrics.scheduler_time > 0 assert attributes.get(
assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_SCHEDULER SpanAttributes.GEN_AI_REQUEST_N) == sampling_params.n
) == metrics.scheduler_time assert attributes.get(
assert metrics.model_forward_time > 0 SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS) == len(
assert attributes.get( outputs[0].prompt_token_ids)
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_FORWARD) == pytest.approx( completion_tokens = sum(len(o.token_ids) for o in outputs[0].outputs)
metrics.model_forward_time / 1000) assert attributes.get(
assert metrics.model_execute_time > 0 SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS) == completion_tokens
assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE metrics = outputs[0].metrics
) == metrics.model_execute_time assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE
assert metrics.model_forward_time < 1000 * metrics.model_execute_time ) == metrics.time_in_queue
ttft = metrics.first_token_time - metrics.arrival_time
assert attributes.get(
SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN) == ttft
e2e_time = metrics.finished_time - metrics.arrival_time
assert attributes.get(SpanAttributes.GEN_AI_LATENCY_E2E) == e2e_time
assert metrics.scheduler_time > 0
assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_SCHEDULER
) == metrics.scheduler_time
assert metrics.model_forward_time > 0
assert attributes.get(
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_FORWARD
) == pytest.approx(metrics.model_forward_time / 1000)
assert metrics.model_execute_time > 0
assert attributes.get(
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE
) == metrics.model_execute_time
assert metrics.model_forward_time < 1000 * metrics.model_execute_time
...@@ -566,6 +566,7 @@ def init_test_distributed_environment( ...@@ -566,6 +566,7 @@ def init_test_distributed_environment(
def multi_process_parallel( def multi_process_parallel(
monkeypatch: pytest.MonkeyPatch,
tp_size: int, tp_size: int,
pp_size: int, pp_size: int,
test_target: Any, test_target: Any,
...@@ -582,7 +583,13 @@ def multi_process_parallel( ...@@ -582,7 +583,13 @@ def multi_process_parallel(
refs = [] refs = []
for rank in range(tp_size * pp_size): for rank in range(tp_size * pp_size):
refs.append( refs.append(
test_target.remote(tp_size, pp_size, rank, distributed_init_port)) test_target.remote(
monkeypatch,
tp_size,
pp_size,
rank,
distributed_init_port,
), )
ray.get(refs) ray.get(refs)
ray.shutdown() ray.shutdown()
...@@ -700,7 +707,7 @@ def large_gpu_mark(min_gb: int) -> pytest.MarkDecorator: ...@@ -700,7 +707,7 @@ def large_gpu_mark(min_gb: int) -> pytest.MarkDecorator:
""" """
Get a pytest mark, which skips the test if the GPU doesn't meet Get a pytest mark, which skips the test if the GPU doesn't meet
a minimum memory requirement in GB. a minimum memory requirement in GB.
This can be leveraged via `@large_gpu_test` to skip tests in environments This can be leveraged via `@large_gpu_test` to skip tests in environments
without enough resources, or called when filtering tests to run directly. without enough resources, or called when filtering tests to run directly.
""" """
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import random import random
from typing import Any
import pytest import pytest
...@@ -50,8 +53,12 @@ def model_name(): ...@@ -50,8 +53,12 @@ def model_name():
return "meta-llama/Meta-Llama-3-8B-Instruct" return "meta-llama/Meta-Llama-3-8B-Instruct"
def test_ngram_correctness(monkeypatch, test_prompts, sampling_config, def test_ngram_correctness(
model_name): monkeypatch: pytest.MonkeyPatch,
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_name: str,
):
''' '''
Compare the outputs of a original LLM and a speculative LLM Compare the outputs of a original LLM and a speculative LLM
should be the same when using ngram speculative decoding. should be the same when using ngram speculative decoding.
......
...@@ -80,9 +80,11 @@ async def generate(engine: AsyncLLM, ...@@ -80,9 +80,11 @@ async def generate(engine: AsyncLLM,
[(TEXT_ENGINE_ARGS, TEXT_PROMPT), [(TEXT_ENGINE_ARGS, TEXT_PROMPT),
(VISION_ENGINE_ARGS, VISION_PROMPT)]) (VISION_ENGINE_ARGS, VISION_PROMPT)])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_load(monkeypatch, output_kind: RequestOutputKind, async def test_load(
engine_args_and_prompt: tuple[AsyncEngineArgs, monkeypatch: pytest.MonkeyPatch,
PromptType]): output_kind: RequestOutputKind,
engine_args_and_prompt: tuple[AsyncEngineArgs, PromptType],
):
# TODO(rickyx): Remove monkeypatch once we have a better way to test V1 # TODO(rickyx): Remove monkeypatch once we have a better way to test V1
# so that in the future when we switch, we don't have to change all the # so that in the future when we switch, we don't have to change all the
# tests. # tests.
...@@ -126,7 +128,8 @@ async def test_load(monkeypatch, output_kind: RequestOutputKind, ...@@ -126,7 +128,8 @@ async def test_load(monkeypatch, output_kind: RequestOutputKind,
[(TEXT_ENGINE_ARGS, TEXT_PROMPT), [(TEXT_ENGINE_ARGS, TEXT_PROMPT),
(VISION_ENGINE_ARGS, VISION_PROMPT)]) (VISION_ENGINE_ARGS, VISION_PROMPT)])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_abort(monkeypatch, output_kind: RequestOutputKind, async def test_abort(monkeypatch: pytest.MonkeyPatch,
output_kind: RequestOutputKind,
engine_args_and_prompt: tuple[AsyncEngineArgs, engine_args_and_prompt: tuple[AsyncEngineArgs,
PromptType]): PromptType]):
......
...@@ -45,7 +45,7 @@ def make_request() -> EngineCoreRequest: ...@@ -45,7 +45,7 @@ def make_request() -> EngineCoreRequest:
@fork_new_process_for_each_test @fork_new_process_for_each_test
def test_engine_core(monkeypatch): def test_engine_core(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
...@@ -159,10 +159,10 @@ def test_engine_core(monkeypatch): ...@@ -159,10 +159,10 @@ def test_engine_core(monkeypatch):
@fork_new_process_for_each_test @fork_new_process_for_each_test
def test_engine_core_advanced_sampling(monkeypatch): def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch):
""" """
A basic end-to-end test to verify that the engine functions correctly A basic end-to-end test to verify that the engine functions correctly
when additional sampling parameters, such as top_p, min_tokens, and when additional sampling parameters, such as top_p, min_tokens, and
presence_penalty, are set. presence_penalty, are set.
""" """
with monkeypatch.context() as m: with monkeypatch.context() as m:
...@@ -209,7 +209,7 @@ def test_engine_core_advanced_sampling(monkeypatch): ...@@ -209,7 +209,7 @@ def test_engine_core_advanced_sampling(monkeypatch):
@fork_new_process_for_each_test @fork_new_process_for_each_test
def test_engine_core_concurrent_batches(monkeypatch): def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
""" """
Test that the engine can handle multiple concurrent batches. Test that the engine can handle multiple concurrent batches.
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment