Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d9784107
Unverified
Commit
d9784107
authored
Jul 21, 2025
by
Ning Xie
Committed by
GitHub
Jul 21, 2025
Browse files
[Misc] unify variable for LLM instance (#20996)
Signed-off-by:
Andy Xie
<
andy.xning@gmail.com
>
parent
e6b90a28
Changes
53
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
75 additions
and
77 deletions
+75
-77
tests/quantization/test_quark.py
tests/quantization/test_quark.py
+2
-2
tests/quantization/test_register_quantization_config.py
tests/quantization/test_register_quantization_config.py
+1
-1
tests/samplers/test_ignore_eos.py
tests/samplers/test_ignore_eos.py
+1
-1
tests/samplers/test_logits_processor.py
tests/samplers/test_logits_processor.py
+5
-5
tests/samplers/test_logprobs.py
tests/samplers/test_logprobs.py
+2
-2
tests/samplers/test_no_bad_words.py
tests/samplers/test_no_bad_words.py
+6
-6
tests/samplers/test_seeded_generate.py
tests/samplers/test_seeded_generate.py
+1
-1
tests/tokenization/test_detokenize.py
tests/tokenization/test_detokenize.py
+1
-1
tests/v1/core/test_scheduler_e2e.py
tests/v1/core/test_scheduler_e2e.py
+6
-6
tests/v1/engine/test_llm_engine.py
tests/v1/engine/test_llm_engine.py
+7
-7
tests/v1/sample/test_logprobs.py
tests/v1/sample/test_logprobs.py
+4
-4
tests/v1/sample/test_sampling_params_e2e.py
tests/v1/sample/test_sampling_params_e2e.py
+36
-38
tests/v1/test_oracle.py
tests/v1/test_oracle.py
+3
-3
No files found.
tests/quantization/test_quark.py
View file @
d9784107
...
@@ -107,11 +107,11 @@ def test_quark_fp8_parity(vllm_runner):
...
@@ -107,11 +107,11 @@ def test_quark_fp8_parity(vllm_runner):
}
}
with
(
vllm_runner
(
quark_model_id
,
**
llm_kwargs
)
as
with
(
vllm_runner
(
quark_model_id
,
**
llm_kwargs
)
as
quark_handle
,
vllm_runner
(
fp8_model_id
,
**
llm_kwargs
)
as
fp8_handle
):
quark_handle
,
vllm_runner
(
fp8_model_id
,
**
llm_kwargs
)
as
fp8_handle
):
quark_model
=
(
quark_handle
.
model
.
llm_engine
.
model_executor
.
quark_model
=
(
quark_handle
.
llm
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
)
driver_worker
.
model_runner
.
model
)
quark_state_dict
=
quark_model
.
state_dict
()
quark_state_dict
=
quark_model
.
state_dict
()
fp8_model
=
(
fp8_handle
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
fp8_model
=
(
fp8_handle
.
llm
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
)
model_runner
.
model
)
fp8_state_dict
=
fp8_model
.
state_dict
()
fp8_state_dict
=
fp8_model
.
state_dict
()
...
...
tests/quantization/test_register_quantization_config.py
View file @
d9784107
...
@@ -111,7 +111,7 @@ def test_custom_quant(vllm_runner, model, monkeypatch):
...
@@ -111,7 +111,7 @@ def test_custom_quant(vllm_runner, model, monkeypatch):
quantization
=
"custom_quant"
,
quantization
=
"custom_quant"
,
enforce_eager
=
True
)
as
llm
:
enforce_eager
=
True
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
model
=
llm
.
llm
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
layer
=
model
.
model
.
layers
[
0
]
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
qkv_proj
=
layer
.
self_attn
.
qkv_proj
...
...
tests/samplers/test_ignore_eos.py
View file @
d9784107
...
@@ -36,7 +36,7 @@ def test_ignore_eos(
...
@@ -36,7 +36,7 @@ def test_ignore_eos(
ignore_eos
=
True
)
ignore_eos
=
True
)
for
prompt
in
example_prompts
:
for
prompt
in
example_prompts
:
ignore_eos_output
=
vllm_model
.
model
.
generate
(
ignore_eos_output
=
vllm_model
.
llm
.
generate
(
prompt
,
sampling_params
=
sampling_params
)
prompt
,
sampling_params
=
sampling_params
)
output_length
=
len
(
ignore_eos_output
[
0
].
outputs
[
0
].
token_ids
)
output_length
=
len
(
ignore_eos_output
[
0
].
outputs
[
0
].
token_ids
)
assert
output_length
==
max_tokens
assert
output_length
==
max_tokens
tests/samplers/test_logits_processor.py
View file @
d9784107
...
@@ -26,7 +26,7 @@ def test_logits_processor_force_generate(
...
@@ -26,7 +26,7 @@ def test_logits_processor_force_generate(
dtype
:
str
,
dtype
:
str
,
)
->
None
:
)
->
None
:
with
vllm_runner
(
model
,
dtype
=
dtype
)
as
vllm_model
:
with
vllm_runner
(
model
,
dtype
=
dtype
)
as
vllm_model
:
tokenizer
=
vllm_model
.
model
.
get_tokenizer
()
tokenizer
=
vllm_model
.
llm
.
get_tokenizer
()
repeat_times
=
2
repeat_times
=
2
enforced_answers
=
" vLLM"
enforced_answers
=
" vLLM"
vllm_token_ids
=
tokenizer
.
encode
(
enforced_answers
,
vllm_token_ids
=
tokenizer
.
encode
(
enforced_answers
,
...
@@ -45,13 +45,13 @@ def test_logits_processor_force_generate(
...
@@ -45,13 +45,13 @@ def test_logits_processor_force_generate(
)
)
# test logits_processors when prompt_logprobs is not None
# test logits_processors when prompt_logprobs is not None
vllm_model
.
model
.
_add_request
(
vllm_model
.
llm
.
_add_request
(
example_prompts
[
0
],
example_prompts
[
0
],
params
=
params_with_logprobs
,
params
=
params_with_logprobs
,
)
)
# test prompt_logprobs is not None
# test prompt_logprobs is not None
vllm_model
.
model
.
_add_request
(
vllm_model
.
llm
.
_add_request
(
example_prompts
[
1
],
example_prompts
[
1
],
params
=
SamplingParams
(
params
=
SamplingParams
(
prompt_logprobs
=
3
,
prompt_logprobs
=
3
,
...
@@ -60,11 +60,11 @@ def test_logits_processor_force_generate(
...
@@ -60,11 +60,11 @@ def test_logits_processor_force_generate(
)
)
# test grouped requests
# test grouped requests
vllm_model
.
model
.
_add_request
(
vllm_model
.
llm
.
_add_request
(
example_prompts
[
2
],
example_prompts
[
2
],
params
=
SamplingParams
(
max_tokens
=
max_tokens
),
params
=
SamplingParams
(
max_tokens
=
max_tokens
),
)
)
outputs
=
vllm_model
.
model
.
_run_engine
(
use_tqdm
=
False
)
outputs
=
vllm_model
.
llm
.
_run_engine
(
use_tqdm
=
False
)
assert
outputs
[
0
].
outputs
[
0
].
text
==
enforced_answers
*
repeat_times
assert
outputs
[
0
].
outputs
[
0
].
text
==
enforced_answers
*
repeat_times
tests/samplers/test_logprobs.py
View file @
d9784107
...
@@ -64,7 +64,7 @@ def test_get_prompt_logprobs(
...
@@ -64,7 +64,7 @@ def test_get_prompt_logprobs(
prompt_logprobs
=
num_top_logprobs
,
prompt_logprobs
=
num_top_logprobs
,
temperature
=
0.0
,
temperature
=
0.0
,
detokenize
=
detokenize
)
detokenize
=
detokenize
)
vllm_results
=
vllm_model
.
model
.
generate
(
vllm_results
=
vllm_model
.
llm
.
generate
(
example_prompts
,
sampling_params
=
vllm_sampling_params
)
example_prompts
,
sampling_params
=
vllm_sampling_params
)
# Test whether logprobs are included in the results.
# Test whether logprobs are included in the results.
...
@@ -174,7 +174,7 @@ def test_none_logprobs(vllm_runner, model, chunked_prefill_token_size: int,
...
@@ -174,7 +174,7 @@ def test_none_logprobs(vllm_runner, model, chunked_prefill_token_size: int,
logprobs
=
None
,
logprobs
=
None
,
temperature
=
0.0
,
temperature
=
0.0
,
detokenize
=
detokenize
)
detokenize
=
detokenize
)
results_logprobs_none
=
vllm_model
.
model
.
generate
(
results_logprobs_none
=
vllm_model
.
llm
.
generate
(
example_prompts
,
sampling_params
=
sampling_params_logprobs_none
)
example_prompts
,
sampling_params
=
sampling_params_logprobs_none
)
for
i
in
range
(
len
(
results_logprobs_none
)):
for
i
in
range
(
len
(
results_logprobs_none
)):
...
...
tests/samplers/test_no_bad_words.py
View file @
d9784107
...
@@ -20,7 +20,7 @@ def v1(run_with_both_engines):
...
@@ -20,7 +20,7 @@ def v1(run_with_both_engines):
def
_generate
(
def
_generate
(
model
:
LLM
,
llm
:
LLM
,
prompt
:
str
,
prompt
:
str
,
num_prompt_tokens
:
int
,
num_prompt_tokens
:
int
,
temperature
:
float
=
0
,
temperature
:
float
=
0
,
...
@@ -32,7 +32,7 @@ def _generate(
...
@@ -32,7 +32,7 @@ def _generate(
)
)
# [([output_token_ids, ], [output_text, ]), ]
# [([output_token_ids, ], [output_text, ]), ]
output
=
model
.
generate
([
prompt
],
sampling_params
=
sampling_params
)
output
=
llm
.
generate
([
prompt
],
sampling_params
=
sampling_params
)
output_token_ids
=
output
[
0
][
0
][
0
][
num_prompt_tokens
:]
output_token_ids
=
output
[
0
][
0
][
0
][
num_prompt_tokens
:]
# [0] first (and only) request output
# [0] first (and only) request output
...
@@ -66,10 +66,10 @@ class TestOneTokenBadWord:
...
@@ -66,10 +66,10 @@ class TestOneTokenBadWord:
assert
self
.
target_token_id
not
in
output_token_ids
assert
self
.
target_token_id
not
in
output_token_ids
def
_generate
(
self
,
def
_generate
(
self
,
model
:
LLM
,
llm
:
LLM
,
bad_words
:
Optional
[
list
[
str
]]
=
None
)
->
list
[
int
]:
bad_words
:
Optional
[
list
[
str
]]
=
None
)
->
list
[
int
]:
return
_generate
(
return
_generate
(
model
=
model
,
llm
=
llm
,
prompt
=
self
.
PROMPT
,
prompt
=
self
.
PROMPT
,
num_prompt_tokens
=
self
.
num_prompt_tokens
,
num_prompt_tokens
=
self
.
num_prompt_tokens
,
bad_words
=
bad_words
,
bad_words
=
bad_words
,
...
@@ -156,10 +156,10 @@ class TestTwoTokenBadWord:
...
@@ -156,10 +156,10 @@ class TestTwoTokenBadWord:
or
(
self
.
neighbour_token_id2
in
output_token_ids
))
or
(
self
.
neighbour_token_id2
in
output_token_ids
))
def
_generate
(
self
,
def
_generate
(
self
,
model
:
LLM
,
llm
:
LLM
,
bad_words
:
Optional
[
list
[
str
]]
=
None
)
->
list
[
int
]:
bad_words
:
Optional
[
list
[
str
]]
=
None
)
->
list
[
int
]:
return
_generate
(
return
_generate
(
model
=
model
,
llm
=
llm
,
prompt
=
self
.
PROMPT
,
prompt
=
self
.
PROMPT
,
num_prompt_tokens
=
self
.
num_prompt_tokens
,
num_prompt_tokens
=
self
.
num_prompt_tokens
,
bad_words
=
bad_words
,
bad_words
=
bad_words
,
...
...
tests/samplers/test_seeded_generate.py
View file @
d9784107
...
@@ -49,7 +49,7 @@ def test_random_sample_with_seed(
...
@@ -49,7 +49,7 @@ def test_random_sample_with_seed(
sampling_params_seed_2
=
copy
.
deepcopy
(
sampling_params
)
sampling_params_seed_2
=
copy
.
deepcopy
(
sampling_params
)
sampling_params_seed_2
.
seed
=
200
sampling_params_seed_2
.
seed
=
200
llm
=
vllm_model
.
model
llm
=
vllm_model
.
llm
for
prompt
in
example_prompts
:
for
prompt
in
example_prompts
:
for
params
in
(
for
params
in
(
...
...
tests/tokenization/test_detokenize.py
View file @
d9784107
...
@@ -393,7 +393,7 @@ def test_decode_prompt_logprobs_chunked_prefill(
...
@@ -393,7 +393,7 @@ def test_decode_prompt_logprobs_chunked_prefill(
logprobs
=
5
,
logprobs
=
5
,
prompt_logprobs
=
5
,
prompt_logprobs
=
5
,
temperature
=
0.0
)
temperature
=
0.0
)
vllm_results
=
vllm_model
.
model
.
generate
(
vllm_results
=
vllm_model
.
llm
.
generate
(
example_prompts
,
sampling_params
=
vllm_sampling_params
)
example_prompts
,
sampling_params
=
vllm_sampling_params
)
for
idx
,
result
in
enumerate
(
vllm_results
):
for
idx
,
result
in
enumerate
(
vllm_results
):
...
...
tests/v1/core/test_scheduler_e2e.py
View file @
d9784107
...
@@ -14,7 +14,7 @@ PROMPT = "Hello my name is Robert and I"
...
@@ -14,7 +14,7 @@ PROMPT = "Hello my name is Robert and I"
@
pytest
.
fixture
(
scope
=
"module"
)
@
pytest
.
fixture
(
scope
=
"module"
)
def
model
()
->
LLM
:
def
llm
()
->
LLM
:
return
LLM
(
MODEL
,
return
LLM
(
MODEL
,
enforce_eager
=
True
,
enforce_eager
=
True
,
enable_prefix_caching
=
True
,
enable_prefix_caching
=
True
,
...
@@ -24,16 +24,16 @@ def model() -> LLM:
...
@@ -24,16 +24,16 @@ def model() -> LLM:
block_size
=
16
)
block_size
=
16
)
def
test_concurrent_partial_prefill
(
model
):
def
test_concurrent_partial_prefill
(
llm
):
outputs
=
model
.
generate
([
PROMPT
]
*
3
)
outputs
=
llm
.
generate
([
PROMPT
]
*
3
)
assert
len
(
outputs
)
==
3
assert
len
(
outputs
)
==
3
for
output
in
outputs
:
for
output
in
outputs
:
assert
len
(
output
.
outputs
)
==
1
assert
len
(
output
.
outputs
)
==
1
def
test_prefix_cache_stats_is_recorded
(
model
):
def
test_prefix_cache_stats_is_recorded
(
llm
):
# 17 tokens will make sure first 16 tokens are cached in a block
# 17 tokens will make sure first 16 tokens are cached in a block
input_tokens
=
{
"prompt_token_ids"
:
[
101
]
*
17
}
input_tokens
=
{
"prompt_token_ids"
:
[
101
]
*
17
}
_
=
model
.
generate
([
input_tokens
])
_
=
llm
.
generate
([
input_tokens
])
outputs
=
model
.
generate
([
input_tokens
])
outputs
=
llm
.
generate
([
input_tokens
])
assert
outputs
[
0
].
num_cached_tokens
==
16
assert
outputs
[
0
].
num_cached_tokens
==
16
tests/v1/engine/test_llm_engine.py
View file @
d9784107
...
@@ -112,9 +112,9 @@ def test_compatibility_with_skip_tokenizer_init(
...
@@ -112,9 +112,9 @@ def test_compatibility_with_skip_tokenizer_init(
example_prompts
,
example_prompts
,
structured_outputs
=
True
,
structured_outputs
=
True
,
)
)
model
:
LLM
=
vllm_model_skip_tokenizer_init
.
model
llm
:
LLM
=
vllm_model_skip_tokenizer_init
.
llm
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
_
=
model
.
generate
(
example_prompts
,
sampling_params_list
)
_
=
llm
.
generate
(
example_prompts
,
sampling_params_list
)
def
test_parallel_sampling
(
vllm_model
,
example_prompts
)
->
None
:
def
test_parallel_sampling
(
vllm_model
,
example_prompts
)
->
None
:
...
@@ -125,8 +125,8 @@ def test_parallel_sampling(vllm_model, example_prompts) -> None:
...
@@ -125,8 +125,8 @@ def test_parallel_sampling(vllm_model, example_prompts) -> None:
example_prompt: test fixture providing prompts for testing.
example_prompt: test fixture providing prompts for testing.
"""
"""
sampling_params_list
,
n_list
=
_get_test_sampling_params
(
example_prompts
)
sampling_params_list
,
n_list
=
_get_test_sampling_params
(
example_prompts
)
model
:
LLM
=
vllm_model
.
model
llm
:
LLM
=
vllm_model
.
llm
outputs
=
model
.
generate
(
example_prompts
,
sampling_params_list
)
outputs
=
llm
.
generate
(
example_prompts
,
sampling_params_list
)
# Validate each request response
# Validate each request response
for
out
,
n
in
zip
(
outputs
,
n_list
):
for
out
,
n
in
zip
(
outputs
,
n_list
):
...
@@ -166,10 +166,10 @@ def test_engine_metrics(vllm_runner, monkeypatch, example_prompts):
...
@@ -166,10 +166,10 @@ def test_engine_metrics(vllm_runner, monkeypatch, example_prompts):
speculative_config
=
speculative_config
,
speculative_config
=
speculative_config
,
disable_log_stats
=
False
,
disable_log_stats
=
False
,
)
as
vllm_model
:
)
as
vllm_model
:
model
:
LLM
=
vllm_model
.
model
llm
:
LLM
=
vllm_model
.
llm
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
max_tokens
)
max_tokens
=
max_tokens
)
outputs
=
model
.
generate
(
example_prompts
,
sampling_params
)
outputs
=
llm
.
generate
(
example_prompts
,
sampling_params
)
n_prompts
=
len
(
example_prompts
)
n_prompts
=
len
(
example_prompts
)
assert
len
(
outputs
)
==
n_prompts
assert
len
(
outputs
)
==
n_prompts
...
@@ -180,7 +180,7 @@ def test_engine_metrics(vllm_runner, monkeypatch, example_prompts):
...
@@ -180,7 +180,7 @@ def test_engine_metrics(vllm_runner, monkeypatch, example_prompts):
total_tokens
+=
len
(
out
.
outputs
[
0
].
token_ids
)
total_tokens
+=
len
(
out
.
outputs
[
0
].
token_ids
)
assert
total_tokens
==
max_tokens
*
n_prompts
assert
total_tokens
==
max_tokens
*
n_prompts
metrics
=
model
.
get_metrics
()
metrics
=
llm
.
get_metrics
()
def
find_metric
(
name
)
->
list
[
Metric
]:
def
find_metric
(
name
)
->
list
[
Metric
]:
found
=
[]
found
=
[]
...
...
tests/v1/sample/test_logprobs.py
View file @
d9784107
...
@@ -112,7 +112,7 @@ def _run_and_validate(
...
@@ -112,7 +112,7 @@ def _run_and_validate(
max_tokens
:
int
,
max_tokens
:
int
,
do_apc
:
bool
,
do_apc
:
bool
,
)
->
None
:
)
->
None
:
vllm_results
=
vllm_model
.
model
.
generate
(
vllm_results
=
vllm_model
.
llm
.
generate
(
test_prompts
,
sampling_params
=
vllm_sampling_params
)
test_prompts
,
sampling_params
=
vllm_sampling_params
)
for
vllm_result
,
hf_logprob
,
hf_output
,
logprob_prompt_logprob
in
zip
(
for
vllm_result
,
hf_logprob
,
hf_output
,
logprob_prompt_logprob
in
zip
(
...
@@ -288,7 +288,7 @@ def test_get_logprobs_and_prompt_logprobs(
...
@@ -288,7 +288,7 @@ def test_get_logprobs_and_prompt_logprobs(
"""
"""
with
monkeypatch
.
context
()
as
m
:
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
do_apc
=
vllm_model
.
model
.
llm_engine
.
cache_config
.
enable_prefix_caching
do_apc
=
vllm_model
.
llm
.
llm_engine
.
cache_config
.
enable_prefix_caching
if
do_apc
and
(
temperature
<
2.0
if
do_apc
and
(
temperature
<
2.0
or
batch_logprobs_composition
!=
SAMPLE_PROMPT
):
or
batch_logprobs_composition
!=
SAMPLE_PROMPT
):
# Skip some test-cases to save time.
# Skip some test-cases to save time.
...
@@ -378,7 +378,7 @@ def test_none_logprobs(vllm_model, example_prompts,
...
@@ -378,7 +378,7 @@ def test_none_logprobs(vllm_model, example_prompts,
prompt_logprobs
=
None
,
prompt_logprobs
=
None
,
temperature
=
0.0
,
temperature
=
0.0
,
)
)
results_logprobs_none
=
vllm_model
.
model
.
generate
(
results_logprobs_none
=
vllm_model
.
llm
.
generate
(
example_prompts
,
example_prompts
,
sampling_params
=
sampling_params_logprobs_none
,
sampling_params
=
sampling_params_logprobs_none
,
)
)
...
@@ -408,7 +408,7 @@ def test_zero_logprobs(vllm_model, example_prompts,
...
@@ -408,7 +408,7 @@ def test_zero_logprobs(vllm_model, example_prompts,
logprobs
=
0
,
logprobs
=
0
,
prompt_logprobs
=
0
,
prompt_logprobs
=
0
,
temperature
=
0.0
)
temperature
=
0.0
)
results_logprobs_zero
=
vllm_model
.
model
.
generate
(
results_logprobs_zero
=
vllm_model
.
llm
.
generate
(
example_prompts
,
sampling_params
=
sampling_params_logprobs_zero
)
example_prompts
,
sampling_params
=
sampling_params_logprobs_zero
)
for
i
in
range
(
len
(
results_logprobs_zero
)):
for
i
in
range
(
len
(
results_logprobs_zero
)):
...
...
tests/v1/sample/test_sampling_params_e2e.py
View file @
d9784107
...
@@ -14,30 +14,30 @@ PROMPT = "Hello my name is Robert and I"
...
@@ -14,30 +14,30 @@ PROMPT = "Hello my name is Robert and I"
@
pytest
.
fixture
(
scope
=
"module"
)
@
pytest
.
fixture
(
scope
=
"module"
)
def
model
()
->
LLM
:
def
llm
()
->
LLM
:
# Disable prefix caching so that we can test prompt logprobs.
# Disable prefix caching so that we can test prompt logprobs.
# TODO remove this after https://github.com/vllm-project/vllm/pull/13949
# TODO remove this after https://github.com/vllm-project/vllm/pull/13949
# is merged
# is merged
return
LLM
(
MODEL
,
enforce_eager
=
True
,
enable_prefix_caching
=
False
)
return
LLM
(
MODEL
,
enforce_eager
=
True
,
enable_prefix_caching
=
False
)
def
test_n_gt_1
(
model
):
def
test_n_gt_1
(
llm
):
"""ParallelSampling is supported."""
"""ParallelSampling is supported."""
params
=
SamplingParams
(
n
=
3
)
params
=
SamplingParams
(
n
=
3
)
outputs
=
model
.
generate
(
PROMPT
,
params
)
outputs
=
llm
.
generate
(
PROMPT
,
params
)
assert
len
(
outputs
[
0
].
outputs
)
==
3
assert
len
(
outputs
[
0
].
outputs
)
==
3
def
test_best_of
(
model
):
def
test_best_of
(
llm
):
"""Raise a ValueError since best_of is deprecated."""
"""Raise a ValueError since best_of is deprecated."""
params
=
SamplingParams
(
n
=
2
,
best_of
=
3
)
params
=
SamplingParams
(
n
=
2
,
best_of
=
3
)
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
_
=
model
.
generate
(
PROMPT
,
params
)
_
=
llm
.
generate
(
PROMPT
,
params
)
def
test_penalties
(
model
):
def
test_penalties
(
llm
):
"""Check that we do not get errors if applied."""
"""Check that we do not get errors if applied."""
params
=
SamplingParams
(
params
=
SamplingParams
(
...
@@ -49,18 +49,18 @@ def test_penalties(model):
...
@@ -49,18 +49,18 @@ def test_penalties(model):
top_p
=
0.5
,
top_p
=
0.5
,
top_k
=
3
,
top_k
=
3
,
)
)
_
=
model
.
generate
(
PROMPT
,
params
)
_
=
llm
.
generate
(
PROMPT
,
params
)
def
test_stop
(
model
):
def
test_stop
(
llm
):
"""Check that we respect the stop words."""
"""Check that we respect the stop words."""
output
=
model
.
generate
(
PROMPT
,
SamplingParams
(
temperature
=
0
))
output
=
llm
.
generate
(
PROMPT
,
SamplingParams
(
temperature
=
0
))
split_text
=
output
[
0
].
outputs
[
0
].
text
.
split
()
split_text
=
output
[
0
].
outputs
[
0
].
text
.
split
()
STOP_IDX
=
5
STOP_IDX
=
5
params
=
SamplingParams
(
temperature
=
0
,
stop
=
split_text
[
STOP_IDX
])
params
=
SamplingParams
(
temperature
=
0
,
stop
=
split_text
[
STOP_IDX
])
output
=
model
.
generate
(
PROMPT
,
params
)
output
=
llm
.
generate
(
PROMPT
,
params
)
new_split_text
=
output
[
0
].
outputs
[
0
].
text
.
split
()
new_split_text
=
output
[
0
].
outputs
[
0
].
text
.
split
()
# Output should not contain the stop word.
# Output should not contain the stop word.
...
@@ -69,40 +69,40 @@ def test_stop(model):
...
@@ -69,40 +69,40 @@ def test_stop(model):
params
=
SamplingParams
(
temperature
=
0
,
params
=
SamplingParams
(
temperature
=
0
,
stop
=
split_text
[
STOP_IDX
],
stop
=
split_text
[
STOP_IDX
],
include_stop_str_in_output
=
True
)
include_stop_str_in_output
=
True
)
output
=
model
.
generate
(
PROMPT
,
params
)
output
=
llm
.
generate
(
PROMPT
,
params
)
new_split_text
=
output
[
0
].
outputs
[
0
].
text
.
split
()
new_split_text
=
output
[
0
].
outputs
[
0
].
text
.
split
()
# Output should contain the stop word.
# Output should contain the stop word.
assert
len
(
new_split_text
)
==
STOP_IDX
+
1
assert
len
(
new_split_text
)
==
STOP_IDX
+
1
def
test_stop_token_ids
(
model
):
def
test_stop_token_ids
(
llm
):
"""Check that we respect the stop token ids."""
"""Check that we respect the stop token ids."""
output
=
model
.
generate
(
PROMPT
,
SamplingParams
(
temperature
=
0
))
output
=
llm
.
generate
(
PROMPT
,
SamplingParams
(
temperature
=
0
))
stop_token_id_0
=
output
[
0
].
outputs
[
0
].
token_ids
[
5
]
stop_token_id_0
=
output
[
0
].
outputs
[
0
].
token_ids
[
5
]
stop_token_id_1
=
output
[
0
].
outputs
[
0
].
token_ids
[
6
]
stop_token_id_1
=
output
[
0
].
outputs
[
0
].
token_ids
[
6
]
stop_token_ids
=
[
stop_token_id_1
,
stop_token_id_0
]
stop_token_ids
=
[
stop_token_id_1
,
stop_token_id_0
]
params
=
SamplingParams
(
temperature
=
0
,
stop_token_ids
=
stop_token_ids
)
params
=
SamplingParams
(
temperature
=
0
,
stop_token_ids
=
stop_token_ids
)
output
=
model
.
generate
(
PROMPT
,
params
)
output
=
llm
.
generate
(
PROMPT
,
params
)
assert
output
[
0
].
outputs
[
0
].
token_ids
[
-
1
]
==
stop_token_id_0
assert
output
[
0
].
outputs
[
0
].
token_ids
[
-
1
]
==
stop_token_id_0
stop_token_ids
=
[
stop_token_id_0
,
stop_token_id_1
]
stop_token_ids
=
[
stop_token_id_0
,
stop_token_id_1
]
params
=
SamplingParams
(
temperature
=
0
,
stop_token_ids
=
stop_token_ids
)
params
=
SamplingParams
(
temperature
=
0
,
stop_token_ids
=
stop_token_ids
)
output
=
model
.
generate
(
PROMPT
,
params
)
output
=
llm
.
generate
(
PROMPT
,
params
)
assert
output
[
0
].
outputs
[
0
].
token_ids
[
-
1
]
==
stop_token_id_0
assert
output
[
0
].
outputs
[
0
].
token_ids
[
-
1
]
==
stop_token_id_0
def
test_detokenize_false
(
model
):
def
test_detokenize_false
(
llm
):
"""Check that detokenize=False option works."""
"""Check that detokenize=False option works."""
output
=
model
.
generate
(
PROMPT
,
SamplingParams
(
detokenize
=
False
))
output
=
llm
.
generate
(
PROMPT
,
SamplingParams
(
detokenize
=
False
))
assert
len
(
output
[
0
].
outputs
[
0
].
token_ids
)
>
0
assert
len
(
output
[
0
].
outputs
[
0
].
token_ids
)
>
0
assert
len
(
output
[
0
].
outputs
[
0
].
text
)
==
0
assert
len
(
output
[
0
].
outputs
[
0
].
text
)
==
0
output
=
model
.
generate
(
output
=
llm
.
generate
(
PROMPT
,
SamplingParams
(
detokenize
=
False
,
logprobs
=
3
,
PROMPT
,
SamplingParams
(
detokenize
=
False
,
logprobs
=
3
,
prompt_logprobs
=
3
))
prompt_logprobs
=
3
))
assert
len
(
output
[
0
].
outputs
[
0
].
token_ids
)
>
0
assert
len
(
output
[
0
].
outputs
[
0
].
token_ids
)
>
0
...
@@ -118,28 +118,28 @@ def test_detokenize_false(model):
...
@@ -118,28 +118,28 @@ def test_detokenize_false(model):
assert
all
(
lp
.
decoded_token
is
None
for
lp
in
logprobs
.
values
())
assert
all
(
lp
.
decoded_token
is
None
for
lp
in
logprobs
.
values
())
def
test_bad_words
(
model
):
def
test_bad_words
(
llm
):
"""Check that we respect bad words."""
"""Check that we respect bad words."""
output
=
model
.
generate
(
PROMPT
,
SamplingParams
(
temperature
=
0
))
output
=
llm
.
generate
(
PROMPT
,
SamplingParams
(
temperature
=
0
))
split_text
=
output
[
0
].
outputs
[
0
].
text
.
split
()
split_text
=
output
[
0
].
outputs
[
0
].
text
.
split
()
bad_words_1
=
" "
.
join
(
split_text
[:
2
])
bad_words_1
=
" "
.
join
(
split_text
[:
2
])
params
=
SamplingParams
(
temperature
=
0
,
bad_words
=
[
bad_words_1
])
params
=
SamplingParams
(
temperature
=
0
,
bad_words
=
[
bad_words_1
])
output
=
model
.
generate
(
PROMPT
,
params
)
output
=
llm
.
generate
(
PROMPT
,
params
)
new_text
=
output
[
0
].
outputs
[
0
].
text
new_text
=
output
[
0
].
outputs
[
0
].
text
assert
bad_words_1
not
in
new_text
assert
bad_words_1
not
in
new_text
bad_words_2
=
new_text
.
split
()[
-
1
]
bad_words_2
=
new_text
.
split
()[
-
1
]
params
=
SamplingParams
(
temperature
=
0
,
params
=
SamplingParams
(
temperature
=
0
,
bad_words
=
[
bad_words_1
,
bad_words_2
])
bad_words
=
[
bad_words_1
,
bad_words_2
])
output
=
model
.
generate
(
PROMPT
,
params
)
output
=
llm
.
generate
(
PROMPT
,
params
)
new_text
=
output
[
0
].
outputs
[
0
].
text
new_text
=
output
[
0
].
outputs
[
0
].
text
assert
bad_words_1
not
in
new_text
assert
bad_words_1
not
in
new_text
assert
bad_words_2
not
in
new_text
assert
bad_words_2
not
in
new_text
def
test_logits_processor
(
model
):
def
test_logits_processor
(
llm
):
"""Check that we reject logits processor."""
"""Check that we reject logits processor."""
# This sample logits processor gives infinite score to the i-th token,
# This sample logits processor gives infinite score to the i-th token,
...
@@ -150,47 +150,45 @@ def test_logits_processor(model):
...
@@ -150,47 +150,45 @@ def test_logits_processor(model):
return
logits
return
logits
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
_
=
model
.
generate
(
PROMPT
,
_
=
llm
.
generate
(
PROMPT
,
SamplingParams
(
logits_processors
=
[
pick_ith
]))
SamplingParams
(
logits_processors
=
[
pick_ith
]))
def
test_allowed_token_ids
(
model
):
def
test_allowed_token_ids
(
llm
):
"""Check that we can use allowed_token_ids."""
"""Check that we can use allowed_token_ids."""
TOKEN_ID
=
10
TOKEN_ID
=
10
allowed_token_ids
=
[
TOKEN_ID
]
allowed_token_ids
=
[
TOKEN_ID
]
output
=
model
.
generate
(
output
=
llm
.
generate
(
PROMPT
,
PROMPT
,
SamplingParams
(
allowed_token_ids
=
allowed_token_ids
))
SamplingParams
(
allowed_token_ids
=
allowed_token_ids
))
assert
output
[
0
].
outputs
[
0
].
token_ids
[
-
1
]
==
TOKEN_ID
assert
output
[
0
].
outputs
[
0
].
token_ids
[
-
1
]
==
TOKEN_ID
# Reject empty allowed_token_ids.
# Reject empty allowed_token_ids.
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
_
=
model
.
generate
(
PROMPT
,
SamplingParams
(
allowed_token_ids
=
[]))
_
=
llm
.
generate
(
PROMPT
,
SamplingParams
(
allowed_token_ids
=
[]))
# Reject negative token id.
# Reject negative token id.
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
_
=
model
.
generate
(
PROMPT
,
SamplingParams
(
allowed_token_ids
=
[
-
1
]))
_
=
llm
.
generate
(
PROMPT
,
SamplingParams
(
allowed_token_ids
=
[
-
1
]))
# Reject out of vocabulary.
# Reject out of vocabulary.
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
_
=
model
.
generate
(
PROMPT
,
_
=
llm
.
generate
(
PROMPT
,
SamplingParams
(
allowed_token_ids
=
[
10000000
]))
SamplingParams
(
allowed_token_ids
=
[
10000000
]))
def
test_priority
(
model
):
def
test_priority
(
llm
):
"""Check that we reject requests with priority."""
"""Check that we reject requests with priority."""
# Reject all allowed token ids
# Reject all allowed token ids
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
_
=
model
.
generate
(
PROMPT
,
priority
=
[
1
])
_
=
llm
.
generate
(
PROMPT
,
priority
=
[
1
])
def
test_seed
(
model
):
def
test_seed
(
llm
):
"""Check that seed impacts randomness."""
"""Check that seed impacts randomness."""
out_1
=
model
.
generate
(
PROMPT
,
SamplingParams
(
seed
=
42
))
out_1
=
llm
.
generate
(
PROMPT
,
SamplingParams
(
seed
=
42
))
out_2
=
model
.
generate
(
PROMPT
,
SamplingParams
(
seed
=
42
))
out_2
=
llm
.
generate
(
PROMPT
,
SamplingParams
(
seed
=
42
))
out_3
=
model
.
generate
(
PROMPT
,
SamplingParams
(
seed
=
43
))
out_3
=
llm
.
generate
(
PROMPT
,
SamplingParams
(
seed
=
43
))
assert
out_1
[
0
].
outputs
[
0
].
text
==
out_2
[
0
].
outputs
[
0
].
text
assert
out_1
[
0
].
outputs
[
0
].
text
==
out_2
[
0
].
outputs
[
0
].
text
assert
out_1
[
0
].
outputs
[
0
].
text
!=
out_3
[
0
].
outputs
[
0
].
text
assert
out_1
[
0
].
outputs
[
0
].
text
!=
out_3
[
0
].
outputs
[
0
].
text
tests/v1/test_oracle.py
View file @
d9784107
...
@@ -106,9 +106,9 @@ def test_v1_llm_by_default(monkeypatch):
...
@@ -106,9 +106,9 @@ def test_v1_llm_by_default(monkeypatch):
m
.
delenv
(
"VLLM_USE_V1"
)
m
.
delenv
(
"VLLM_USE_V1"
)
# Should default to V1 for supported config.
# Should default to V1 for supported config.
model
=
LLM
(
MODEL
,
enforce_eager
=
True
,
enable_lora
=
True
)
llm
=
LLM
(
MODEL
,
enforce_eager
=
True
,
enable_lora
=
True
)
print
(
model
.
generate
(
"Hello my name is"
))
print
(
llm
.
generate
(
"Hello my name is"
))
assert
hasattr
(
model
.
llm_engine
,
"engine_core"
)
assert
hasattr
(
llm
.
llm_engine
,
"engine_core"
)
m
.
delenv
(
"VLLM_USE_V1"
)
m
.
delenv
(
"VLLM_USE_V1"
)
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment