Unverified Commit a73e183e authored by Sibi's avatar Sibi Committed by GitHub
Browse files

[Misc] Replace os environ to monkeypatch in test suite (#14516)


Signed-off-by: default avatarsibi <85477603+t-sibiraj@users.noreply.github.com>
Signed-off-by: default avatarAaron Pham <contact@aarnphm.xyz>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: default avatarAaron Pham <contact@aarnphm.xyz>
parent 1e799b7e
...@@ -90,7 +90,8 @@ def echo(self, msg: str, err_msg: Optional[str] = None) -> str: ...@@ -90,7 +90,8 @@ def echo(self, msg: str, err_msg: Optional[str] = None) -> str:
@fork_new_process_for_each_test @fork_new_process_for_each_test
@pytest.mark.parametrize("multiprocessing_mode", [True, False]) @pytest.mark.parametrize("multiprocessing_mode", [True, False])
def test_engine_core_client(monkeypatch, multiprocessing_mode: bool): def test_engine_core_client(monkeypatch: pytest.MonkeyPatch,
multiprocessing_mode: bool):
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
...@@ -175,7 +176,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool): ...@@ -175,7 +176,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
@pytest.mark.asyncio(loop_scope="function") @pytest.mark.asyncio(loop_scope="function")
async def test_engine_core_client_asyncio(monkeypatch): async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
......
...@@ -57,7 +57,7 @@ def _repeat_logprob_config( ...@@ -57,7 +57,7 @@ def _repeat_logprob_config(
logprob_prompt_logprob_list: BatchLogprobsSpecType, logprob_prompt_logprob_list: BatchLogprobsSpecType,
) -> BatchLogprobsSpecType: ) -> BatchLogprobsSpecType:
"""Ensure each test prompt has a logprob config. """Ensure each test prompt has a logprob config.
A logprob config specifies the optional (i.e. A logprob config specifies the optional (i.e.
may-be-`None`) number of sample logprobs and may-be-`None`) number of sample logprobs and
the optional number of prompt logprobs. the optional number of prompt logprobs.
...@@ -80,7 +80,7 @@ def _repeat_logprob_config( ...@@ -80,7 +80,7 @@ def _repeat_logprob_config(
(optional num sample logprob, (optional num sample logprob,
optional num prompt logprob) optional num prompt logprob)
tuples tuples
Returns: Returns:
list of list of
(optional num sample logprob,optional num prompt logprob) (optional num sample logprob,optional num prompt logprob)
...@@ -255,14 +255,12 @@ def _run_and_validate( ...@@ -255,14 +255,12 @@ def _run_and_validate(
[NONE, SAMPLE, PROMPT, SAMPLE_PROMPT]) [NONE, SAMPLE, PROMPT, SAMPLE_PROMPT])
@pytest.mark.parametrize("temperature", [0.0, 2.0]) @pytest.mark.parametrize("temperature", [0.0, 2.0])
def test_get_logprobs_and_prompt_logprobs( def test_get_logprobs_and_prompt_logprobs(
hf_model, hf_model, vllm_model,
vllm_model, batch_logprobs_composition: BatchLogprobsComposition,
batch_logprobs_composition: BatchLogprobsComposition, temperature: float, example_prompts: list[str],
temperature: float, monkeypatch: pytest.MonkeyPatch) -> None:
example_prompts,
) -> None:
"""Test V1 Engine logprobs & prompt logprobs """Test V1 Engine logprobs & prompt logprobs
Exercise a variety of combinations of `logprobs` and `prompt_logprobs` Exercise a variety of combinations of `logprobs` and `prompt_logprobs`
settings and validate that settings and validate that
* The generated logprobs and prompt logprobs are consistent with the * The generated logprobs and prompt logprobs are consistent with the
...@@ -279,7 +277,7 @@ def test_get_logprobs_and_prompt_logprobs( ...@@ -279,7 +277,7 @@ def test_get_logprobs_and_prompt_logprobs(
To save time, only test one APC-enabled scenario To save time, only test one APC-enabled scenario
(sample & prompt logprobs enabled, temperature>0.0). (sample & prompt logprobs enabled, temperature>0.0).
Args: Args:
hf_model: HuggingFace reference model fixture hf_model: HuggingFace reference model fixture
vllm_model: vLLM model fixture vllm_model: vLLM model fixture
...@@ -287,128 +285,140 @@ def test_get_logprobs_and_prompt_logprobs( ...@@ -287,128 +285,140 @@ def test_get_logprobs_and_prompt_logprobs(
temperature: "temperature" sampling parameter temperature: "temperature" sampling parameter
example_prompts: example prompt fixture example_prompts: example prompt fixture
""" """
do_apc = vllm_model.model.llm_engine.cache_config.enable_prefix_caching with monkeypatch.context() as m:
if do_apc and (temperature < 2.0 m.setenv("VLLM_USE_V1", "1")
or batch_logprobs_composition != SAMPLE_PROMPT): do_apc = vllm_model.model.llm_engine.cache_config.enable_prefix_caching
# Skip some test-cases to save time. if do_apc and (temperature < 2.0
pytest.skip() or batch_logprobs_composition != SAMPLE_PROMPT):
test_prompts = example_prompts # Skip some test-cases to save time.
pytest.skip()
max_tokens = 5 test_prompts = example_prompts
hf_outputs = hf_model.generate_greedy(
test_prompts, max_tokens = 5
max_tokens=max_tokens, hf_outputs = hf_model.generate_greedy(
) test_prompts,
hf_logprobs = hf_model.generate_greedy_logprobs(
test_prompts,
max_tokens=max_tokens,
)
# Batch has mixed sample params
# (different logprobs/prompt logprobs combos)
logprob_prompt_logprob_list = get_test_batch(batch_logprobs_composition)
# Ensure that each test prompt has a logprob config for testing
logprob_prompt_logprob_list = _repeat_logprob_config(
test_prompts, logprob_prompt_logprob_list)
# Generate SamplingParams
vllm_sampling_params = [
SamplingParams(max_tokens=max_tokens,
logprobs=num_lp,
prompt_logprobs=num_plp,
temperature=temperature,
seed=1984)
for num_lp, num_plp in logprob_prompt_logprob_list
]
for _ in range(2 if do_apc else 1):
_run_and_validate(
vllm_model=vllm_model,
test_prompts=test_prompts,
vllm_sampling_params=vllm_sampling_params,
hf_logprobs=hf_logprobs,
hf_outputs=hf_outputs,
logprob_prompt_logprob_list=logprob_prompt_logprob_list,
temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
do_apc=do_apc) )
hf_logprobs = hf_model.generate_greedy_logprobs(
test_prompts,
def test_max_logprobs(): max_tokens=max_tokens,
)
# Batch has mixed sample params
# (different logprobs/prompt logprobs combos)
logprob_prompt_logprob_list = get_test_batch(
batch_logprobs_composition)
# Ensure that each test prompt has a logprob config for testing
logprob_prompt_logprob_list = _repeat_logprob_config(
test_prompts, logprob_prompt_logprob_list)
# Generate SamplingParams
vllm_sampling_params = [
SamplingParams(max_tokens=max_tokens,
logprobs=num_lp,
prompt_logprobs=num_plp,
temperature=temperature,
seed=1984)
for num_lp, num_plp in logprob_prompt_logprob_list
]
for _ in range(2 if do_apc else 1):
_run_and_validate(
vllm_model=vllm_model,
test_prompts=test_prompts,
vllm_sampling_params=vllm_sampling_params,
hf_logprobs=hf_logprobs,
hf_outputs=hf_outputs,
logprob_prompt_logprob_list=logprob_prompt_logprob_list,
temperature=temperature,
max_tokens=max_tokens,
do_apc=do_apc)
def test_max_logprobs(monkeypatch: pytest.MonkeyPatch):
"""vLLM v1 engine should fail a request with `logprobs > max_logprobs` """vLLM v1 engine should fail a request with `logprobs > max_logprobs`
Should also fail for `prompt_logprobs > max_logprobs` Should also fail for `prompt_logprobs > max_logprobs`
APC should not matter as this test checks basic request validation. APC should not matter as this test checks basic request validation.
Args:
monkeypatch
""" """
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
runner = VllmRunner("facebook/opt-125m", runner = VllmRunner("facebook/opt-125m",
max_logprobs=1, max_logprobs=1,
enable_prefix_caching=False, enable_prefix_caching=False,
max_model_len=256) max_model_len=256)
vllm_sampling_params = SamplingParams(logprobs=1) vllm_sampling_params = SamplingParams(logprobs=1)
# should pass # should pass
runner.generate(["Hello world"], sampling_params=vllm_sampling_params) runner.generate(["Hello world"], sampling_params=vllm_sampling_params)
bad_sampling_params = SamplingParams(logprobs=2) bad_sampling_params = SamplingParams(logprobs=2)
with pytest.raises(ValueError): with pytest.raises(ValueError):
runner.generate(["Hello world"], sampling_params=bad_sampling_params) runner.generate(["Hello world"],
sampling_params=bad_sampling_params)
def test_none_logprobs(vllm_model, example_prompts): def test_none_logprobs(vllm_model, example_prompts,
monkeypatch: pytest.MonkeyPatch):
"""Engine should return `logprobs` and `prompt_logprobs` as `None` """Engine should return `logprobs` and `prompt_logprobs` as `None`
Args: Args:
vllm_model: vLLM model fixture vllm_model: vLLM model fixture
example_prompts: list of example prompts (test fixture) example_prompts: list of example prompts (test fixture)
""" """
max_tokens = 5 with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
sampling_params_logprobs_none = SamplingParams(max_tokens=max_tokens, max_tokens = 5
logprobs=None,
prompt_logprobs=None,
temperature=0.0)
results_logprobs_none = vllm_model.model.generate(
example_prompts, sampling_params=sampling_params_logprobs_none)
for i in range(len(results_logprobs_none)): sampling_params_logprobs_none = SamplingParams(
# Check sample logprobs are None max_tokens=max_tokens,
assert results_logprobs_none[i].outputs[0].logprobs is None logprobs=None,
assert results_logprobs_none[i].outputs[0].cumulative_logprob is None prompt_logprobs=None,
# Check prompt logprobs are None temperature=0.0,
assert results_logprobs_none[i].prompt_logprobs is None )
results_logprobs_none = vllm_model.model.generate(
example_prompts,
def test_zero_logprobs(vllm_model, example_prompts): sampling_params=sampling_params_logprobs_none,
)
for i in range(len(results_logprobs_none)):
# Check sample logprobs are None
assert results_logprobs_none[i].outputs[0].logprobs is None
assert results_logprobs_none[i].outputs[
0].cumulative_logprob is None
# Check prompt logprobs are None
assert results_logprobs_none[i].prompt_logprobs is None
def test_zero_logprobs(vllm_model, example_prompts,
monkeypatch: pytest.MonkeyPatch):
"""Engine should return sampled token and prompt token logprobs """Engine should return sampled token and prompt token logprobs
Args: Args:
vllm_model: vLLM model fixture vllm_model: vLLM model fixture
example_prompts: list of example prompts (test fixture) example_prompts: list of example prompts (test fixture)
""" """
max_tokens = 5 with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
sampling_params_logprobs_zero = SamplingParams(max_tokens=max_tokens, max_tokens = 5
logprobs=0,
prompt_logprobs=0, sampling_params_logprobs_zero = SamplingParams(max_tokens=max_tokens,
temperature=0.0) logprobs=0,
results_logprobs_zero = vllm_model.model.generate( prompt_logprobs=0,
example_prompts, sampling_params=sampling_params_logprobs_zero) temperature=0.0)
results_logprobs_zero = vllm_model.model.generate(
for i in range(len(results_logprobs_zero)): example_prompts, sampling_params=sampling_params_logprobs_zero)
# Check that there is one sample logprob dict for each
# sample token for i in range(len(results_logprobs_zero)):
logprobs = results_logprobs_zero[i].outputs[0].logprobs # Check that there is one sample logprob dict for each
prompt_logprobs = results_logprobs_zero[i].prompt_logprobs # sample token
sampled_token_ids = results_logprobs_zero[i].outputs[0].token_ids logprobs = results_logprobs_zero[i].outputs[0].logprobs
prompt_token_ids = results_logprobs_zero[i].prompt_token_ids prompt_logprobs = results_logprobs_zero[i].prompt_logprobs
assert logprobs is not None sampled_token_ids = results_logprobs_zero[i].outputs[0].token_ids
assert len(sampled_token_ids) == len(logprobs) prompt_token_ids = results_logprobs_zero[i].prompt_token_ids
assert results_logprobs_zero[i].outputs[ assert logprobs is not None
0].cumulative_logprob is not None assert len(sampled_token_ids) == len(logprobs)
# Check that there is one prompt logprob dict for each assert results_logprobs_zero[i].outputs[
# prompt token 0].cumulative_logprob is not None
assert prompt_logprobs is not None # Check that there is one prompt logprob dict for each
assert len(prompt_token_ids) == len(prompt_logprobs) # prompt token
assert prompt_logprobs is not None
assert len(prompt_token_ids) == len(prompt_logprobs)
...@@ -3,11 +3,16 @@ ...@@ -3,11 +3,16 @@
Run `pytest tests/v1/tpu/test_basic.py`. Run `pytest tests/v1/tpu/test_basic.py`.
""" """
from __future__ import annotations
from typing import TYPE_CHECKING
import pytest import pytest
from vllm.platforms import current_platform from vllm.platforms import current_platform
from ...conftest import VllmRunner if TYPE_CHECKING:
from tests.conftest import VllmRunner
MODELS = [ MODELS = [
# "Qwen/Qwen2-7B-Instruct", # "Qwen/Qwen2-7B-Instruct",
...@@ -28,7 +33,8 @@ TENSOR_PARALLEL_SIZES = [1] ...@@ -28,7 +33,8 @@ TENSOR_PARALLEL_SIZES = [1]
@pytest.mark.parametrize("enforce_eager", [True]) @pytest.mark.parametrize("enforce_eager", [True])
@pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES) @pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES)
def test_models( def test_models(
monkeypatch, vllm_runner: type[VllmRunner],
monkeypatch: pytest.MonkeyPatch,
model: str, model: str,
max_tokens: int, max_tokens: int,
enforce_eager: bool, enforce_eager: bool,
...@@ -41,7 +47,7 @@ def test_models( ...@@ -41,7 +47,7 @@ def test_models(
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
with VllmRunner( with vllm_runner(
model, model,
max_model_len=8192, max_model_len=8192,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
...@@ -50,5 +56,5 @@ def test_models( ...@@ -50,5 +56,5 @@ def test_models(
tensor_parallel_size=tensor_parallel_size) as vllm_model: tensor_parallel_size=tensor_parallel_size) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, vllm_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens) max_tokens)
output = vllm_outputs[0][1] output = vllm_outputs[0][1]
assert "1024" in output assert "1024" in output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment