Unverified Commit a73e183e authored by Sibi's avatar Sibi Committed by GitHub
Browse files

[Misc] Replace os environ to monkeypatch in test suite (#14516)


Signed-off-by: default avatarsibi <85477603+t-sibiraj@users.noreply.github.com>
Signed-off-by: default avatarAaron Pham <contact@aarnphm.xyz>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: default avatarAaron Pham <contact@aarnphm.xyz>
parent 1e799b7e
......@@ -90,7 +90,8 @@ def echo(self, msg: str, err_msg: Optional[str] = None) -> str:
@fork_new_process_for_each_test
@pytest.mark.parametrize("multiprocessing_mode", [True, False])
def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
def test_engine_core_client(monkeypatch: pytest.MonkeyPatch,
multiprocessing_mode: bool):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
......@@ -175,7 +176,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
@pytest.mark.asyncio(loop_scope="function")
async def test_engine_core_client_asyncio(monkeypatch):
async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
......
......@@ -255,12 +255,10 @@ def _run_and_validate(
[NONE, SAMPLE, PROMPT, SAMPLE_PROMPT])
@pytest.mark.parametrize("temperature", [0.0, 2.0])
def test_get_logprobs_and_prompt_logprobs(
hf_model,
vllm_model,
hf_model, vllm_model,
batch_logprobs_composition: BatchLogprobsComposition,
temperature: float,
example_prompts,
) -> None:
temperature: float, example_prompts: list[str],
monkeypatch: pytest.MonkeyPatch) -> None:
"""Test V1 Engine logprobs & prompt logprobs
Exercise a variety of combinations of `logprobs` and `prompt_logprobs`
......@@ -287,6 +285,8 @@ def test_get_logprobs_and_prompt_logprobs(
temperature: "temperature" sampling parameter
example_prompts: example prompt fixture
"""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
do_apc = vllm_model.model.llm_engine.cache_config.enable_prefix_caching
if do_apc and (temperature < 2.0
or batch_logprobs_composition != SAMPLE_PROMPT):
......@@ -306,7 +306,8 @@ def test_get_logprobs_and_prompt_logprobs(
# Batch has mixed sample params
# (different logprobs/prompt logprobs combos)
logprob_prompt_logprob_list = get_test_batch(batch_logprobs_composition)
logprob_prompt_logprob_list = get_test_batch(
batch_logprobs_composition)
# Ensure that each test prompt has a logprob config for testing
logprob_prompt_logprob_list = _repeat_logprob_config(
......@@ -333,16 +334,13 @@ def test_get_logprobs_and_prompt_logprobs(
do_apc=do_apc)
def test_max_logprobs():
def test_max_logprobs(monkeypatch: pytest.MonkeyPatch):
"""vLLM v1 engine should fail a request with `logprobs > max_logprobs`
Should also fail for `prompt_logprobs > max_logprobs`
APC should not matter as this test checks basic request validation.
Args:
monkeypatch
"""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
runner = VllmRunner("facebook/opt-125m",
max_logprobs=1,
......@@ -354,40 +352,52 @@ def test_max_logprobs():
bad_sampling_params = SamplingParams(logprobs=2)
with pytest.raises(ValueError):
runner.generate(["Hello world"], sampling_params=bad_sampling_params)
runner.generate(["Hello world"],
sampling_params=bad_sampling_params)
def test_none_logprobs(vllm_model, example_prompts):
def test_none_logprobs(vllm_model, example_prompts,
monkeypatch: pytest.MonkeyPatch):
"""Engine should return `logprobs` and `prompt_logprobs` as `None`
Args:
vllm_model: vLLM model fixture
example_prompts: list of example prompts (test fixture)
"""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
max_tokens = 5
sampling_params_logprobs_none = SamplingParams(max_tokens=max_tokens,
sampling_params_logprobs_none = SamplingParams(
max_tokens=max_tokens,
logprobs=None,
prompt_logprobs=None,
temperature=0.0)
temperature=0.0,
)
results_logprobs_none = vllm_model.model.generate(
example_prompts, sampling_params=sampling_params_logprobs_none)
example_prompts,
sampling_params=sampling_params_logprobs_none,
)
for i in range(len(results_logprobs_none)):
# Check sample logprobs are None
assert results_logprobs_none[i].outputs[0].logprobs is None
assert results_logprobs_none[i].outputs[0].cumulative_logprob is None
assert results_logprobs_none[i].outputs[
0].cumulative_logprob is None
# Check prompt logprobs are None
assert results_logprobs_none[i].prompt_logprobs is None
def test_zero_logprobs(vllm_model, example_prompts):
def test_zero_logprobs(vllm_model, example_prompts,
monkeypatch: pytest.MonkeyPatch):
"""Engine should return sampled token and prompt token logprobs
Args:
vllm_model: vLLM model fixture
example_prompts: list of example prompts (test fixture)
"""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
max_tokens = 5
sampling_params_logprobs_zero = SamplingParams(max_tokens=max_tokens,
......
......@@ -3,11 +3,16 @@
Run `pytest tests/v1/tpu/test_basic.py`.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import pytest
from vllm.platforms import current_platform
from ...conftest import VllmRunner
if TYPE_CHECKING:
from tests.conftest import VllmRunner
MODELS = [
# "Qwen/Qwen2-7B-Instruct",
......@@ -28,7 +33,8 @@ TENSOR_PARALLEL_SIZES = [1]
@pytest.mark.parametrize("enforce_eager", [True])
@pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES)
def test_models(
monkeypatch,
vllm_runner: type[VllmRunner],
monkeypatch: pytest.MonkeyPatch,
model: str,
max_tokens: int,
enforce_eager: bool,
......@@ -41,7 +47,7 @@ def test_models(
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
with VllmRunner(
with vllm_runner(
model,
max_model_len=8192,
enforce_eager=enforce_eager,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment