Unverified Commit a73e183e authored by Sibi's avatar Sibi Committed by GitHub
Browse files

[Misc] Replace os environ to monkeypatch in test suite (#14516)


Signed-off-by: default avatarsibi <85477603+t-sibiraj@users.noreply.github.com>
Signed-off-by: default avatarAaron Pham <contact@aarnphm.xyz>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: default avatarAaron Pham <contact@aarnphm.xyz>
parent 1e799b7e
......@@ -90,7 +90,8 @@ def echo(self, msg: str, err_msg: Optional[str] = None) -> str:
@fork_new_process_for_each_test
@pytest.mark.parametrize("multiprocessing_mode", [True, False])
def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
def test_engine_core_client(monkeypatch: pytest.MonkeyPatch,
multiprocessing_mode: bool):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
......@@ -175,7 +176,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
@pytest.mark.asyncio(loop_scope="function")
async def test_engine_core_client_asyncio(monkeypatch):
async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
......
......@@ -57,7 +57,7 @@ def _repeat_logprob_config(
logprob_prompt_logprob_list: BatchLogprobsSpecType,
) -> BatchLogprobsSpecType:
"""Ensure each test prompt has a logprob config.
A logprob config specifies the optional (i.e.
may-be-`None`) number of sample logprobs and
the optional number of prompt logprobs.
......@@ -80,7 +80,7 @@ def _repeat_logprob_config(
(optional num sample logprob,
optional num prompt logprob)
tuples
Returns:
list of
(optional num sample logprob,optional num prompt logprob)
......@@ -255,14 +255,12 @@ def _run_and_validate(
[NONE, SAMPLE, PROMPT, SAMPLE_PROMPT])
@pytest.mark.parametrize("temperature", [0.0, 2.0])
def test_get_logprobs_and_prompt_logprobs(
hf_model,
vllm_model,
batch_logprobs_composition: BatchLogprobsComposition,
temperature: float,
example_prompts,
) -> None:
hf_model, vllm_model,
batch_logprobs_composition: BatchLogprobsComposition,
temperature: float, example_prompts: list[str],
monkeypatch: pytest.MonkeyPatch) -> None:
"""Test V1 Engine logprobs & prompt logprobs
Exercise a variety of combinations of `logprobs` and `prompt_logprobs`
settings and validate that
* The generated logprobs and prompt logprobs are consistent with the
......@@ -279,7 +277,7 @@ def test_get_logprobs_and_prompt_logprobs(
To save time, only test one APC-enabled scenario
(sample & prompt logprobs enabled, temperature>0.0).
Args:
hf_model: HuggingFace reference model fixture
vllm_model: vLLM model fixture
......@@ -287,128 +285,140 @@ def test_get_logprobs_and_prompt_logprobs(
temperature: "temperature" sampling parameter
example_prompts: example prompt fixture
"""
do_apc = vllm_model.model.llm_engine.cache_config.enable_prefix_caching
if do_apc and (temperature < 2.0
or batch_logprobs_composition != SAMPLE_PROMPT):
# Skip some test-cases to save time.
pytest.skip()
test_prompts = example_prompts
max_tokens = 5
hf_outputs = hf_model.generate_greedy(
test_prompts,
max_tokens=max_tokens,
)
hf_logprobs = hf_model.generate_greedy_logprobs(
test_prompts,
max_tokens=max_tokens,
)
# Batch has mixed sample params
# (different logprobs/prompt logprobs combos)
logprob_prompt_logprob_list = get_test_batch(batch_logprobs_composition)
# Ensure that each test prompt has a logprob config for testing
logprob_prompt_logprob_list = _repeat_logprob_config(
test_prompts, logprob_prompt_logprob_list)
# Generate SamplingParams
vllm_sampling_params = [
SamplingParams(max_tokens=max_tokens,
logprobs=num_lp,
prompt_logprobs=num_plp,
temperature=temperature,
seed=1984)
for num_lp, num_plp in logprob_prompt_logprob_list
]
for _ in range(2 if do_apc else 1):
_run_and_validate(
vllm_model=vllm_model,
test_prompts=test_prompts,
vllm_sampling_params=vllm_sampling_params,
hf_logprobs=hf_logprobs,
hf_outputs=hf_outputs,
logprob_prompt_logprob_list=logprob_prompt_logprob_list,
temperature=temperature,
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
do_apc = vllm_model.model.llm_engine.cache_config.enable_prefix_caching
if do_apc and (temperature < 2.0
or batch_logprobs_composition != SAMPLE_PROMPT):
# Skip some test-cases to save time.
pytest.skip()
test_prompts = example_prompts
max_tokens = 5
hf_outputs = hf_model.generate_greedy(
test_prompts,
max_tokens=max_tokens,
do_apc=do_apc)
def test_max_logprobs():
)
hf_logprobs = hf_model.generate_greedy_logprobs(
test_prompts,
max_tokens=max_tokens,
)
# Batch has mixed sample params
# (different logprobs/prompt logprobs combos)
logprob_prompt_logprob_list = get_test_batch(
batch_logprobs_composition)
# Ensure that each test prompt has a logprob config for testing
logprob_prompt_logprob_list = _repeat_logprob_config(
test_prompts, logprob_prompt_logprob_list)
# Generate SamplingParams
vllm_sampling_params = [
SamplingParams(max_tokens=max_tokens,
logprobs=num_lp,
prompt_logprobs=num_plp,
temperature=temperature,
seed=1984)
for num_lp, num_plp in logprob_prompt_logprob_list
]
for _ in range(2 if do_apc else 1):
_run_and_validate(
vllm_model=vllm_model,
test_prompts=test_prompts,
vllm_sampling_params=vllm_sampling_params,
hf_logprobs=hf_logprobs,
hf_outputs=hf_outputs,
logprob_prompt_logprob_list=logprob_prompt_logprob_list,
temperature=temperature,
max_tokens=max_tokens,
do_apc=do_apc)
def test_max_logprobs(monkeypatch: pytest.MonkeyPatch):
"""vLLM v1 engine should fail a request with `logprobs > max_logprobs`
Should also fail for `prompt_logprobs > max_logprobs`
APC should not matter as this test checks basic request validation.
Args:
monkeypatch
"""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
runner = VllmRunner("facebook/opt-125m",
max_logprobs=1,
enable_prefix_caching=False,
max_model_len=256)
vllm_sampling_params = SamplingParams(logprobs=1)
# should pass
runner.generate(["Hello world"], sampling_params=vllm_sampling_params)
runner = VllmRunner("facebook/opt-125m",
max_logprobs=1,
enable_prefix_caching=False,
max_model_len=256)
vllm_sampling_params = SamplingParams(logprobs=1)
# should pass
runner.generate(["Hello world"], sampling_params=vllm_sampling_params)
bad_sampling_params = SamplingParams(logprobs=2)
with pytest.raises(ValueError):
runner.generate(["Hello world"], sampling_params=bad_sampling_params)
bad_sampling_params = SamplingParams(logprobs=2)
with pytest.raises(ValueError):
runner.generate(["Hello world"],
sampling_params=bad_sampling_params)
def test_none_logprobs(vllm_model, example_prompts):
def test_none_logprobs(vllm_model, example_prompts,
monkeypatch: pytest.MonkeyPatch):
"""Engine should return `logprobs` and `prompt_logprobs` as `None`
Args:
vllm_model: vLLM model fixture
example_prompts: list of example prompts (test fixture)
"""
max_tokens = 5
sampling_params_logprobs_none = SamplingParams(max_tokens=max_tokens,
logprobs=None,
prompt_logprobs=None,
temperature=0.0)
results_logprobs_none = vllm_model.model.generate(
example_prompts, sampling_params=sampling_params_logprobs_none)
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
max_tokens = 5
for i in range(len(results_logprobs_none)):
# Check sample logprobs are None
assert results_logprobs_none[i].outputs[0].logprobs is None
assert results_logprobs_none[i].outputs[0].cumulative_logprob is None
# Check prompt logprobs are None
assert results_logprobs_none[i].prompt_logprobs is None
def test_zero_logprobs(vllm_model, example_prompts):
sampling_params_logprobs_none = SamplingParams(
max_tokens=max_tokens,
logprobs=None,
prompt_logprobs=None,
temperature=0.0,
)
results_logprobs_none = vllm_model.model.generate(
example_prompts,
sampling_params=sampling_params_logprobs_none,
)
for i in range(len(results_logprobs_none)):
# Check sample logprobs are None
assert results_logprobs_none[i].outputs[0].logprobs is None
assert results_logprobs_none[i].outputs[
0].cumulative_logprob is None
# Check prompt logprobs are None
assert results_logprobs_none[i].prompt_logprobs is None
def test_zero_logprobs(vllm_model, example_prompts,
monkeypatch: pytest.MonkeyPatch):
"""Engine should return sampled token and prompt token logprobs
Args:
vllm_model: vLLM model fixture
example_prompts: list of example prompts (test fixture)
"""
max_tokens = 5
sampling_params_logprobs_zero = SamplingParams(max_tokens=max_tokens,
logprobs=0,
prompt_logprobs=0,
temperature=0.0)
results_logprobs_zero = vllm_model.model.generate(
example_prompts, sampling_params=sampling_params_logprobs_zero)
for i in range(len(results_logprobs_zero)):
# Check that there is one sample logprob dict for each
# sample token
logprobs = results_logprobs_zero[i].outputs[0].logprobs
prompt_logprobs = results_logprobs_zero[i].prompt_logprobs
sampled_token_ids = results_logprobs_zero[i].outputs[0].token_ids
prompt_token_ids = results_logprobs_zero[i].prompt_token_ids
assert logprobs is not None
assert len(sampled_token_ids) == len(logprobs)
assert results_logprobs_zero[i].outputs[
0].cumulative_logprob is not None
# Check that there is one prompt logprob dict for each
# prompt token
assert prompt_logprobs is not None
assert len(prompt_token_ids) == len(prompt_logprobs)
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
max_tokens = 5
sampling_params_logprobs_zero = SamplingParams(max_tokens=max_tokens,
logprobs=0,
prompt_logprobs=0,
temperature=0.0)
results_logprobs_zero = vllm_model.model.generate(
example_prompts, sampling_params=sampling_params_logprobs_zero)
for i in range(len(results_logprobs_zero)):
# Check that there is one sample logprob dict for each
# sample token
logprobs = results_logprobs_zero[i].outputs[0].logprobs
prompt_logprobs = results_logprobs_zero[i].prompt_logprobs
sampled_token_ids = results_logprobs_zero[i].outputs[0].token_ids
prompt_token_ids = results_logprobs_zero[i].prompt_token_ids
assert logprobs is not None
assert len(sampled_token_ids) == len(logprobs)
assert results_logprobs_zero[i].outputs[
0].cumulative_logprob is not None
# Check that there is one prompt logprob dict for each
# prompt token
assert prompt_logprobs is not None
assert len(prompt_token_ids) == len(prompt_logprobs)
......@@ -3,11 +3,16 @@
Run `pytest tests/v1/tpu/test_basic.py`.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import pytest
from vllm.platforms import current_platform
from ...conftest import VllmRunner
if TYPE_CHECKING:
from tests.conftest import VllmRunner
MODELS = [
# "Qwen/Qwen2-7B-Instruct",
......@@ -28,7 +33,8 @@ TENSOR_PARALLEL_SIZES = [1]
@pytest.mark.parametrize("enforce_eager", [True])
@pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES)
def test_models(
monkeypatch,
vllm_runner: type[VllmRunner],
monkeypatch: pytest.MonkeyPatch,
model: str,
max_tokens: int,
enforce_eager: bool,
......@@ -41,7 +47,7 @@ def test_models(
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
with VllmRunner(
with vllm_runner(
model,
max_model_len=8192,
enforce_eager=enforce_eager,
......@@ -50,5 +56,5 @@ def test_models(
tensor_parallel_size=tensor_parallel_size) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens)
output = vllm_outputs[0][1]
assert "1024" in output
output = vllm_outputs[0][1]
assert "1024" in output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment