Unverified Commit a73e183e authored by Sibi's avatar Sibi Committed by GitHub
Browse files

[Misc] Replace os environ to monkeypatch in test suite (#14516)


Signed-off-by: default avatarsibi <85477603+t-sibiraj@users.noreply.github.com>
Signed-off-by: default avatarAaron Pham <contact@aarnphm.xyz>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: default avatarAaron Pham <contact@aarnphm.xyz>
parent 1e799b7e
...@@ -90,7 +90,8 @@ def echo(self, msg: str, err_msg: Optional[str] = None) -> str: ...@@ -90,7 +90,8 @@ def echo(self, msg: str, err_msg: Optional[str] = None) -> str:
@fork_new_process_for_each_test @fork_new_process_for_each_test
@pytest.mark.parametrize("multiprocessing_mode", [True, False]) @pytest.mark.parametrize("multiprocessing_mode", [True, False])
def test_engine_core_client(monkeypatch, multiprocessing_mode: bool): def test_engine_core_client(monkeypatch: pytest.MonkeyPatch,
multiprocessing_mode: bool):
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
...@@ -175,7 +176,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool): ...@@ -175,7 +176,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
@pytest.mark.asyncio(loop_scope="function") @pytest.mark.asyncio(loop_scope="function")
async def test_engine_core_client_asyncio(monkeypatch): async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
......
...@@ -255,12 +255,10 @@ def _run_and_validate( ...@@ -255,12 +255,10 @@ def _run_and_validate(
[NONE, SAMPLE, PROMPT, SAMPLE_PROMPT]) [NONE, SAMPLE, PROMPT, SAMPLE_PROMPT])
@pytest.mark.parametrize("temperature", [0.0, 2.0]) @pytest.mark.parametrize("temperature", [0.0, 2.0])
def test_get_logprobs_and_prompt_logprobs( def test_get_logprobs_and_prompt_logprobs(
hf_model, hf_model, vllm_model,
vllm_model,
batch_logprobs_composition: BatchLogprobsComposition, batch_logprobs_composition: BatchLogprobsComposition,
temperature: float, temperature: float, example_prompts: list[str],
example_prompts, monkeypatch: pytest.MonkeyPatch) -> None:
) -> None:
"""Test V1 Engine logprobs & prompt logprobs """Test V1 Engine logprobs & prompt logprobs
Exercise a variety of combinations of `logprobs` and `prompt_logprobs` Exercise a variety of combinations of `logprobs` and `prompt_logprobs`
...@@ -287,6 +285,8 @@ def test_get_logprobs_and_prompt_logprobs( ...@@ -287,6 +285,8 @@ def test_get_logprobs_and_prompt_logprobs(
temperature: "temperature" sampling parameter temperature: "temperature" sampling parameter
example_prompts: example prompt fixture example_prompts: example prompt fixture
""" """
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
do_apc = vllm_model.model.llm_engine.cache_config.enable_prefix_caching do_apc = vllm_model.model.llm_engine.cache_config.enable_prefix_caching
if do_apc and (temperature < 2.0 if do_apc and (temperature < 2.0
or batch_logprobs_composition != SAMPLE_PROMPT): or batch_logprobs_composition != SAMPLE_PROMPT):
...@@ -306,7 +306,8 @@ def test_get_logprobs_and_prompt_logprobs( ...@@ -306,7 +306,8 @@ def test_get_logprobs_and_prompt_logprobs(
# Batch has mixed sample params # Batch has mixed sample params
# (different logprobs/prompt logprobs combos) # (different logprobs/prompt logprobs combos)
logprob_prompt_logprob_list = get_test_batch(batch_logprobs_composition) logprob_prompt_logprob_list = get_test_batch(
batch_logprobs_composition)
# Ensure that each test prompt has a logprob config for testing # Ensure that each test prompt has a logprob config for testing
logprob_prompt_logprob_list = _repeat_logprob_config( logprob_prompt_logprob_list = _repeat_logprob_config(
...@@ -333,16 +334,13 @@ def test_get_logprobs_and_prompt_logprobs( ...@@ -333,16 +334,13 @@ def test_get_logprobs_and_prompt_logprobs(
do_apc=do_apc) do_apc=do_apc)
def test_max_logprobs(): def test_max_logprobs(monkeypatch: pytest.MonkeyPatch):
"""vLLM v1 engine should fail a request with `logprobs > max_logprobs` """vLLM v1 engine should fail a request with `logprobs > max_logprobs`
Should also fail for `prompt_logprobs > max_logprobs` Should also fail for `prompt_logprobs > max_logprobs`
APC should not matter as this test checks basic request validation. APC should not matter as this test checks basic request validation.
Args:
monkeypatch
""" """
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
runner = VllmRunner("facebook/opt-125m", runner = VllmRunner("facebook/opt-125m",
max_logprobs=1, max_logprobs=1,
...@@ -354,40 +352,52 @@ def test_max_logprobs(): ...@@ -354,40 +352,52 @@ def test_max_logprobs():
bad_sampling_params = SamplingParams(logprobs=2) bad_sampling_params = SamplingParams(logprobs=2)
with pytest.raises(ValueError): with pytest.raises(ValueError):
runner.generate(["Hello world"], sampling_params=bad_sampling_params) runner.generate(["Hello world"],
sampling_params=bad_sampling_params)
def test_none_logprobs(vllm_model, example_prompts): def test_none_logprobs(vllm_model, example_prompts,
monkeypatch: pytest.MonkeyPatch):
"""Engine should return `logprobs` and `prompt_logprobs` as `None` """Engine should return `logprobs` and `prompt_logprobs` as `None`
Args: Args:
vllm_model: vLLM model fixture vllm_model: vLLM model fixture
example_prompts: list of example prompts (test fixture) example_prompts: list of example prompts (test fixture)
""" """
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
max_tokens = 5 max_tokens = 5
sampling_params_logprobs_none = SamplingParams(max_tokens=max_tokens, sampling_params_logprobs_none = SamplingParams(
max_tokens=max_tokens,
logprobs=None, logprobs=None,
prompt_logprobs=None, prompt_logprobs=None,
temperature=0.0) temperature=0.0,
)
results_logprobs_none = vllm_model.model.generate( results_logprobs_none = vllm_model.model.generate(
example_prompts, sampling_params=sampling_params_logprobs_none) example_prompts,
sampling_params=sampling_params_logprobs_none,
)
for i in range(len(results_logprobs_none)): for i in range(len(results_logprobs_none)):
# Check sample logprobs are None # Check sample logprobs are None
assert results_logprobs_none[i].outputs[0].logprobs is None assert results_logprobs_none[i].outputs[0].logprobs is None
assert results_logprobs_none[i].outputs[0].cumulative_logprob is None assert results_logprobs_none[i].outputs[
0].cumulative_logprob is None
# Check prompt logprobs are None # Check prompt logprobs are None
assert results_logprobs_none[i].prompt_logprobs is None assert results_logprobs_none[i].prompt_logprobs is None
def test_zero_logprobs(vllm_model, example_prompts): def test_zero_logprobs(vllm_model, example_prompts,
monkeypatch: pytest.MonkeyPatch):
"""Engine should return sampled token and prompt token logprobs """Engine should return sampled token and prompt token logprobs
Args: Args:
vllm_model: vLLM model fixture vllm_model: vLLM model fixture
example_prompts: list of example prompts (test fixture) example_prompts: list of example prompts (test fixture)
""" """
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
max_tokens = 5 max_tokens = 5
sampling_params_logprobs_zero = SamplingParams(max_tokens=max_tokens, sampling_params_logprobs_zero = SamplingParams(max_tokens=max_tokens,
......
...@@ -3,11 +3,16 @@ ...@@ -3,11 +3,16 @@
Run `pytest tests/v1/tpu/test_basic.py`. Run `pytest tests/v1/tpu/test_basic.py`.
""" """
from __future__ import annotations
from typing import TYPE_CHECKING
import pytest import pytest
from vllm.platforms import current_platform from vllm.platforms import current_platform
from ...conftest import VllmRunner if TYPE_CHECKING:
from tests.conftest import VllmRunner
MODELS = [ MODELS = [
# "Qwen/Qwen2-7B-Instruct", # "Qwen/Qwen2-7B-Instruct",
...@@ -28,7 +33,8 @@ TENSOR_PARALLEL_SIZES = [1] ...@@ -28,7 +33,8 @@ TENSOR_PARALLEL_SIZES = [1]
@pytest.mark.parametrize("enforce_eager", [True]) @pytest.mark.parametrize("enforce_eager", [True])
@pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES) @pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES)
def test_models( def test_models(
monkeypatch, vllm_runner: type[VllmRunner],
monkeypatch: pytest.MonkeyPatch,
model: str, model: str,
max_tokens: int, max_tokens: int,
enforce_eager: bool, enforce_eager: bool,
...@@ -41,7 +47,7 @@ def test_models( ...@@ -41,7 +47,7 @@ def test_models(
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
with VllmRunner( with vllm_runner(
model, model,
max_model_len=8192, max_model_len=8192,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment