Unverified Commit a73e183e authored by Sibi's avatar Sibi Committed by GitHub
Browse files

[Misc] Replace os environ to monkeypatch in test suite (#14516)


Signed-off-by: default avatarsibi <85477603+t-sibiraj@users.noreply.github.com>
Signed-off-by: default avatarAaron Pham <contact@aarnphm.xyz>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: default avatarAaron Pham <contact@aarnphm.xyz>
parent 1e799b7e
...@@ -12,11 +12,10 @@ import pytest ...@@ -12,11 +12,10 @@ import pytest
from tests.kernels.utils import override_backend_env_variable from tests.kernels.utils import override_backend_env_variable
from tests.quantization.utils import is_quant_method_supported from tests.quantization.utils import is_quant_method_supported
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import STR_BACKEND_ENV_VAR
from ...utils import check_logprobs_close from ...utils import check_logprobs_close
os.environ["TOKENIZERS_PARALLELISM"] = "true"
@pytest.mark.quant_model @pytest.mark.quant_model
@pytest.mark.skipif(not is_quant_method_supported("fp8"), @pytest.mark.skipif(not is_quant_method_supported("fp8"),
...@@ -55,13 +54,15 @@ def test_models( ...@@ -55,13 +54,15 @@ def test_models(
backend: str, backend: str,
tensor_parallel_size: int, tensor_parallel_size: int,
disable_async_output_proc: bool, disable_async_output_proc: bool,
monkeypatch, monkeypatch: pytest.MonkeyPatch,
) -> None: ) -> None:
""" """
Only checks log probs match to cover the discrepancy in Only checks log probs match to cover the discrepancy in
numerical sensitive kernels. numerical sensitive kernels.
""" """
override_backend_env_variable(monkeypatch, backend) with monkeypatch.context() as m:
m.setenv("TOKENIZERS_PARALLELISM", 'true')
m.setenv(STR_BACKEND_ENV_VAR, backend)
MAX_MODEL_LEN = 1024 MAX_MODEL_LEN = 1024
NUM_LOG_PROBS = 8 NUM_LOG_PROBS = 8
...@@ -119,11 +120,14 @@ def test_cpu_models( ...@@ -119,11 +120,14 @@ def test_cpu_models(
test_model: str, test_model: str,
max_tokens: int, max_tokens: int,
disable_async_output_proc: bool, disable_async_output_proc: bool,
monkeypatch: pytest.MonkeyPatch,
) -> None: ) -> None:
""" """
Only checks log probs match to cover the discrepancy in Only checks log probs match to cover the discrepancy in
numerical sensitive kernels. numerical sensitive kernels.
""" """
with monkeypatch.context() as m:
m.setenv("TOKENIZERS_PARALLELISM", 'true')
MAX_MODEL_LEN = 1024 MAX_MODEL_LEN = 1024
NUM_LOG_PROBS = 8 NUM_LOG_PROBS = 8
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import importlib.util import importlib.util
import math import math
...@@ -11,6 +12,7 @@ from scipy.spatial.distance import cosine ...@@ -11,6 +12,7 @@ from scipy.spatial.distance import cosine
import vllm import vllm
import vllm.config import vllm.config
from vllm.utils import STR_BACKEND_ENV_VAR
from ....utils import RemoteOpenAIServer from ....utils import RemoteOpenAIServer
...@@ -29,9 +31,10 @@ def _arr(arr): ...@@ -29,9 +31,10 @@ def _arr(arr):
return array("i", arr) return array("i", arr)
def test_find_array(monkeypatch): def test_find_array(monkeypatch: pytest.MonkeyPatch):
# GritLM embedding implementation is only supported by XFormers backend. # GritLM embedding implementation is only supported by XFormers backend.
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS") with monkeypatch.context() as m:
m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS")
from vllm.model_executor.models.gritlm import GritLMPooler from vllm.model_executor.models.gritlm import GritLMPooler
...@@ -53,9 +56,6 @@ def test_find_array(monkeypatch): ...@@ -53,9 +56,6 @@ def test_find_array(monkeypatch):
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server_embedding(): def server_embedding():
# GritLM embedding implementation is only supported by XFormers backend. # GritLM embedding implementation is only supported by XFormers backend.
with pytest.MonkeyPatch.context() as mp:
mp.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS")
args = ["--task", "embed", "--max_model_len", str(MAX_MODEL_LEN)] args = ["--task", "embed", "--max_model_len", str(MAX_MODEL_LEN)]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server yield remote_server
...@@ -69,7 +69,10 @@ def server_generate(): ...@@ -69,7 +69,10 @@ def server_generate():
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def client_embedding(server_embedding: RemoteOpenAIServer): async def client_embedding(monkeypatch: pytest.MonkeyPatch,
server_embedding: RemoteOpenAIServer):
with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS")
async with server_embedding.get_async_client() as async_client: async with server_embedding.get_async_client() as async_client:
yield async_client yield async_client
...@@ -80,14 +83,20 @@ async def client_generate(server_generate: RemoteOpenAIServer): ...@@ -80,14 +83,20 @@ async def client_generate(server_generate: RemoteOpenAIServer):
yield async_client yield async_client
def run_llm_encode(llm: vllm.LLM, queries: list[str], def run_llm_encode(
instruction: str) -> list[float]: llm: vllm.LLM,
queries: list[str],
instruction: str,
) -> list[float]:
outputs = llm.encode([instruction + q for q in queries], ) outputs = llm.encode([instruction + q for q in queries], )
return [output.outputs.embedding for output in outputs] return [output.outputs.embedding for output in outputs]
async def run_client_embeddings(client: vllm.LLM, queries: list[str], async def run_client_embeddings(
instruction: str) -> list[float]: client: vllm.LLM,
queries: list[str],
instruction: str,
) -> list[float]:
outputs = await client.embeddings.create( outputs = await client.embeddings.create(
model=MODEL_NAME, model=MODEL_NAME,
input=[instruction + q for q in queries], input=[instruction + q for q in queries],
...@@ -106,7 +115,7 @@ def get_test_data(): ...@@ -106,7 +115,7 @@ def get_test_data():
README.md in https://github.com/ContextualAI/gritlm README.md in https://github.com/ContextualAI/gritlm
""" """
q_instruction = gritlm_instruction( q_instruction = gritlm_instruction(
"Given a scientific paper title, retrieve the paper's abstract") "Given a scientific paper title, retrieve the paper's abstract", )
queries = [ queries = [
"Bitcoin: A Peer-to-Peer Electronic Cash System", "Bitcoin: A Peer-to-Peer Electronic Cash System",
"Generative Representational Instruction Tuning", "Generative Representational Instruction Tuning",
...@@ -136,9 +145,10 @@ def validate_embed_output(q_rep: list[float], d_rep: list[float]): ...@@ -136,9 +145,10 @@ def validate_embed_output(q_rep: list[float], d_rep: list[float]):
assert math.isclose(cosine_sim_q1_d1, 0.532, abs_tol=0.001) assert math.isclose(cosine_sim_q1_d1, 0.532, abs_tol=0.001)
def test_gritlm_offline_embedding(monkeypatch): def test_gritlm_offline_embedding(monkeypatch: pytest.MonkeyPatch):
# GritLM embedding implementation is only supported by XFormers backend. # GritLM embedding implementation is only supported by XFormers backend.
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS") with monkeypatch.context() as m:
m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS")
queries, q_instruction, documents, d_instruction = get_test_data() queries, q_instruction, documents, d_instruction = get_test_data()
...@@ -160,7 +170,7 @@ def test_gritlm_offline_embedding(monkeypatch): ...@@ -160,7 +170,7 @@ def test_gritlm_offline_embedding(monkeypatch):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_gritlm_api_server_embedding( async def test_gritlm_api_server_embedding(
client_embedding: openai.AsyncOpenAI): client_embedding: openai.AsyncOpenAI, ):
queries, q_instruction, documents, d_instruction = get_test_data() queries, q_instruction, documents, d_instruction = get_test_data()
d_rep = await run_client_embeddings( d_rep = await run_client_embeddings(
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os
import pytest import pytest
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
...@@ -11,20 +9,28 @@ from ..utils import fork_new_process_for_each_test ...@@ -11,20 +9,28 @@ from ..utils import fork_new_process_for_each_test
@fork_new_process_for_each_test @fork_new_process_for_each_test
def test_plugin(dummy_opt_path, monkeypatch): def test_plugin(
monkeypatch: pytest.MonkeyPatch,
dummy_opt_path: str,
):
# V1 shuts down rather than raising an error here. # V1 shuts down rather than raising an error here.
monkeypatch.setenv("VLLM_USE_V1", "0") with monkeypatch.context() as m:
os.environ["VLLM_PLUGINS"] = "" m.setenv("VLLM_USE_V1", "0")
m.setenv("VLLM_PLUGINS", "")
with pytest.raises(Exception) as excinfo: with pytest.raises(Exception) as excinfo:
LLM(model=dummy_opt_path, load_format="dummy") LLM(model=dummy_opt_path, load_format="dummy")
error_msg = "has no vLLM implementation and " \ error_msg = "has no vLLM implementation and the Transformers implementation is not compatible with vLLM" # noqa: E501
"the Transformers implementation is not compatible with vLLM"
assert (error_msg in str(excinfo.value)) assert (error_msg in str(excinfo.value))
@fork_new_process_for_each_test @fork_new_process_for_each_test
def test_oot_registration_text_generation(dummy_opt_path): def test_oot_registration_text_generation(
os.environ["VLLM_PLUGINS"] = "register_dummy_model" monkeypatch: pytest.MonkeyPatch,
dummy_opt_path: str,
):
with monkeypatch.context() as m:
m.setenv("VLLM_PLUGINS", "register_dummy_model")
prompts = ["Hello, my name is", "The text does not matter"] prompts = ["Hello, my name is", "The text does not matter"]
sampling_params = SamplingParams(temperature=0) sampling_params = SamplingParams(temperature=0)
llm = LLM(model=dummy_opt_path, load_format="dummy") llm = LLM(model=dummy_opt_path, load_format="dummy")
...@@ -39,8 +45,12 @@ def test_oot_registration_text_generation(dummy_opt_path): ...@@ -39,8 +45,12 @@ def test_oot_registration_text_generation(dummy_opt_path):
@fork_new_process_for_each_test @fork_new_process_for_each_test
def test_oot_registration_embedding(dummy_gemma2_embedding_path): def test_oot_registration_embedding(
os.environ["VLLM_PLUGINS"] = "register_dummy_model" monkeypatch: pytest.MonkeyPatch,
dummy_gemma2_embedding_path: str,
):
with monkeypatch.context() as m:
m.setenv("VLLM_PLUGINS", "register_dummy_model")
prompts = ["Hello, my name is", "The text does not matter"] prompts = ["Hello, my name is", "The text does not matter"]
llm = LLM(model=dummy_gemma2_embedding_path, load_format="dummy") llm = LLM(model=dummy_gemma2_embedding_path, load_format="dummy")
outputs = llm.embed(prompts) outputs = llm.embed(prompts)
...@@ -53,8 +63,12 @@ image = ImageAsset("cherry_blossom").pil_image.convert("RGB") ...@@ -53,8 +63,12 @@ image = ImageAsset("cherry_blossom").pil_image.convert("RGB")
@fork_new_process_for_each_test @fork_new_process_for_each_test
def test_oot_registration_multimodal(dummy_llava_path, monkeypatch): def test_oot_registration_multimodal(
os.environ["VLLM_PLUGINS"] = "register_dummy_model" monkeypatch: pytest.MonkeyPatch,
dummy_llava_path: str,
):
with monkeypatch.context() as m:
m.setenv("VLLM_PLUGINS", "register_dummy_model")
prompts = [{ prompts = [{
"prompt": "What's in the image?<image>", "prompt": "What's in the image?<image>",
"multi_modal_data": { "multi_modal_data": {
......
...@@ -235,9 +235,11 @@ async def test_bad_request(tmp_socket): ...@@ -235,9 +235,11 @@ async def test_bad_request(tmp_socket):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_mp_crash_detection(monkeypatch): async def test_mp_crash_detection(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.") parser = FlexibleArgumentParser(
description="vLLM's remote OpenAI server.")
parser = make_arg_parser(parser) parser = make_arg_parser(parser)
args = parser.parse_args([]) args = parser.parse_args([])
...@@ -245,14 +247,15 @@ async def test_mp_crash_detection(monkeypatch): ...@@ -245,14 +247,15 @@ async def test_mp_crash_detection(monkeypatch):
def mock_init(): def mock_init():
raise ValueError raise ValueError
monkeypatch.setattr(LLMEngine, "__init__", mock_init) m.setattr(LLMEngine, "__init__", mock_init)
start = time.perf_counter() start = time.perf_counter()
async with build_async_engine_client(args): async with build_async_engine_client(args):
pass pass
end = time.perf_counter() end = time.perf_counter()
assert end - start < 60, ("Expected vLLM to gracefully shutdown in <60s " assert end - start < 60, (
"Expected vLLM to gracefully shutdown in <60s "
"if there is an error in the startup.") "if there is an error in the startup.")
......
...@@ -5,7 +5,7 @@ from typing import Optional ...@@ -5,7 +5,7 @@ from typing import Optional
import pytest import pytest
from tests.kernels.utils import override_backend_env_variable from vllm.utils import STR_BACKEND_ENV_VAR
from ..models.utils import check_logprobs_close from ..models.utils import check_logprobs_close
from ..utils import (completions_with_server_args, get_client_text_generations, from ..utils import (completions_with_server_args, get_client_text_generations,
...@@ -52,7 +52,7 @@ async def test_multi_step( ...@@ -52,7 +52,7 @@ async def test_multi_step(
num_logprobs: Optional[int], num_logprobs: Optional[int],
attention_backend: str, attention_backend: str,
enable_chunked_prefill: bool, enable_chunked_prefill: bool,
monkeypatch, monkeypatch: pytest.MonkeyPatch,
) -> None: ) -> None:
"""Test vLLM engine with multi-step scheduling in an OpenAI-protocol """Test vLLM engine with multi-step scheduling in an OpenAI-protocol
client/server environment. client/server environment.
...@@ -82,7 +82,8 @@ async def test_multi_step( ...@@ -82,7 +82,8 @@ async def test_multi_step(
pytest.skip("Multi-step with Chunked-Prefill only supports" pytest.skip("Multi-step with Chunked-Prefill only supports"
"PP=1 and FLASH_ATTN backend") "PP=1 and FLASH_ATTN backend")
override_backend_env_variable(monkeypatch, attention_backend) with monkeypatch.context() as m:
m.setenv(STR_BACKEND_ENV_VAR, attention_backend)
prompts = example_prompts prompts = example_prompts
if len(prompts) < num_prompts: if len(prompts) < num_prompts:
...@@ -135,8 +136,10 @@ async def test_multi_step( ...@@ -135,8 +136,10 @@ async def test_multi_step(
# Assert multi-step scheduling produces nearly-identical logprobs # Assert multi-step scheduling produces nearly-identical logprobs
# to single-step scheduling. # to single-step scheduling.
ref_text_logprobs = get_client_text_logprob_generations(ref_completions) ref_text_logprobs = get_client_text_logprob_generations(
test_text_logprobs = get_client_text_logprob_generations(test_completions) ref_completions)
test_text_logprobs = get_client_text_logprob_generations(
test_completions)
check_logprobs_close( check_logprobs_close(
outputs_0_lst=ref_text_logprobs, outputs_0_lst=ref_text_logprobs,
outputs_1_lst=test_text_logprobs, outputs_1_lst=test_text_logprobs,
...@@ -152,7 +155,7 @@ async def test_multi_step( ...@@ -152,7 +155,7 @@ async def test_multi_step(
async def test_multi_step_pp_smoke( async def test_multi_step_pp_smoke(
tp_size: int, tp_size: int,
pp_size: int, pp_size: int,
monkeypatch, monkeypatch: pytest.MonkeyPatch,
) -> None: ) -> None:
""" """
Smoke test for the vLLM engine with multi-step scheduling in an Smoke test for the vLLM engine with multi-step scheduling in an
...@@ -174,7 +177,8 @@ async def test_multi_step_pp_smoke( ...@@ -174,7 +177,8 @@ async def test_multi_step_pp_smoke(
attention_backend = "FLASH_ATTN" attention_backend = "FLASH_ATTN"
max_num_seqs = 3 max_num_seqs = 3
override_backend_env_variable(monkeypatch, attention_backend) with monkeypatch.context() as m:
m.setenv(STR_BACKEND_ENV_VAR, attention_backend)
# Prompt from the ShareGPT dataset # Prompt from the ShareGPT dataset
prompts = [ prompts = [
......
...@@ -7,7 +7,7 @@ from typing import Optional ...@@ -7,7 +7,7 @@ from typing import Optional
import pytest import pytest
from tests.kernels.utils import override_backend_env_variable from vllm.utils import STR_BACKEND_ENV_VAR
from ..models.utils import check_logprobs_close, check_outputs_equal from ..models.utils import check_logprobs_close, check_outputs_equal
...@@ -42,7 +42,7 @@ def test_multi_step_llm( ...@@ -42,7 +42,7 @@ def test_multi_step_llm(
num_prompts: int, num_prompts: int,
num_logprobs: Optional[int], num_logprobs: Optional[int],
attention_backend: str, attention_backend: str,
monkeypatch, monkeypatch: pytest.MonkeyPatch,
) -> None: ) -> None:
"""Test vLLM engine with multi-step scheduling via sync LLM Engine. """Test vLLM engine with multi-step scheduling via sync LLM Engine.
...@@ -70,7 +70,8 @@ def test_multi_step_llm( ...@@ -70,7 +70,8 @@ def test_multi_step_llm(
num_logprobs: corresponds to the `logprobs` argument to the OpenAI num_logprobs: corresponds to the `logprobs` argument to the OpenAI
completions endpoint; `None` -> 1 logprob returned. completions endpoint; `None` -> 1 logprob returned.
""" """
override_backend_env_variable(monkeypatch, attention_backend) with monkeypatch.context() as m:
m.setenv(STR_BACKEND_ENV_VAR, attention_backend)
prompts = example_prompts prompts = example_prompts
if len(prompts) < num_prompts: if len(prompts) < num_prompts:
...@@ -136,7 +137,7 @@ def test_multi_step_llm_w_prompt_logprobs( ...@@ -136,7 +137,7 @@ def test_multi_step_llm_w_prompt_logprobs(
num_logprobs: Optional[int], num_logprobs: Optional[int],
num_prompt_logprobs: Optional[int], num_prompt_logprobs: Optional[int],
attention_backend: str, attention_backend: str,
monkeypatch, monkeypatch: pytest.MonkeyPatch,
) -> None: ) -> None:
"""Test prompt logprobs with multi-step scheduling via sync LLM Engine. """Test prompt logprobs with multi-step scheduling via sync LLM Engine.
...@@ -166,7 +167,8 @@ def test_multi_step_llm_w_prompt_logprobs( ...@@ -166,7 +167,8 @@ def test_multi_step_llm_w_prompt_logprobs(
note that this argument is not supported by the note that this argument is not supported by the
OpenAI completions endpoint. OpenAI completions endpoint.
""" """
override_backend_env_variable(monkeypatch, attention_backend) with monkeypatch.context() as m:
m.setenv(STR_BACKEND_ENV_VAR, attention_backend)
prompts = example_prompts prompts = example_prompts
if len(prompts) < num_prompts: if len(prompts) < num_prompts:
...@@ -230,7 +232,7 @@ def test_multi_step_llm_chunked_prefill_prefix_cache( ...@@ -230,7 +232,7 @@ def test_multi_step_llm_chunked_prefill_prefix_cache(
num_prompts: int, num_prompts: int,
num_logprobs: Optional[int], num_logprobs: Optional[int],
attention_backend: str, attention_backend: str,
monkeypatch, monkeypatch: pytest.MonkeyPatch,
) -> None: ) -> None:
"""Test vLLM engine with multi-step+"single-step chunked prefill"+APC. """Test vLLM engine with multi-step+"single-step chunked prefill"+APC.
...@@ -293,13 +295,14 @@ def test_multi_step_llm_chunked_prefill_prefix_cache( ...@@ -293,13 +295,14 @@ def test_multi_step_llm_chunked_prefill_prefix_cache(
# #
# The Incorrect scheduling behavior - if it occurs - will cause an exception # The Incorrect scheduling behavior - if it occurs - will cause an exception
# in the model runner resulting from `do_sample=False`. # in the model runner resulting from `do_sample=False`.
override_backend_env_variable(monkeypatch, attention_backend) with monkeypatch.context() as m:
m.setenv(STR_BACKEND_ENV_VAR, attention_backend)
assert len(example_prompts) >= 2 assert len(example_prompts) >= 2
challenge_prompts = copy.deepcopy(example_prompts) challenge_prompts = copy.deepcopy(example_prompts)
challenge_prompts[0] = ('vLLM is a high-throughput and memory-efficient ' challenge_prompts[0] = (
'inference and serving engine for LLMs.\n' 'vLLM is a high-throughput and memory-efficient '
) # 24 tok 'inference and serving engine for LLMs.\n') # 24 tok
challenge_prompts[1] = ( challenge_prompts[1] = (
'Briefly describe the major milestones in the ' 'Briefly describe the major milestones in the '
'development of artificial intelligence from 1950 to 2020.\n' 'development of artificial intelligence from 1950 to 2020.\n'
...@@ -326,9 +329,9 @@ def test_multi_step_llm_chunked_prefill_prefix_cache( ...@@ -326,9 +329,9 @@ def test_multi_step_llm_chunked_prefill_prefix_cache(
max_num_seqs=4, max_num_seqs=4,
block_size=16, block_size=16,
) as vllm_model: ) as vllm_model:
outputs_baseline = (vllm_model.generate_greedy( outputs_baseline = (
challenge_prompts, max_tokens) if num_logprobs is None else vllm_model.generate_greedy(challenge_prompts, max_tokens) if
vllm_model.generate_greedy_logprobs( num_logprobs is None else vllm_model.generate_greedy_logprobs(
challenge_prompts, max_tokens, num_logprobs)) challenge_prompts, max_tokens, num_logprobs))
# multi-step+"single-step chunked prefill"+APC # multi-step+"single-step chunked prefill"+APC
...@@ -346,9 +349,9 @@ def test_multi_step_llm_chunked_prefill_prefix_cache( ...@@ -346,9 +349,9 @@ def test_multi_step_llm_chunked_prefill_prefix_cache(
max_num_seqs=4, max_num_seqs=4,
block_size=16, block_size=16,
) as vllm_model: ) as vllm_model:
outputs_w_features = (vllm_model.generate_greedy( outputs_w_features = (
challenge_prompts, max_tokens) if num_logprobs is None else vllm_model.generate_greedy(challenge_prompts, max_tokens) if
vllm_model.generate_greedy_logprobs( num_logprobs is None else vllm_model.generate_greedy_logprobs(
challenge_prompts, max_tokens, num_logprobs)) challenge_prompts, max_tokens, num_logprobs))
if num_logprobs is None: if num_logprobs is None:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os
import neuronxcc.nki.language as nl import neuronxcc.nki.language as nl
import pytest import pytest
...@@ -99,6 +98,7 @@ def ref_block_tables_transform( ...@@ -99,6 +98,7 @@ def ref_block_tables_transform(
) )
@torch.inference_mode() @torch.inference_mode()
def test_load_and_transform_block_tables( def test_load_and_transform_block_tables(
monkeypatch: pytest.MonkeyPatch,
num_tiles, num_tiles,
num_blocks_per_tile, num_blocks_per_tile,
q_head_per_kv_head, q_head_per_kv_head,
...@@ -108,12 +108,12 @@ def test_load_and_transform_block_tables( ...@@ -108,12 +108,12 @@ def test_load_and_transform_block_tables(
device = xm.xla_device() device = xm.xla_device()
compiler_flags = [ compiler_flags_str = " ".join([
"-O1", "-O1",
"--retry_failed_compilation", "--retry_failed_compilation",
] ])
compiler_flags_str = " ".join(compiler_flags) with monkeypatch.context() as m:
os.environ["NEURON_CC_FLAGS"] = compiler_flags_str m.setenv("NEURON_CC_FLAGS", compiler_flags_str)
torch.manual_seed(10000) torch.manual_seed(10000)
torch.set_printoptions(sci_mode=False) torch.set_printoptions(sci_mode=False)
......
...@@ -320,6 +320,7 @@ def get_active_block_tables(block_tables, query_lens, seq_lens, block_size, ...@@ -320,6 +320,7 @@ def get_active_block_tables(block_tables, query_lens, seq_lens, block_size,
]) ])
@torch.inference_mode() @torch.inference_mode()
def test_contexted_kv_attention( def test_contexted_kv_attention(
monkeypatch: pytest.MonkeyPatch,
prefill_batch_size: int, prefill_batch_size: int,
decode_batch_size: int, decode_batch_size: int,
num_heads: int, num_heads: int,
...@@ -329,7 +330,6 @@ def test_contexted_kv_attention( ...@@ -329,7 +330,6 @@ def test_contexted_kv_attention(
large_tile_size, large_tile_size,
mixed_precision: bool, mixed_precision: bool,
) -> None: ) -> None:
import os
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
...@@ -340,12 +340,12 @@ def test_contexted_kv_attention( ...@@ -340,12 +340,12 @@ def test_contexted_kv_attention(
device = xm.xla_device() device = xm.xla_device()
compiler_flags = [ compiler_flags_str = " ".join([
"-O1", "-O1",
"--retry_failed_compilation", "--retry_failed_compilation",
] ])
compiler_flags_str = " ".join(compiler_flags) with monkeypatch.context() as m:
os.environ["NEURON_CC_FLAGS"] = compiler_flags_str m.setenv("NEURON_CC_FLAGS", compiler_flags_str)
torch.manual_seed(0) torch.manual_seed(0)
torch.set_printoptions(sci_mode=False) torch.set_printoptions(sci_mode=False)
...@@ -415,7 +415,8 @@ def test_contexted_kv_attention( ...@@ -415,7 +415,8 @@ def test_contexted_kv_attention(
num_active_blocks = pad_to_multiple(num_active_blocks, num_active_blocks = pad_to_multiple(num_active_blocks,
large_tile_size // block_size) large_tile_size // block_size)
context_kv_len = num_active_blocks * block_size context_kv_len = num_active_blocks * block_size
assert (context_kv_len % assert (
context_kv_len %
large_tile_size == 0), f"invalid context_kv_len={context_kv_len}" large_tile_size == 0), f"invalid context_kv_len={context_kv_len}"
# pad QKV tensors # pad QKV tensors
...@@ -476,9 +477,11 @@ def test_contexted_kv_attention( ...@@ -476,9 +477,11 @@ def test_contexted_kv_attention(
"constant", "constant",
0, 0,
).bool() ).bool()
attn_mask = torch.concat([prior_mask_padded, active_mask_padded], dim=1) attn_mask = torch.concat([prior_mask_padded, active_mask_padded],
dim=1)
attn_mask = reorder_context_mask(attn_mask, large_tile_size, block_size) attn_mask = reorder_context_mask(attn_mask, large_tile_size,
block_size)
input_args = ( input_args = (
query.to(device=device), query.to(device=device),
...@@ -508,6 +511,7 @@ def test_contexted_kv_attention( ...@@ -508,6 +511,7 @@ def test_contexted_kv_attention(
"constant", "constant",
0, 0,
) )
output_ref = output_ref_padded.transpose(0, 1)[0, :num_actual_tokens, :, :] output_ref = output_ref_padded.transpose(
0, 1)[0, :num_actual_tokens, :, :]
torch.testing.assert_close(output_nki, output_ref, atol=1e-2, rtol=0) torch.testing.assert_close(output_nki, output_ref, atol=1e-2, rtol=0)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import pytest
import torch import torch
from tests.kernels.utils import override_backend_env_variable
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.utils import STR_INVALID_VAL from vllm.utils import STR_BACKEND_ENV_VAR, STR_INVALID_VAL
def test_platform_plugins(): def test_platform_plugins():
...@@ -25,8 +25,9 @@ def test_platform_plugins(): ...@@ -25,8 +25,9 @@ def test_platform_plugins():
f" is loaded. The first import:\n{_init_trace}") f" is loaded. The first import:\n{_init_trace}")
def test_oot_attention_backend(monkeypatch): def test_oot_attention_backend(monkeypatch: pytest.MonkeyPatch):
# ignore the backend env variable if it is set # ignore the backend env variable if it is set
override_backend_env_variable(monkeypatch, STR_INVALID_VAL) with monkeypatch.context() as m:
m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL)
backend = get_attn_backend(16, torch.float16, torch.float16, 16, False) backend = get_attn_backend(16, torch.float16, torch.float16, 16, False)
assert backend.get_name() == "Dummy_Backend" assert backend.get_name() == "Dummy_Backend"
...@@ -22,8 +22,9 @@ class DummyV1Scheduler(V1Scheduler): ...@@ -22,8 +22,9 @@ class DummyV1Scheduler(V1Scheduler):
raise Exception("Exception raised by DummyV1Scheduler") raise Exception("Exception raised by DummyV1Scheduler")
def test_scheduler_plugins_v0(monkeypatch): def test_scheduler_plugins_v0(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_V1", "0") with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "0")
with pytest.raises(Exception) as exception_info: with pytest.raises(Exception) as exception_info:
engine_args = EngineArgs( engine_args = EngineArgs(
...@@ -38,14 +39,16 @@ def test_scheduler_plugins_v0(monkeypatch): ...@@ -38,14 +39,16 @@ def test_scheduler_plugins_v0(monkeypatch):
engine.add_request("0", "foo", sampling_params) engine.add_request("0", "foo", sampling_params)
engine.step() engine.step()
assert str(exception_info.value) == "Exception raised by DummyV0Scheduler" assert str(
exception_info.value) == "Exception raised by DummyV0Scheduler"
def test_scheduler_plugins_v1(monkeypatch): def test_scheduler_plugins_v1(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_V1", "1") with monkeypatch.context() as m:
# Explicitly turn off engine multiprocessing so that the scheduler runs in m.setenv("VLLM_USE_V1", "1")
# this process # Explicitly turn off engine multiprocessing so
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") # that the scheduler runs in this process
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
with pytest.raises(Exception) as exception_info: with pytest.raises(Exception) as exception_info:
...@@ -61,4 +64,5 @@ def test_scheduler_plugins_v1(monkeypatch): ...@@ -61,4 +64,5 @@ def test_scheduler_plugins_v1(monkeypatch):
engine.add_request("0", "foo", sampling_params) engine.add_request("0", "foo", sampling_params)
engine.step() engine.step()
assert str(exception_info.value) == "Exception raised by DummyV1Scheduler" assert str(
exception_info.value) == "Exception raised by DummyV1Scheduler"
...@@ -4,25 +4,29 @@ ...@@ -4,25 +4,29 @@
Run `pytest tests/prefix_caching/test_prefix_caching.py`. Run `pytest tests/prefix_caching/test_prefix_caching.py`.
""" """
from __future__ import annotations
import pytest import pytest
from tests.conftest import VllmRunner from tests.conftest import VllmRunner
from tests.core.utils import SchedulerProxy, create_dummy_prompt from tests.core.utils import SchedulerProxy, create_dummy_prompt
from tests.kernels.utils import override_backend_env_variable
from vllm import SamplingParams, TokensPrompt from vllm import SamplingParams, TokensPrompt
from vllm.core.scheduler import Scheduler from vllm.core.scheduler import Scheduler
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import STR_BACKEND_ENV_VAR
from ..models.utils import check_outputs_equal from ..models.utils import check_outputs_equal
@pytest.fixture(scope="function", autouse=True) @pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch): def use_v0_only(monkeypatch: pytest.MonkeyPatch):
""" """
This module relies on V0 internals, so set VLLM_USE_V1=0. This module relies on V0 internals, so set VLLM_USE_V1=0.
""" """
monkeypatch.setenv('VLLM_USE_V1', '0') with monkeypatch.context() as m:
m.setenv('VLLM_USE_V1', '0')
yield
MODELS = [ MODELS = [
...@@ -56,7 +60,7 @@ def test_mixed_requests( ...@@ -56,7 +60,7 @@ def test_mixed_requests(
cached_position: int, cached_position: int,
enable_chunked_prefill: bool, enable_chunked_prefill: bool,
block_size: int, block_size: int,
monkeypatch, monkeypatch: pytest.MonkeyPatch,
) -> None: ) -> None:
""" """
Test the case when some sequences have the prefix cache hit Test the case when some sequences have the prefix cache hit
...@@ -67,7 +71,8 @@ def test_mixed_requests( ...@@ -67,7 +71,8 @@ def test_mixed_requests(
pytest.skip("Flashinfer does not support ROCm/HIP.") pytest.skip("Flashinfer does not support ROCm/HIP.")
if backend == "XFORMERS" and current_platform.is_rocm(): if backend == "XFORMERS" and current_platform.is_rocm():
pytest.skip("Xformers does not support ROCm/HIP.") pytest.skip("Xformers does not support ROCm/HIP.")
override_backend_env_variable(monkeypatch, backend) with monkeypatch.context() as m:
m.setenv(STR_BACKEND_ENV_VAR, backend)
with hf_runner(model, dtype=dtype) as hf_model: with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
...@@ -81,11 +86,14 @@ def test_mixed_requests( ...@@ -81,11 +86,14 @@ def test_mixed_requests(
block_size=block_size, block_size=block_size,
) as vllm_model: ) as vllm_model:
# Run the first prompt so the cache is populated # Run the first prompt so the cache is populated
vllm_outputs = vllm_model.generate_greedy([cached_prompt], max_tokens) vllm_outputs = vllm_model.generate_greedy([cached_prompt],
max_tokens)
# Run all the promopts # Run all the promopts
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) greedy_params = SamplingParams(temperature=0.0,
req_outputs = vllm_model.model.generate(example_prompts, greedy_params) max_tokens=max_tokens)
req_outputs = vllm_model.model.generate(example_prompts,
greedy_params)
# Verify number of cached tokens # Verify number of cached tokens
for i in range(len(req_outputs)): for i in range(len(req_outputs)):
...@@ -95,8 +103,8 @@ def test_mixed_requests( ...@@ -95,8 +103,8 @@ def test_mixed_requests(
block_size) * block_size block_size) * block_size
else: else:
expected_num_cached_tokens = 0 expected_num_cached_tokens = 0
assert ( assert (req_outputs[i].num_cached_tokens ==
req_outputs[i].num_cached_tokens == expected_num_cached_tokens) expected_num_cached_tokens)
vllm_outputs = [( vllm_outputs = [(
output.prompt_token_ids + list(output.outputs[0].token_ids), output.prompt_token_ids + list(output.outputs[0].token_ids),
...@@ -115,14 +123,15 @@ def test_mixed_requests( ...@@ -115,14 +123,15 @@ def test_mixed_requests(
def test_unstable_prompt_sequence( def test_unstable_prompt_sequence(
vllm_runner, vllm_runner,
backend: str, backend: str,
monkeypatch, monkeypatch: pytest.MonkeyPatch,
) -> None: ) -> None:
if backend == "FLASHINFER" and current_platform.is_rocm(): if backend == "FLASHINFER" and current_platform.is_rocm():
pytest.skip("Flashinfer does not support ROCm/HIP.") pytest.skip("Flashinfer does not support ROCm/HIP.")
if backend == "XFORMERS" and current_platform.is_rocm(): if backend == "XFORMERS" and current_platform.is_rocm():
pytest.skip("Xformers does not support ROCm/HIP.") pytest.skip("Xformers does not support ROCm/HIP.")
override_backend_env_variable(monkeypatch, backend) with monkeypatch.context() as m:
m.setenv(STR_BACKEND_ENV_VAR, backend)
with vllm_runner( with vllm_runner(
"Qwen/Qwen2.5-0.5B-Instruct", "Qwen/Qwen2.5-0.5B-Instruct",
......
...@@ -56,12 +56,11 @@ def test_gc(): ...@@ -56,12 +56,11 @@ def test_gc():
assert allocated < 50 * 1024 * 1024 assert allocated < 50 * 1024 * 1024
def test_model_from_modelscope(monkeypatch): def test_model_from_modelscope(monkeypatch: pytest.MonkeyPatch):
# model: https://modelscope.cn/models/qwen/Qwen1.5-0.5B-Chat/summary # model: https://modelscope.cn/models/qwen/Qwen1.5-0.5B-Chat/summary
MODELSCOPE_MODEL_NAME = "qwen/Qwen1.5-0.5B-Chat" with monkeypatch.context() as m:
monkeypatch.setenv("VLLM_USE_MODELSCOPE", "True") m.setenv("VLLM_USE_MODELSCOPE", "True")
try: llm = LLM(model="qwen/Qwen1.5-0.5B-Chat")
llm = LLM(model=MODELSCOPE_MODEL_NAME)
prompts = [ prompts = [
"Hello, my name is", "Hello, my name is",
...@@ -73,10 +72,3 @@ def test_model_from_modelscope(monkeypatch): ...@@ -73,10 +72,3 @@ def test_model_from_modelscope(monkeypatch):
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
assert len(outputs) == 4 assert len(outputs) == 4
finally:
monkeypatch.delenv("VLLM_USE_MODELSCOPE", raising=False)
if __name__ == "__main__":
import pytest
pytest.main([__file__])
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# ruff: noqa
import asyncio import asyncio
import os
import socket import socket
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from unittest.mock import patch from unittest.mock import patch
...@@ -112,8 +112,9 @@ def test_deprecate_kwargs_additional_message(): ...@@ -112,8 +112,9 @@ def test_deprecate_kwargs_additional_message():
dummy(old_arg=1) dummy(old_arg=1)
def test_get_open_port(): def test_get_open_port(monkeypatch: pytest.MonkeyPatch):
os.environ["VLLM_PORT"] = "5678" with monkeypatch.context() as m:
m.setenv("VLLM_PORT", "5678")
# make sure we can get multiple ports, even if the env var is set # make sure we can get multiple ports, even if the env var is set
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s1: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s1:
s1.bind(("localhost", get_open_port())) s1.bind(("localhost", get_open_port()))
...@@ -121,7 +122,6 @@ def test_get_open_port(): ...@@ -121,7 +122,6 @@ def test_get_open_port():
s2.bind(("localhost", get_open_port())) s2.bind(("localhost", get_open_port()))
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s3: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s3:
s3.bind(("localhost", get_open_port())) s3.bind(("localhost", get_open_port()))
os.environ.pop("VLLM_PORT")
# Tests for FlexibleArgumentParser # Tests for FlexibleArgumentParser
...@@ -366,9 +366,10 @@ def test_bind_kv_cache_non_attention(): ...@@ -366,9 +366,10 @@ def test_bind_kv_cache_non_attention():
assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[1] assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[1]
def test_bind_kv_cache_encoder_decoder(monkeypatch): def test_bind_kv_cache_encoder_decoder(monkeypatch: pytest.MonkeyPatch):
# V1 TESTS: ENCODER_DECODER is not supported on V1 yet. # V1 TESTS: ENCODER_DECODER is not supported on V1 yet.
monkeypatch.setenv("VLLM_USE_V1", "0") with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "0")
from vllm.attention import Attention, AttentionType from vllm.attention import Attention, AttentionType
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os import pytest
from vllm.config import CompilationLevel from vllm.config import CompilationLevel
...@@ -9,10 +9,11 @@ from ..utils import compare_two_settings ...@@ -9,10 +9,11 @@ from ..utils import compare_two_settings
# --enforce-eager on TPU causes graph compilation # --enforce-eager on TPU causes graph compilation
# this times out default Health Check in the MQLLMEngine, # this times out default Health Check in the MQLLMEngine,
# so we set the timeout here to 30s # so we set the timeout here to 30s
os.environ["VLLM_RPC_TIMEOUT"] = "30000"
def test_custom_dispatcher(): def test_custom_dispatcher(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
m.setenv("VLLM_RPC_TIMEOUT", "30000")
compare_two_settings( compare_two_settings(
"google/gemma-2b", "google/gemma-2b",
arg1=[ arg1=[
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# ruff: noqa
# type: ignore
from __future__ import annotations
import os
import threading import threading
from collections.abc import Iterable from collections.abc import Iterable
from concurrent import futures from concurrent import futures
from typing import Callable, Literal from typing import Callable, Generator, Literal
import grpc import grpc
import pytest import pytest
...@@ -21,12 +23,14 @@ from vllm.tracing import SpanAttributes ...@@ -21,12 +23,14 @@ from vllm.tracing import SpanAttributes
@pytest.fixture(scope="function", autouse=True) @pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch): def use_v0_only(monkeypatch: pytest.MonkeyPatch):
""" """
Since this module is V0 only, set VLLM_USE_V1=0 for Since this module is V0 only, set VLLM_USE_V1=0 for
all tests in the module. all tests in the module.
""" """
monkeypatch.setenv('VLLM_USE_V1', '0') with monkeypatch.context() as m:
m.setenv('VLLM_USE_V1', '0')
yield
FAKE_TRACE_SERVER_ADDRESS = "localhost:4317" FAKE_TRACE_SERVER_ADDRESS = "localhost:4317"
...@@ -67,7 +71,7 @@ class FakeTraceService(TraceServiceServicer): ...@@ -67,7 +71,7 @@ class FakeTraceService(TraceServiceServicer):
@pytest.fixture @pytest.fixture
def trace_service(): def trace_service() -> Generator[FakeTraceService, None, None]:
"""Fixture to set up a fake gRPC trace service""" """Fixture to set up a fake gRPC trace service"""
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
service = FakeTraceService() service = FakeTraceService()
...@@ -80,12 +84,18 @@ def trace_service(): ...@@ -80,12 +84,18 @@ def trace_service():
server.stop(None) server.stop(None)
def test_traces(trace_service): def test_traces(
os.environ[OTEL_EXPORTER_OTLP_TRACES_INSECURE] = "true" monkeypatch: pytest.MonkeyPatch,
trace_service: FakeTraceService,
):
with monkeypatch.context() as m:
m.setenv(OTEL_EXPORTER_OTLP_TRACES_INSECURE, "true")
sampling_params = SamplingParams(temperature=0.01, sampling_params = SamplingParams(
temperature=0.01,
top_p=0.1, top_p=0.1,
max_tokens=256) max_tokens=256,
)
model = "facebook/opt-125m" model = "facebook/opt-125m"
llm = LLM( llm = LLM(
model=model, model=model,
...@@ -120,17 +130,19 @@ def test_traces(trace_service): ...@@ -120,17 +130,19 @@ def test_traces(trace_service):
) == sampling_params.temperature ) == sampling_params.temperature
assert attributes.get( assert attributes.get(
SpanAttributes.GEN_AI_REQUEST_TOP_P) == sampling_params.top_p SpanAttributes.GEN_AI_REQUEST_TOP_P) == sampling_params.top_p
assert attributes.get(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS
) == sampling_params.max_tokens
assert attributes.get(
SpanAttributes.GEN_AI_REQUEST_N) == sampling_params.n
assert attributes.get( assert attributes.get(
SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS) == sampling_params.max_tokens SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS) == len(
assert attributes.get(SpanAttributes.GEN_AI_REQUEST_N) == sampling_params.n
assert attributes.get(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS) == len(
outputs[0].prompt_token_ids) outputs[0].prompt_token_ids)
completion_tokens = sum(len(o.token_ids) for o in outputs[0].outputs) completion_tokens = sum(len(o.token_ids) for o in outputs[0].outputs)
assert attributes.get( assert attributes.get(
SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS) == completion_tokens SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS) == completion_tokens
metrics = outputs[0].metrics metrics = outputs[0].metrics
assert attributes.get( assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE
SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE) == metrics.time_in_queue ) == metrics.time_in_queue
ttft = metrics.first_token_time - metrics.arrival_time ttft = metrics.first_token_time - metrics.arrival_time
assert attributes.get( assert attributes.get(
SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN) == ttft SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN) == ttft
...@@ -145,12 +157,18 @@ def test_traces(trace_service): ...@@ -145,12 +157,18 @@ def test_traces(trace_service):
assert metrics.model_execute_time is None assert metrics.model_execute_time is None
def test_traces_with_detailed_steps(trace_service): def test_traces_with_detailed_steps(
os.environ[OTEL_EXPORTER_OTLP_TRACES_INSECURE] = "true" monkeypatch: pytest.MonkeyPatch,
trace_service: FakeTraceService,
):
with monkeypatch.context() as m:
m.setenv(OTEL_EXPORTER_OTLP_TRACES_INSECURE, "true")
sampling_params = SamplingParams(temperature=0.01, sampling_params = SamplingParams(
temperature=0.01,
top_p=0.1, top_p=0.1,
max_tokens=256) max_tokens=256,
)
model = "facebook/opt-125m" model = "facebook/opt-125m"
llm = LLM( llm = LLM(
model=model, model=model,
...@@ -186,17 +204,19 @@ def test_traces_with_detailed_steps(trace_service): ...@@ -186,17 +204,19 @@ def test_traces_with_detailed_steps(trace_service):
) == sampling_params.temperature ) == sampling_params.temperature
assert attributes.get( assert attributes.get(
SpanAttributes.GEN_AI_REQUEST_TOP_P) == sampling_params.top_p SpanAttributes.GEN_AI_REQUEST_TOP_P) == sampling_params.top_p
assert attributes.get(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS
) == sampling_params.max_tokens
assert attributes.get( assert attributes.get(
SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS) == sampling_params.max_tokens SpanAttributes.GEN_AI_REQUEST_N) == sampling_params.n
assert attributes.get(SpanAttributes.GEN_AI_REQUEST_N) == sampling_params.n assert attributes.get(
assert attributes.get(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS) == len( SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS) == len(
outputs[0].prompt_token_ids) outputs[0].prompt_token_ids)
completion_tokens = sum(len(o.token_ids) for o in outputs[0].outputs) completion_tokens = sum(len(o.token_ids) for o in outputs[0].outputs)
assert attributes.get( assert attributes.get(
SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS) == completion_tokens SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS) == completion_tokens
metrics = outputs[0].metrics metrics = outputs[0].metrics
assert attributes.get( assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE
SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE) == metrics.time_in_queue ) == metrics.time_in_queue
ttft = metrics.first_token_time - metrics.arrival_time ttft = metrics.first_token_time - metrics.arrival_time
assert attributes.get( assert attributes.get(
SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN) == ttft SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN) == ttft
...@@ -207,9 +227,10 @@ def test_traces_with_detailed_steps(trace_service): ...@@ -207,9 +227,10 @@ def test_traces_with_detailed_steps(trace_service):
) == metrics.scheduler_time ) == metrics.scheduler_time
assert metrics.model_forward_time > 0 assert metrics.model_forward_time > 0
assert attributes.get( assert attributes.get(
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_FORWARD) == pytest.approx( SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_FORWARD
metrics.model_forward_time / 1000) ) == pytest.approx(metrics.model_forward_time / 1000)
assert metrics.model_execute_time > 0 assert metrics.model_execute_time > 0
assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE assert attributes.get(
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE
) == metrics.model_execute_time ) == metrics.model_execute_time
assert metrics.model_forward_time < 1000 * metrics.model_execute_time assert metrics.model_forward_time < 1000 * metrics.model_execute_time
...@@ -566,6 +566,7 @@ def init_test_distributed_environment( ...@@ -566,6 +566,7 @@ def init_test_distributed_environment(
def multi_process_parallel( def multi_process_parallel(
monkeypatch: pytest.MonkeyPatch,
tp_size: int, tp_size: int,
pp_size: int, pp_size: int,
test_target: Any, test_target: Any,
...@@ -582,7 +583,13 @@ def multi_process_parallel( ...@@ -582,7 +583,13 @@ def multi_process_parallel(
refs = [] refs = []
for rank in range(tp_size * pp_size): for rank in range(tp_size * pp_size):
refs.append( refs.append(
test_target.remote(tp_size, pp_size, rank, distributed_init_port)) test_target.remote(
monkeypatch,
tp_size,
pp_size,
rank,
distributed_init_port,
), )
ray.get(refs) ray.get(refs)
ray.shutdown() ray.shutdown()
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import random import random
from typing import Any
import pytest import pytest
...@@ -50,8 +53,12 @@ def model_name(): ...@@ -50,8 +53,12 @@ def model_name():
return "meta-llama/Meta-Llama-3-8B-Instruct" return "meta-llama/Meta-Llama-3-8B-Instruct"
def test_ngram_correctness(monkeypatch, test_prompts, sampling_config, def test_ngram_correctness(
model_name): monkeypatch: pytest.MonkeyPatch,
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_name: str,
):
''' '''
Compare the outputs of a original LLM and a speculative LLM Compare the outputs of a original LLM and a speculative LLM
should be the same when using ngram speculative decoding. should be the same when using ngram speculative decoding.
......
...@@ -80,9 +80,11 @@ async def generate(engine: AsyncLLM, ...@@ -80,9 +80,11 @@ async def generate(engine: AsyncLLM,
[(TEXT_ENGINE_ARGS, TEXT_PROMPT), [(TEXT_ENGINE_ARGS, TEXT_PROMPT),
(VISION_ENGINE_ARGS, VISION_PROMPT)]) (VISION_ENGINE_ARGS, VISION_PROMPT)])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_load(monkeypatch, output_kind: RequestOutputKind, async def test_load(
engine_args_and_prompt: tuple[AsyncEngineArgs, monkeypatch: pytest.MonkeyPatch,
PromptType]): output_kind: RequestOutputKind,
engine_args_and_prompt: tuple[AsyncEngineArgs, PromptType],
):
# TODO(rickyx): Remove monkeypatch once we have a better way to test V1 # TODO(rickyx): Remove monkeypatch once we have a better way to test V1
# so that in the future when we switch, we don't have to change all the # so that in the future when we switch, we don't have to change all the
# tests. # tests.
...@@ -126,7 +128,8 @@ async def test_load(monkeypatch, output_kind: RequestOutputKind, ...@@ -126,7 +128,8 @@ async def test_load(monkeypatch, output_kind: RequestOutputKind,
[(TEXT_ENGINE_ARGS, TEXT_PROMPT), [(TEXT_ENGINE_ARGS, TEXT_PROMPT),
(VISION_ENGINE_ARGS, VISION_PROMPT)]) (VISION_ENGINE_ARGS, VISION_PROMPT)])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_abort(monkeypatch, output_kind: RequestOutputKind, async def test_abort(monkeypatch: pytest.MonkeyPatch,
output_kind: RequestOutputKind,
engine_args_and_prompt: tuple[AsyncEngineArgs, engine_args_and_prompt: tuple[AsyncEngineArgs,
PromptType]): PromptType]):
......
...@@ -45,7 +45,7 @@ def make_request() -> EngineCoreRequest: ...@@ -45,7 +45,7 @@ def make_request() -> EngineCoreRequest:
@fork_new_process_for_each_test @fork_new_process_for_each_test
def test_engine_core(monkeypatch): def test_engine_core(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
...@@ -159,7 +159,7 @@ def test_engine_core(monkeypatch): ...@@ -159,7 +159,7 @@ def test_engine_core(monkeypatch):
@fork_new_process_for_each_test @fork_new_process_for_each_test
def test_engine_core_advanced_sampling(monkeypatch): def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch):
""" """
A basic end-to-end test to verify that the engine functions correctly A basic end-to-end test to verify that the engine functions correctly
when additional sampling parameters, such as top_p, min_tokens, and when additional sampling parameters, such as top_p, min_tokens, and
...@@ -209,7 +209,7 @@ def test_engine_core_advanced_sampling(monkeypatch): ...@@ -209,7 +209,7 @@ def test_engine_core_advanced_sampling(monkeypatch):
@fork_new_process_for_each_test @fork_new_process_for_each_test
def test_engine_core_concurrent_batches(monkeypatch): def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
""" """
Test that the engine can handle multiple concurrent batches. Test that the engine can handle multiple concurrent batches.
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment