Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7a985548
Commit
7a985548
authored
May 22, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.9.0' into v0.9.0-ori
parents
45d3785c
dc1440cf
Changes
486
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1300 additions
and
203 deletions
+1300
-203
tests/compile/test_silu_mul_quant_fusion.py
tests/compile/test_silu_mul_quant_fusion.py
+73
-0
tests/conftest.py
tests/conftest.py
+46
-43
tests/core/test_scheduler.py
tests/core/test_scheduler.py
+73
-1
tests/core/utils.py
tests/core/utils.py
+9
-2
tests/distributed/conftest.py
tests/distributed/conftest.py
+145
-0
tests/distributed/test_events.py
tests/distributed/test_events.py
+193
-0
tests/distributed/test_pipeline_parallel.py
tests/distributed/test_pipeline_parallel.py
+8
-4
tests/distributed/test_sequence_parallel.py
tests/distributed/test_sequence_parallel.py
+37
-4
tests/distributed/test_torchrun_example.py
tests/distributed/test_torchrun_example.py
+2
-1
tests/engine/test_arg_utils.py
tests/engine/test_arg_utils.py
+95
-29
tests/engine/test_options.py
tests/engine/test_options.py
+60
-0
tests/entrypoints/llm/test_chat.py
tests/entrypoints/llm/test_chat.py
+93
-16
tests/entrypoints/llm/test_collective_rpc.py
tests/entrypoints/llm/test_collective_rpc.py
+2
-1
tests/entrypoints/llm/test_guided_generate.py
tests/entrypoints/llm/test_guided_generate.py
+125
-81
tests/entrypoints/openai/test_audio.py
tests/entrypoints/openai/test_audio.py
+4
-4
tests/entrypoints/openai/test_chat_template.py
tests/entrypoints/openai/test_chat_template.py
+19
-3
tests/entrypoints/openai/test_chat_with_tool_reasoning.py
tests/entrypoints/openai/test_chat_with_tool_reasoning.py
+3
-3
tests/entrypoints/openai/test_classification.py
tests/entrypoints/openai/test_classification.py
+156
-0
tests/entrypoints/openai/test_cli_args.py
tests/entrypoints/openai/test_cli_args.py
+3
-11
tests/entrypoints/openai/test_completion_with_function_calling.py
...trypoints/openai/test_completion_with_function_calling.py
+154
-0
No files found.
Too many changes to show.
To preserve performance only
486 of 486+
files are displayed.
Plain diff
Email patch
tests/compile/test_silu_mul_quant_fusion.py
0 → 100644
View file @
7a985548
# SPDX-License-Identifier: Apache-2.0
import
pytest
import
torch
import
vllm.envs
as
envs
from
vllm._custom_ops
import
scaled_fp8_quant
from
vllm.compilation.activation_quant_fusion
import
ActivationQuantFusionPass
from
vllm.compilation.fx_utils
import
find_auto_fn
,
find_auto_fn_maybe
from
vllm.config
import
CompilationConfig
,
PassConfig
,
VllmConfig
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
.backend
import
TestBackend
class
TestModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
silu_and_mul
=
SiluAndMul
()
self
.
scale
=
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
def
forward
(
self
,
x
):
y
=
self
.
silu_and_mul
(
x
)
x2
=
scaled_fp8_quant
(
y
,
self
.
scale
)
return
x2
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
256
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
64
])
@
pytest
.
mark
.
skipif
(
envs
.
VLLM_TARGET_DEVICE
not
in
[
"cuda"
,
"rocm"
],
reason
=
"Only test on CUDA and ROCm"
)
def
test_fusion_silu_and_mul_quant
(
num_tokens
,
hidden_size
):
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_dtype
(
torch
.
float16
)
# Reshape pass is needed for the fusion pass to work
config
=
VllmConfig
()
config
.
compilation_config
=
CompilationConfig
(
pass_config
=
PassConfig
(
enable_fusion
=
True
,
enable_noop
=
True
))
fusion_pass
=
ActivationQuantFusionPass
(
config
)
backend
=
TestBackend
(
fusion_pass
)
model
=
TestModel
()
# First dimension dynamic
x
=
torch
.
rand
(
num_tokens
,
hidden_size
)
torch
.
_dynamo
.
mark_dynamic
(
x
,
0
)
result
=
model
(
x
)
model2
=
torch
.
compile
(
model
,
backend
=
backend
)
result2
=
model2
(
x
)
# Check that it gives the same answer
torch
.
testing
.
assert_close
(
result
[
0
].
to
(
dtype
=
torch
.
float16
),
result2
[
0
].
to
(
dtype
=
torch
.
float16
),
atol
=
1e-3
,
rtol
=
1e-3
)
# Check substitution worked
pre_nodes
=
backend
.
graph_pre_pass
.
nodes
post_nodes
=
backend
.
graph_post_pass
.
nodes
silu_and_mul_quant
=
torch
.
ops
.
_C
.
silu_and_mul_quant
.
default
fp8_quant
=
torch
.
ops
.
_C
.
static_scaled_fp8_quant
.
default
# In pre-nodes, fp8 quant should be present and fused kernels should not
assert
find_auto_fn_maybe
(
pre_nodes
,
silu_and_mul_quant
)
is
None
find_auto_fn
(
pre_nodes
,
fp8_quant
)
# In post-nodes, fused kernels should be present and fp8 quant should not
find_auto_fn
(
post_nodes
,
silu_and_mul_quant
)
assert
find_auto_fn_maybe
(
post_nodes
,
fp8_quant
)
is
None
tests/conftest.py
View file @
7a985548
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
json
import
json
import
os
import
os
import
tempfile
import
tempfile
from
collections
import
UserList
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
Any
,
Callable
,
Optional
,
TypedDict
,
TypeVar
,
Union
from
typing
import
Any
,
Callable
,
Optional
,
TypedDict
,
TypeVar
,
Union
...
@@ -58,16 +56,12 @@ def _read_prompts(filename: str) -> list[str]:
...
@@ -58,16 +56,12 @@ def _read_prompts(filename: str) -> list[str]:
return
prompts
return
prompts
class
_
ImageAssetPrompts
(
TypedDict
):
class
ImageAssetPrompts
(
TypedDict
):
stop_sign
:
str
stop_sign
:
str
cherry_blossom
:
str
cherry_blossom
:
str
class
_ImageAssetsBase
(
UserList
[
ImageAsset
]):
class
ImageTestAssets
(
list
[
ImageAsset
]):
pass
class
_ImageAssets
(
_ImageAssetsBase
):
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
super
().
__init__
([
super
().
__init__
([
...
@@ -75,7 +69,7 @@ class _ImageAssets(_ImageAssetsBase):
...
@@ -75,7 +69,7 @@ class _ImageAssets(_ImageAssetsBase):
ImageAsset
(
"cherry_blossom"
),
ImageAsset
(
"cherry_blossom"
),
])
])
def
prompts
(
self
,
prompts
:
_
ImageAssetPrompts
)
->
list
[
str
]:
def
prompts
(
self
,
prompts
:
ImageAssetPrompts
)
->
list
[
str
]:
"""
"""
Convenience method to define the prompt for each test image.
Convenience method to define the prompt for each test image.
...
@@ -85,30 +79,27 @@ class _ImageAssets(_ImageAssetsBase):
...
@@ -85,30 +79,27 @@ class _ImageAssets(_ImageAssetsBase):
return
[
prompts
[
"stop_sign"
],
prompts
[
"cherry_blossom"
]]
return
[
prompts
[
"stop_sign"
],
prompts
[
"cherry_blossom"
]]
class
_
VideoAssetPrompts
(
TypedDict
):
class
VideoAssetPrompts
(
TypedDict
):
sample_demo_1
:
str
baby_reading
:
str
class
_VideoAssetsBase
(
UserList
[
VideoAsset
]):
class
VideoTestAssets
(
list
[
VideoAsset
]):
pass
class
_VideoAssets
(
_VideoAssetsBase
):
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
super
().
__init__
([
super
().
__init__
([
VideoAsset
(
"
sample_demo_1.mp4
"
),
VideoAsset
(
"
baby_reading
"
),
])
])
def
prompts
(
self
,
prompts
:
_
VideoAssetPrompts
)
->
list
[
str
]:
def
prompts
(
self
,
prompts
:
VideoAssetPrompts
)
->
list
[
str
]:
return
[
prompts
[
"
sample_demo_1
"
]]
return
[
prompts
[
"
baby_reading
"
]]
class
_AudioAssetsBase
(
UserList
[
AudioAsset
]):
class
AudioAssetPrompts
(
TypedDict
):
pass
mary_had_lamb
:
str
winning_call
:
str
class
_
AudioAssets
(
_
AudioAsset
sBase
):
class
Audio
Test
Assets
(
list
[
AudioAsset
]
):
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
super
().
__init__
([
super
().
__init__
([
...
@@ -116,13 +107,16 @@ class _AudioAssets(_AudioAssetsBase):
...
@@ -116,13 +107,16 @@ class _AudioAssets(_AudioAssetsBase):
AudioAsset
(
"winning_call"
),
AudioAsset
(
"winning_call"
),
])
])
def
prompts
(
self
,
prompts
:
AudioAssetPrompts
)
->
list
[
str
]:
return
[
prompts
[
"mary_had_lamb"
],
prompts
[
"winning_call"
]]
IMAGE_ASSETS
=
_
ImageAssets
()
IMAGE_ASSETS
=
Image
Test
Assets
()
"""Singleton instance of
:
class
:`_
ImageAssets`."""
"""Singleton instance of
{
class
}`
Image
Test
Assets`."""
VIDEO_ASSETS
=
_
VideoAssets
()
VIDEO_ASSETS
=
Video
Test
Assets
()
"""Singleton instance of
:
class
:`_
VideoAssets`."""
"""Singleton instance of
{
class
}`
Video
Test
Assets`."""
AUDIO_ASSETS
=
_
AudioAssets
()
AUDIO_ASSETS
=
Audio
Test
Assets
()
"""Singleton instance of
:
class
:`_
AudioAssets`."""
"""Singleton instance of
{
class
}`
Audio
Test
Assets`."""
@
pytest
.
fixture
(
scope
=
"function"
,
autouse
=
True
)
@
pytest
.
fixture
(
scope
=
"function"
,
autouse
=
True
)
...
@@ -270,17 +264,17 @@ def example_long_prompts() -> list[str]:
...
@@ -270,17 +264,17 @@ def example_long_prompts() -> list[str]:
@
pytest
.
fixture
(
scope
=
"session"
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
image_assets
()
->
_
ImageAssets
:
def
image_assets
()
->
Image
Test
Assets
:
return
IMAGE_ASSETS
return
IMAGE_ASSETS
@
pytest
.
fixture
(
scope
=
"session"
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
video_assets
()
->
_
VideoAssets
:
def
video_assets
()
->
Video
Test
Assets
:
return
VIDEO_ASSETS
return
VIDEO_ASSETS
@
pytest
.
fixture
(
scope
=
"session"
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
audio_assets
()
->
_
AudioAssets
:
def
audio_assets
()
->
Audio
Test
Assets
:
return
AUDIO_ASSETS
return
AUDIO_ASSETS
...
@@ -293,7 +287,8 @@ class HfRunner:
...
@@ -293,7 +287,8 @@ class HfRunner:
def
get_default_device
(
self
):
def
get_default_device
(
self
):
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
return
(
"cpu"
if
current_platform
.
is_cpu
()
else
"cuda"
)
return
(
"cpu"
if
current_platform
.
is_cpu
()
else
current_platform
.
device_type
)
def
wrap_device
(
self
,
x
:
_T
,
device
:
Optional
[
str
]
=
None
)
->
_T
:
def
wrap_device
(
self
,
x
:
_T
,
device
:
Optional
[
str
]
=
None
)
->
_T
:
if
x
is
None
or
isinstance
(
x
,
(
bool
,
)):
if
x
is
None
or
isinstance
(
x
,
(
bool
,
)):
...
@@ -360,10 +355,16 @@ class HfRunner:
...
@@ -360,10 +355,16 @@ class HfRunner:
**
model_kwargs
,
**
model_kwargs
,
)
)
# in case some unquantized custom models are not in same dtype
if
(
getattr
(
model
,
"quantization_method"
,
None
)
is
None
and
any
(
p
.
dtype
!=
self
.
dtype
for
p
in
model
.
parameters
())):
model
=
model
.
to
(
dtype
=
self
.
dtype
)
if
(
getattr
(
model
,
"quantization_method"
,
None
)
!=
"bitsandbytes"
if
(
getattr
(
model
,
"quantization_method"
,
None
)
!=
"bitsandbytes"
and
len
({
p
.
device
and
len
({
p
.
device
for
p
in
model
.
parameters
()})
<
2
):
for
p
in
model
.
parameters
()})
<
2
):
model
=
model
.
to
(
self
.
device
)
model
=
model
.
to
(
device
=
self
.
device
)
self
.
model
=
model
self
.
model
=
model
...
@@ -729,7 +730,7 @@ def hf_runner():
...
@@ -729,7 +730,7 @@ def hf_runner():
class
VllmRunner
:
class
VllmRunner
:
"""
"""
The default value of some arguments have been modified from
The default value of some arguments have been modified from
:
class
:
`~vllm.LLM` as follows:
{
class
}
`~vllm.LLM` as follows:
- `trust_remote_code`: Set to `True` instead of `False` for convenience.
- `trust_remote_code`: Set to `True` instead of `False` for convenience.
- `seed`: Set to `0` instead of `None` for test reproducibility.
- `seed`: Set to `0` instead of `None` for test reproducibility.
...
@@ -737,7 +738,7 @@ class VllmRunner:
...
@@ -737,7 +738,7 @@ class VllmRunner:
- `block_size`: Set to `16` instead of `None` to reduce memory usage.
- `block_size`: Set to `16` instead of `None` to reduce memory usage.
- `enable_chunked_prefill`: Set to `False` instead of `None` for
- `enable_chunked_prefill`: Set to `False` instead of `None` for
test reproducibility.
test reproducibility.
- `enforce_eager`: Set to `False`
instead of `None`
to test CUDA graph.
- `enforce_eager`: Set to `False` to test CUDA graph.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -778,7 +779,7 @@ class VllmRunner:
...
@@ -778,7 +779,7 @@ class VllmRunner:
def
get_inputs
(
def
get_inputs
(
self
,
self
,
prompts
:
list
[
str
],
prompts
:
Union
[
list
[
str
],
list
[
torch
.
Tensor
]],
images
:
Optional
[
PromptImageInput
]
=
None
,
images
:
Optional
[
PromptImageInput
]
=
None
,
videos
:
Optional
[
PromptVideoInput
]
=
None
,
videos
:
Optional
[
PromptVideoInput
]
=
None
,
audios
:
Optional
[
PromptAudioInput
]
=
None
,
audios
:
Optional
[
PromptAudioInput
]
=
None
,
...
@@ -800,16 +801,18 @@ class VllmRunner:
...
@@ -800,16 +801,18 @@ class VllmRunner:
if
audios
is
not
None
and
(
audio
:
=
audios
[
i
])
is
not
None
:
if
audios
is
not
None
and
(
audio
:
=
audios
[
i
])
is
not
None
:
multi_modal_data
[
"audio"
]
=
audio
multi_modal_data
[
"audio"
]
=
audio
inputs
.
append
(
text_prompt_kwargs
=
{
TextPrompt
(
prompt
=
prompt
,
(
"prompt"
if
isinstance
(
prompt
,
str
)
else
"prompt_embeds"
):
multi_modal_data
=
multi_modal_data
prompt
,
if
multi_modal_data
else
None
))
"multi_modal_data"
:
multi_modal_data
or
None
}
inputs
.
append
(
TextPrompt
(
**
text_prompt_kwargs
))
return
inputs
return
inputs
def
generate
(
def
generate
(
self
,
self
,
prompts
:
list
[
str
],
prompts
:
Union
[
list
[
str
],
list
[
torch
.
Tensor
]],
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
images
:
Optional
[
PromptImageInput
]
=
None
,
images
:
Optional
[
PromptImageInput
]
=
None
,
videos
:
Optional
[
PromptVideoInput
]
=
None
,
videos
:
Optional
[
PromptVideoInput
]
=
None
,
...
@@ -835,7 +838,7 @@ class VllmRunner:
...
@@ -835,7 +838,7 @@ class VllmRunner:
output_str
=
sample
.
text
output_str
=
sample
.
text
output_ids
=
list
(
sample
.
token_ids
)
output_ids
=
list
(
sample
.
token_ids
)
req_sample_output_ids
.
append
(
prompt_ids
+
output_ids
)
req_sample_output_ids
.
append
(
prompt_ids
+
output_ids
)
req_sample_output_strs
.
append
(
prompt_str
+
output_str
)
req_sample_output_strs
.
append
(
(
prompt_str
or
""
)
+
output_str
)
outputs
.
append
((
req_sample_output_ids
,
req_sample_output_strs
))
outputs
.
append
((
req_sample_output_ids
,
req_sample_output_strs
))
return
outputs
return
outputs
...
@@ -902,7 +905,7 @@ class VllmRunner:
...
@@ -902,7 +905,7 @@ class VllmRunner:
def
generate_greedy
(
def
generate_greedy
(
self
,
self
,
prompts
:
list
[
str
],
prompts
:
Union
[
list
[
str
],
list
[
torch
.
Tensor
]],
max_tokens
:
int
,
max_tokens
:
int
,
images
:
Optional
[
PromptImageInput
]
=
None
,
images
:
Optional
[
PromptImageInput
]
=
None
,
videos
:
Optional
[
PromptVideoInput
]
=
None
,
videos
:
Optional
[
PromptVideoInput
]
=
None
,
...
...
tests/core/test_scheduler.py
View file @
7a985548
...
@@ -2,16 +2,18 @@
...
@@ -2,16 +2,18 @@
import
time
import
time
from
collections
import
deque
from
collections
import
deque
from
typing
import
Optional
from
unittest.mock
import
MagicMock
from
unittest.mock
import
MagicMock
import
pytest
# noqa
import
pytest
# noqa
import
torch
from
torch
import
Use
# noqa
from
torch
import
Use
# noqa
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
SchedulerConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
SchedulerConfig
from
vllm.core.interfaces
import
AllocStatus
from
vllm.core.interfaces
import
AllocStatus
from
vllm.core.scheduler
import
Scheduler
,
SchedulingBudget
from
vllm.core.scheduler
import
Scheduler
,
SchedulingBudget
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
SequenceGroup
from
vllm.sequence
import
SequenceGroup
,
SequenceStatus
from
.utils
import
(
append_new_token
,
append_new_token_seq
,
from
.utils
import
(
append_new_token
,
append_new_token_seq
,
append_new_token_seq_group
,
create_dummy_prompt
,
append_new_token_seq_group
,
create_dummy_prompt
,
...
@@ -968,3 +970,73 @@ def test_no_multiple_partial_prefills_with_chunked_prefill_and_prefix_caching(
...
@@ -968,3 +970,73 @@ def test_no_multiple_partial_prefills_with_chunked_prefill_and_prefix_caching(
),
"A partial prefix of C (4 tokens) should be prefilled, with the "
),
"A partial prefix of C (4 tokens) should be prefilled, with the "
"remaining tokens fit into 3 token budget (4-1 from the seqA). It will "
"remaining tokens fit into 3 token budget (4-1 from the seqA). It will "
"then be rounded down to 2 tokens on block size, thus 6 tokens in total."
"then be rounded down to 2 tokens on block size, thus 6 tokens in total."
def
test_no_batches_mixed_with_prompt_tokens_and_prompt_embeds
():
"""
Test that the scheduler does not schedule batches with prompt tokens and
prompt embeddings co-mingled.
"""
block_size
=
2
max_seq_group
=
3
scheduler
=
initialize_scheduler
(
block_size
=
block_size
,
num_cpu_blocks
=
16
,
num_gpu_blocks
=
16
,
max_num_seqs
=
max_seq_group
,
max_model_len
=
100
,
enable_prefix_caching
=
True
,
)
# the odd indexed inputs should be passed in via embeddings,
# evens via token_ids
seq_length
=
7
embedding_size
=
5
num_seqs
=
11
seq_tokens
:
list
[
list
[
int
]]
=
[]
seq_embeds
:
list
[
Optional
[
torch
.
Tensor
]]
=
[]
for
i
in
range
(
num_seqs
):
if
i
%
2
:
seq_tokens
.
append
(
list
(
range
(
seq_length
)))
seq_embeds
.
append
(
None
)
else
:
seq_tokens
.
append
([
0
]
*
seq_length
)
seq_embeds
.
append
(
torch
.
rand
(
embedding_size
))
seq_and_seq_groups
=
[
create_dummy_prompt
(
f
"
{
i
}
"
,
prompt_tokens
=
seq_tokens
[
i
],
prompt_embeds
=
seq_embeds
[
i
],
block_size
=
block_size
)
for
i
in
range
(
len
(
seq_tokens
))
]
for
_
,
seq_group
in
seq_and_seq_groups
:
scheduler
.
add_seq_group
(
seq_group
)
while
not
all
(
seq
.
is_finished
()
for
seq
,
_
in
seq_and_seq_groups
):
unfinished_seq_groups
=
[
seq_group
for
_
,
seq_group
in
seq_and_seq_groups
if
not
seq_group
.
is_finished
()
]
_
,
out
=
schedule_and_update_computed_tokens
(
scheduler
)
assert
len
(
out
.
scheduled_seq_groups
)
>
0
batch_is_prompt_embeds
=
out
.
scheduled_seq_groups
[
0
].
seq_group
.
uses_prompt_embeds
()
expected_scheduled_seq_groups
=
[
seq_group
for
seq_group
in
unfinished_seq_groups
if
seq_group
.
uses_prompt_embeds
()
==
batch_is_prompt_embeds
]
# We should have as many scheduled groups as possible, without mixing
assert
len
(
out
.
scheduled_seq_groups
)
==
min
(
max_seq_group
,
len
(
expected_scheduled_seq_groups
))
assert
all
(
scheduled_seq_group
.
seq_group
.
uses_prompt_embeds
()
==
batch_is_prompt_embeds
for
scheduled_seq_group
in
out
.
scheduled_seq_groups
)
# Finish the scheduled groups
for
scheduled_seq_group
in
out
.
scheduled_seq_groups
:
for
seq
in
scheduled_seq_group
.
seq_group
.
seqs
:
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
scheduler
.
free_finished_seq_groups
()
tests/core/utils.py
View file @
7a985548
...
@@ -5,9 +5,11 @@ from collections import defaultdict
...
@@ -5,9 +5,11 @@ from collections import defaultdict
from
collections.abc
import
Sequence
as
GenericSequence
from
collections.abc
import
Sequence
as
GenericSequence
from
typing
import
Any
,
Optional
from
typing
import
Any
,
Optional
import
torch
from
vllm
import
SamplingParams
from
vllm
import
SamplingParams
from
vllm.core.scheduler
import
Scheduler
,
SchedulerOutputs
from
vllm.core.scheduler
import
Scheduler
,
SchedulerOutputs
from
vllm.inputs
import
EncoderDecoderInputs
,
token_inputs
from
vllm.inputs
import
EncoderDecoderInputs
,
embeds_inputs
,
token_inputs
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
(
Logprob
,
Sequence
,
SequenceGroup
,
from
vllm.sequence
import
(
Logprob
,
Sequence
,
SequenceGroup
,
SequenceGroupMetadata
)
SequenceGroupMetadata
)
...
@@ -19,6 +21,7 @@ def create_dummy_prompt(
...
@@ -19,6 +21,7 @@ def create_dummy_prompt(
block_size
:
Optional
[
int
]
=
None
,
block_size
:
Optional
[
int
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_tokens
:
Optional
[
list
[
int
]]
=
None
,
prompt_tokens
:
Optional
[
list
[
int
]]
=
None
,
prompt_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
min_tokens
:
int
=
0
,
min_tokens
:
int
=
0
,
max_tokens
:
int
=
16
,
max_tokens
:
int
=
16
,
)
->
tuple
[
Sequence
,
SequenceGroup
]:
)
->
tuple
[
Sequence
,
SequenceGroup
]:
...
@@ -31,9 +34,13 @@ def create_dummy_prompt(
...
@@ -31,9 +34,13 @@ def create_dummy_prompt(
prompt_tokens
=
list
(
range
(
prompt_length
))
prompt_tokens
=
list
(
range
(
prompt_length
))
prompt_str
=
" "
.
join
([
str
(
t
)
for
t
in
prompt_tokens
])
prompt_str
=
" "
.
join
([
str
(
t
)
for
t
in
prompt_tokens
])
inputs
=
token_inputs
(
prompt_token_ids
=
prompt_tokens
,
prompt
=
prompt_str
)
if
prompt_embeds
is
None
else
embeds_inputs
(
prompt_embeds
=
prompt_embeds
)
prompt
=
Sequence
(
prompt
=
Sequence
(
int
(
request_id
),
int
(
request_id
),
inputs
=
token_inputs
(
prompt_tokens
,
prompt
=
prompt_str
)
,
inputs
=
inputs
,
block_size
=
block_size
,
block_size
=
block_size
,
)
)
seq_group
=
SequenceGroup
(
seq_group
=
SequenceGroup
(
...
...
tests/distributed/conftest.py
0 → 100644
View file @
7a985548
# SPDX-License-Identifier: Apache-2.0
import
random
from
typing
import
Optional
,
Union
import
msgspec
import
msgspec.msgpack
import
pytest
import
zmq
from
vllm.config
import
KVEventsConfig
from
vllm.distributed.kv_events
import
EventPublisherFactory
from
.test_events
import
SampleBatch
@
pytest
.
fixture
def
random_port
():
"""Generate a random port number for testing"""
return
random
.
randint
(
10000
,
60000
)
@
pytest
.
fixture
def
publisher_config
(
random_port
,
request
):
"""Create a publisher config with inproc transport"""
how
=
request
.
param
if
hasattr
(
request
,
"param"
)
else
"inproc"
if
how
==
"inproc"
:
endpoint
=
f
"inproc://test-
{
random_port
}
"
replay_endpoint
=
endpoint
+
"-replay"
else
:
endpoint
=
f
"tcp://*:
{
random_port
}
"
replay_endpoint
=
f
"tcp://*:
{
random_port
+
1
}
"
return
KVEventsConfig
(
enable_kv_cache_events
=
True
,
publisher
=
"zmq"
,
endpoint
=
endpoint
,
replay_endpoint
=
replay_endpoint
,
buffer_steps
=
100
,
hwm
=
1000
,
topic
=
"test"
)
@
pytest
.
fixture
def
publisher
(
publisher_config
):
"""Create and return a publisher instance"""
pub
=
EventPublisherFactory
.
create
(
publisher_config
)
yield
pub
pub
.
shutdown
()
@
pytest
.
fixture
def
subscriber
(
publisher_config
):
"""Create and return a subscriber for testing"""
endpoint
=
publisher_config
.
endpoint
replay_endpoint
=
publisher_config
.
replay_endpoint
if
endpoint
.
startswith
(
"tcp://*"
):
endpoint
=
endpoint
.
replace
(
"*"
,
"127.0.0.1"
)
if
replay_endpoint
and
replay_endpoint
.
startswith
(
"tcp://*"
):
replay_endpoint
=
replay_endpoint
.
replace
(
"*"
,
"127.0.0.1"
)
sub
=
MockSubscriber
(
endpoint
,
replay_endpoint
,
publisher_config
.
topic
)
yield
sub
sub
.
close
()
class
MockSubscriber
:
"""Helper class to receive and verify published events"""
def
__init__
(
self
,
pub_endpoint
:
str
,
replay_endpoint
:
Optional
[
str
]
=
None
,
topic
:
str
=
""
,
decode_type
=
SampleBatch
):
self
.
ctx
=
zmq
.
Context
.
instance
()
# Set up subscriber socket
self
.
sub
=
self
.
ctx
.
socket
(
zmq
.
SUB
)
self
.
sub
.
setsockopt
(
zmq
.
SUBSCRIBE
,
topic
.
encode
(
'utf-8'
))
self
.
sub
.
connect
(
pub_endpoint
)
# Set up replay socket if provided
self
.
replay
=
None
if
replay_endpoint
:
self
.
replay
=
self
.
ctx
.
socket
(
zmq
.
REQ
)
self
.
replay
.
connect
(
replay_endpoint
)
self
.
topic
=
topic
self
.
topic_bytes
=
topic
.
encode
(
'utf-8'
)
self
.
received_msgs
:
list
[
tuple
[
int
,
SampleBatch
]]
=
[]
self
.
last_seq
=
-
1
self
.
decoder
=
msgspec
.
msgpack
.
Decoder
(
type
=
decode_type
)
def
receive_one
(
self
,
timeout
=
1000
)
->
Union
[
tuple
[
int
,
SampleBatch
],
None
]:
"""Receive a single message with timeout"""
if
not
self
.
sub
.
poll
(
timeout
):
return
None
topic_bytes
,
seq_bytes
,
payload
=
self
.
sub
.
recv_multipart
()
assert
topic_bytes
==
self
.
topic_bytes
seq
=
int
.
from_bytes
(
seq_bytes
,
"big"
)
data
=
self
.
decoder
.
decode
(
payload
)
self
.
last_seq
=
seq
self
.
received_msgs
.
append
((
seq
,
data
))
return
seq
,
data
def
request_replay
(
self
,
start_seq
:
int
)
->
None
:
"""Request replay of messages starting from start_seq"""
if
not
self
.
replay
:
raise
ValueError
(
"Replay socket not initialized"
)
self
.
replay
.
send
(
start_seq
.
to_bytes
(
8
,
"big"
))
def
receive_replay
(
self
)
->
list
[
tuple
[
int
,
SampleBatch
]]:
"""Receive replayed messages"""
if
not
self
.
replay
:
raise
ValueError
(
"Replay socket not initialized"
)
replayed
:
list
[
tuple
[
int
,
SampleBatch
]]
=
[]
while
True
:
try
:
if
not
self
.
replay
.
poll
(
1000
):
break
frames
=
self
.
replay
.
recv_multipart
()
if
not
frames
or
not
frames
[
-
1
]:
# End of replay marker
break
seq_bytes
,
payload
=
frames
seq
=
int
.
from_bytes
(
seq_bytes
,
"big"
)
data
=
self
.
decoder
.
decode
(
payload
)
replayed
.
append
((
seq
,
data
))
except
zmq
.
ZMQError
as
_
:
break
return
replayed
def
close
(
self
):
"""Clean up resources"""
self
.
sub
.
close
()
if
self
.
replay
:
self
.
replay
.
close
()
tests/distributed/test_events.py
0 → 100644
View file @
7a985548
# SPDX-License-Identifier: Apache-2.0
import
threading
import
time
import
msgspec
import
pytest
from
vllm.distributed.kv_events
import
(
EventBatch
,
EventPublisherFactory
,
NullEventPublisher
)
class
EventSample
(
msgspec
.
Struct
,
tag
=
True
,
# type: ignore
array_like
=
True
# type: ignore
):
"""Test event for publisher testing"""
id
:
int
value
:
str
class
SampleBatch
(
EventBatch
):
"""Test event batch for publisher testing"""
events
:
list
[
EventSample
]
def
create_test_events
(
count
:
int
)
->
SampleBatch
:
"""Create a batch of test events"""
events
=
[
EventSample
(
id
=
i
,
value
=
f
"test-
{
i
}
"
)
for
i
in
range
(
count
)]
return
SampleBatch
(
ts
=
time
.
time
(),
events
=
events
)
def
test_basic_publishing
(
publisher
,
subscriber
):
"""Test basic event publishing works"""
test_batch
=
create_test_events
(
5
)
publisher
.
publish
(
test_batch
)
result
=
subscriber
.
receive_one
(
timeout
=
1000
)
assert
result
is
not
None
,
"No message received"
seq
,
received
=
result
assert
seq
==
0
,
"Sequence number mismatch"
assert
received
.
ts
==
pytest
.
approx
(
test_batch
.
ts
,
abs
=
0.1
),
(
"Timestamp mismatch"
)
assert
len
(
received
.
events
)
==
len
(
test_batch
.
events
),
(
"Number of events mismatch"
)
for
i
,
event
in
enumerate
(
received
.
events
):
assert
event
.
id
==
i
,
"Event id mismatch"
assert
event
.
value
==
f
"test-
{
i
}
"
,
"Event value mismatch"
def
test_multiple_events
(
publisher
,
subscriber
):
"""Test publishing and receiving multiple event batches"""
for
_
in
range
(
10
):
batch
=
create_test_events
(
2
)
publisher
.
publish
(
batch
)
received
=
[]
for
_
in
range
(
10
):
data
=
subscriber
.
receive_one
(
timeout
=
100
)
if
data
:
received
.
append
(
data
)
assert
len
(
received
)
==
10
,
"Number of messages mismatch"
seqs
=
[
seq
for
seq
,
_
in
received
]
assert
seqs
==
list
(
range
(
10
)),
"Sequence numbers mismatch"
def
test_replay_mechanism
(
publisher
,
subscriber
):
"""Test the replay mechanism works correctly"""
for
_
in
range
(
19
):
batch
=
create_test_events
(
1
)
publisher
.
publish
(
batch
)
time
.
sleep
(
0.5
)
# Need publisher to process above requests
subscriber
.
request_replay
(
10
)
batch
=
create_test_events
(
1
)
publisher
.
publish
(
batch
)
# 20th message
replayed
=
subscriber
.
receive_replay
()
assert
len
(
replayed
)
>
0
,
"No replayed messages received"
seqs
=
[
seq
for
seq
,
_
in
replayed
]
assert
all
(
seq
>=
10
for
seq
in
seqs
),
"Replayed messages not in order"
assert
seqs
==
list
(
range
(
min
(
seqs
),
max
(
seqs
)
+
1
)),
(
"Replayed messages not consecutive"
)
def
test_buffer_limit
(
publisher
,
subscriber
,
publisher_config
):
"""Test buffer limit behavior"""
buffer_size
=
publisher_config
.
buffer_steps
# Publish more events than the buffer can hold
for
i
in
range
(
buffer_size
+
10
):
batch
=
create_test_events
(
1
)
publisher
.
publish
(
batch
)
time
.
sleep
(
0.5
)
# Need publisher to process above requests
subscriber
.
request_replay
(
0
)
batch
=
create_test_events
(
1
)
publisher
.
publish
(
batch
)
replayed
=
subscriber
.
receive_replay
()
assert
len
(
replayed
)
<=
buffer_size
,
"Can't replay more than buffer size"
oldest_seq
=
min
(
seq
for
seq
,
_
in
replayed
)
assert
oldest_seq
>=
10
,
"The oldest sequence should be at least 10"
def
test_topic_filtering
(
publisher_config
):
"""
Test that a subscriber only receives messages matching its topic filter
"""
publisher_config
.
replay_endpoint
=
None
cfg
=
publisher_config
.
model_copy
()
cfg
.
topic
=
"foo"
pub
=
EventPublisherFactory
.
create
(
cfg
)
from
.conftest
import
MockSubscriber
sub_foo
=
MockSubscriber
(
cfg
.
endpoint
,
None
,
"foo"
)
sub_bar
=
MockSubscriber
(
cfg
.
endpoint
,
None
,
"bar"
)
try
:
time
.
sleep
(
0.1
)
for
_
in
range
(
3
):
pub
.
publish
(
create_test_events
(
1
))
foo_received
=
[
sub_foo
.
receive_one
(
timeout
=
200
)
for
_
in
range
(
3
)]
assert
all
(
msg
is
not
None
for
msg
in
foo_received
),
(
"Subscriber with matching topic should receive messages"
)
bar_received
=
[
sub_bar
.
receive_one
(
timeout
=
200
)
for
_
in
range
(
3
)]
assert
all
(
msg
is
None
for
msg
in
bar_received
),
(
"Subscriber with non-matching topic should receive no messages"
)
finally
:
pub
.
shutdown
()
sub_foo
.
close
()
sub_bar
.
close
()
def
test_high_volume
(
publisher
,
subscriber
):
"""Test publishing and receiving a high volume of events"""
num_batches
=
10_000
events_per_batch
=
100
# Publish events in a separate thread to not block
def
publish_events
():
for
i
in
range
(
num_batches
):
batch
=
create_test_events
(
events_per_batch
)
publisher
.
publish
(
batch
)
# Small delay to avoid overwhelming
if
i
%
100
==
0
:
time
.
sleep
(
0.01
)
received
:
list
[
tuple
[
int
,
SampleBatch
]]
=
[]
publisher_thread
=
threading
.
Thread
(
target
=
publish_events
)
publisher_thread
.
start
()
start_time
=
time
.
time
()
while
len
(
received
)
<
num_batches
:
if
time
.
time
()
-
start_time
>
10
:
# Timeout after 10 seconds
break
result
=
subscriber
.
receive_one
(
timeout
=
100
)
if
result
:
received
.
append
(
result
)
publisher_thread
.
join
()
assert
len
(
received
)
>=
num_batches
*
0.9
,
(
"We should have received most messages"
)
seqs
=
[
seq
for
seq
,
_
in
received
]
assert
sorted
(
seqs
)
==
seqs
,
"Sequence numbers should be in order"
def
test_null_publisher
():
"""Test that NullEventPublisher can be used without errors"""
publisher
=
NullEventPublisher
()
# This should not raise any errors
batch
=
create_test_events
(
5
)
publisher
.
publish
(
batch
)
publisher
.
shutdown
()
tests/distributed/test_pipeline_parallel.py
View file @
7a985548
...
@@ -100,9 +100,8 @@ class PPTestSettings:
...
@@ -100,9 +100,8 @@ class PPTestSettings:
eager_mode
=
True
,
eager_mode
=
True
,
chunked_prefill
=
False
),
chunked_prefill
=
False
),
],
],
# only ray is supported for V1
distributed_backends
=
[
"mp"
,
"mp"
,
"ray"
,
"ray"
],
distributed_backends
=
[
"mp"
,
"ray"
,
"ray"
],
vllm_major_versions
=
[
"0"
,
"1"
,
"0"
,
"1"
],
vllm_major_versions
=
[
"0"
,
"0"
,
"1"
],
task
=
task
,
task
=
task
,
test_options
=
PPTestOptions
(
multi_node_only
=
multi_node_only
,
test_options
=
PPTestOptions
(
multi_node_only
=
multi_node_only
,
load_format
=
load_format
),
load_format
=
load_format
),
...
@@ -186,7 +185,7 @@ TEXT_GENERATION_MODELS = {
...
@@ -186,7 +185,7 @@ TEXT_GENERATION_MODELS = {
"mosaicml/mpt-7b"
:
PPTestSettings
.
fast
(),
"mosaicml/mpt-7b"
:
PPTestSettings
.
fast
(),
"nvidia/Minitron-8B-Base"
:
PPTestSettings
.
fast
(),
"nvidia/Minitron-8B-Base"
:
PPTestSettings
.
fast
(),
"allenai/OLMo-1B-hf"
:
PPTestSettings
.
fast
(),
"allenai/OLMo-1B-hf"
:
PPTestSettings
.
fast
(),
"
shanearora/OLMo-7B-1124-hf
"
:
PPTestSettings
.
fast
(),
"
allenai/OLMo-2-0425-1B
"
:
PPTestSettings
.
fast
(),
"allenai/OLMoE-1B-7B-0924-Instruct"
:
PPTestSettings
.
fast
(),
"allenai/OLMoE-1B-7B-0924-Instruct"
:
PPTestSettings
.
fast
(),
"facebook/opt-iml-max-1.3b"
:
PPTestSettings
.
fast
(),
"facebook/opt-iml-max-1.3b"
:
PPTestSettings
.
fast
(),
"OrionStarAI/Orion-14B-Chat"
:
PPTestSettings
.
fast
(),
"OrionStarAI/Orion-14B-Chat"
:
PPTestSettings
.
fast
(),
...
@@ -350,6 +349,11 @@ def _compare_tp(
...
@@ -350,6 +349,11 @@ def _compare_tp(
# Temporary. Currently when zeromq + SPMD is used, it does not properly
# Temporary. Currently when zeromq + SPMD is used, it does not properly
# terminate because of a Ray Compiled Graph issue.
# terminate because of a Ray Compiled Graph issue.
common_args
.
append
(
"--disable-frontend-multiprocessing"
)
common_args
.
append
(
"--disable-frontend-multiprocessing"
)
elif
distributed_backend
==
"mp"
:
# Both V0/V1 of multiprocessing executor support PP
pp_env
=
{
"VLLM_USE_V1"
:
vllm_major_version
,
}
else
:
else
:
pp_env
=
None
pp_env
=
None
...
...
tests/distributed/test_sequence_parallel.py
View file @
7a985548
...
@@ -26,6 +26,7 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
...
@@ -26,6 +26,7 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
class
ParallelSetup
(
NamedTuple
):
class
ParallelSetup
(
NamedTuple
):
tp_size
:
int
tp_size
:
int
pp_size
:
int
sp_enabled
:
bool
sp_enabled
:
bool
eager_mode
:
bool
eager_mode
:
bool
chunked_prefill
:
bool
chunked_prefill
:
bool
...
@@ -60,6 +61,7 @@ class SPTestSettings:
...
@@ -60,6 +61,7 @@ class SPTestSettings:
def
detailed
(
def
detailed
(
*
,
*
,
tp_base
:
int
=
2
,
tp_base
:
int
=
2
,
pp_base
:
int
=
1
,
multi_node_only
:
bool
=
False
,
multi_node_only
:
bool
=
False
,
task
:
TaskOption
=
"auto"
,
task
:
TaskOption
=
"auto"
,
load_format
:
Optional
[
str
]
=
None
,
load_format
:
Optional
[
str
]
=
None
,
...
@@ -67,18 +69,42 @@ class SPTestSettings:
...
@@ -67,18 +69,42 @@ class SPTestSettings:
return
SPTestSettings
(
return
SPTestSettings
(
parallel_setups
=
[
parallel_setups
=
[
ParallelSetup
(
tp_size
=
tp_base
,
ParallelSetup
(
tp_size
=
tp_base
,
pp_size
=
pp_base
,
sp_enabled
=
True
,
sp_enabled
=
True
,
eager_mode
=
False
,
eager_mode
=
False
,
chunked_prefill
=
False
),
chunked_prefill
=
False
),
ParallelSetup
(
tp_size
=
tp_base
,
ParallelSetup
(
tp_size
=
tp_base
,
pp_size
=
pp_base
,
sp_enabled
=
True
,
sp_enabled
=
True
,
eager_mode
=
False
,
eager_mode
=
False
,
chunked_prefill
=
True
),
chunked_prefill
=
True
),
ParallelSetup
(
tp_size
=
tp_base
,
ParallelSetup
(
tp_size
=
tp_base
,
pp_size
=
pp_base
,
sp_enabled
=
True
,
sp_enabled
=
True
,
eager_mode
=
True
,
eager_mode
=
True
,
chunked_prefill
=
False
),
chunked_prefill
=
False
),
ParallelSetup
(
tp_size
=
tp_base
,
ParallelSetup
(
tp_size
=
tp_base
,
pp_size
=
pp_base
,
sp_enabled
=
True
,
eager_mode
=
True
,
chunked_prefill
=
True
),
ParallelSetup
(
tp_size
=
tp_base
,
pp_size
=
2
*
pp_base
,
sp_enabled
=
True
,
eager_mode
=
False
,
chunked_prefill
=
False
),
ParallelSetup
(
tp_size
=
tp_base
,
pp_size
=
2
*
pp_base
,
sp_enabled
=
True
,
eager_mode
=
False
,
chunked_prefill
=
True
),
ParallelSetup
(
tp_size
=
tp_base
,
pp_size
=
2
*
pp_base
,
sp_enabled
=
True
,
eager_mode
=
True
,
chunked_prefill
=
False
),
ParallelSetup
(
tp_size
=
tp_base
,
pp_size
=
2
*
pp_base
,
sp_enabled
=
True
,
sp_enabled
=
True
,
eager_mode
=
True
,
eager_mode
=
True
,
chunked_prefill
=
True
)
chunked_prefill
=
True
)
...
@@ -94,6 +120,7 @@ class SPTestSettings:
...
@@ -94,6 +120,7 @@ class SPTestSettings:
def
fast
(
def
fast
(
*
,
*
,
tp_base
:
int
=
2
,
tp_base
:
int
=
2
,
pp_base
:
int
=
1
,
task
:
TaskOption
=
"auto"
,
task
:
TaskOption
=
"auto"
,
multi_node_only
:
bool
=
False
,
multi_node_only
:
bool
=
False
,
load_format
:
Optional
[
str
]
=
None
,
load_format
:
Optional
[
str
]
=
None
,
...
@@ -101,6 +128,12 @@ class SPTestSettings:
...
@@ -101,6 +128,12 @@ class SPTestSettings:
return
SPTestSettings
(
return
SPTestSettings
(
parallel_setups
=
[
parallel_setups
=
[
ParallelSetup
(
tp_size
=
tp_base
,
ParallelSetup
(
tp_size
=
tp_base
,
pp_size
=
pp_base
,
sp_enabled
=
True
,
eager_mode
=
False
,
chunked_prefill
=
False
),
ParallelSetup
(
tp_size
=
tp_base
,
pp_size
=
2
*
pp_base
,
sp_enabled
=
True
,
sp_enabled
=
True
,
eager_mode
=
False
,
eager_mode
=
False
,
chunked_prefill
=
False
),
chunked_prefill
=
False
),
...
@@ -136,6 +169,7 @@ def _compare_sp(
...
@@ -136,6 +169,7 @@ def _compare_sp(
):
):
(
(
tp_size
,
tp_size
,
pp_size
,
sp_enabled
,
sp_enabled
,
eager_mode
,
eager_mode
,
chunked_prefill
,
chunked_prefill
,
...
@@ -167,7 +201,6 @@ def _compare_sp(
...
@@ -167,7 +201,6 @@ def _compare_sp(
else
:
else
:
model_info
.
check_available_online
(
on_fail
=
"skip"
)
model_info
.
check_available_online
(
on_fail
=
"skip"
)
pp_size
=
1
if
num_gpus_available
<
tp_size
*
pp_size
:
if
num_gpus_available
<
tp_size
*
pp_size
:
pytest
.
skip
(
f
"Need at least
{
tp_size
}
x
{
pp_size
}
GPUs"
)
pytest
.
skip
(
f
"Need at least
{
tp_size
}
x
{
pp_size
}
GPUs"
)
if
VLLM_MULTI_NODE
and
distributed_backend
==
"mp"
:
if
VLLM_MULTI_NODE
and
distributed_backend
==
"mp"
:
...
@@ -206,7 +239,7 @@ def _compare_sp(
...
@@ -206,7 +239,7 @@ def _compare_sp(
'compile_sizes'
:
[
4
,
8
],
'compile_sizes'
:
[
4
,
8
],
'splitting_ops'
:
[],
'splitting_ops'
:
[],
'pass_config'
:
{
'pass_config'
:
{
'enable_sequence_parallism'
:
sp_enabled
,
'enable_sequence_parall
el
ism'
:
sp_enabled
,
'enable_noop'
:
True
,
'enable_noop'
:
True
,
'enable_fusion'
:
True
,
'enable_fusion'
:
True
,
},
},
...
@@ -223,7 +256,7 @@ def _compare_sp(
...
@@ -223,7 +256,7 @@ def _compare_sp(
"--distributed-executor-backend"
,
"--distributed-executor-backend"
,
distributed_backend
,
distributed_backend
,
"--compilation_config"
,
"--compilation_config"
,
str
(
compilation_config
),
json
.
dumps
(
compilation_config
),
]
]
tp_env
=
{
tp_env
=
{
...
@@ -256,7 +289,7 @@ def _compare_sp(
...
@@ -256,7 +289,7 @@ def _compare_sp(
SP_TEXT_GENERATION_MODELS
=
{
SP_TEXT_GENERATION_MODELS
=
{
# [Decoder-only]
# [Decoder-only]
"meta-llama/Llama-3.2-1B-Instruct"
:
SPTestSettings
.
detailed
(),
"meta-llama/Llama-3.2-1B-Instruct"
:
SPTestSettings
.
fast
(),
}
}
SP_TEST_MODELS
=
[
SP_TEST_MODELS
=
[
...
...
tests/distributed/test_torchrun_example.py
View file @
7a985548
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# unit test for `examples/offline_inference/torchrun_example.py`
# unit test for `examples/offline_inference/torchrun_example.py`
import
os
import
random
import
random
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -25,6 +25,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
...
@@ -25,6 +25,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# to test if all ranks agree on the same kv cache configuration.
# to test if all ranks agree on the same kv cache configuration.
llm
=
LLM
(
model
=
"facebook/opt-125m"
,
llm
=
LLM
(
model
=
"facebook/opt-125m"
,
tensor_parallel_size
=
2
,
tensor_parallel_size
=
2
,
pipeline_parallel_size
=
int
(
os
.
getenv
(
"PP_SIZE"
,
1
)),
distributed_executor_backend
=
"external_launcher"
,
distributed_executor_backend
=
"external_launcher"
,
gpu_memory_utilization
=
random
.
uniform
(
0.7
,
0.9
),
gpu_memory_utilization
=
random
.
uniform
(
0.7
,
0.9
),
swap_space
=
random
.
randint
(
1
,
4
),
swap_space
=
random
.
randint
(
1
,
4
),
...
...
tests/engine/test_arg_utils.py
View file @
7a985548
...
@@ -8,20 +8,18 @@ from typing import Literal, Optional
...
@@ -8,20 +8,18 @@ from typing import Literal, Optional
import
pytest
import
pytest
from
vllm.config
import
Pooler
Config
,
config
from
vllm.config
import
Compilation
Config
,
config
from
vllm.engine.arg_utils
import
(
EngineArgs
,
contains_type
,
get_kwargs
,
from
vllm.engine.arg_utils
import
(
EngineArgs
,
contains_type
,
get_kwargs
,
get_type
,
is_not_builtin
,
is_type
,
get_type
,
is_not_builtin
,
is_type
,
nullable_kvs
,
optional_type
)
literal_to_kwargs
,
nullable_kvs
,
optional_type
,
parse_type
)
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
import
FlexibleArgumentParser
@
pytest
.
mark
.
parametrize
((
"type"
,
"value"
,
"expected"
),
[
@
pytest
.
mark
.
parametrize
((
"type"
,
"value"
,
"expected"
),
[
(
int
,
"42"
,
42
),
(
int
,
"42"
,
42
),
(
int
,
"None"
,
None
),
(
float
,
"3.14"
,
3.14
),
(
float
,
"3.14"
,
3.14
),
(
float
,
"None"
,
None
),
(
str
,
"Hello World!"
,
"Hello World!"
),
(
str
,
"Hello World!"
,
"Hello World!"
),
(
str
,
"None"
,
None
),
(
json
.
loads
,
'{"foo":1,"bar":2}'
,
{
(
json
.
loads
,
'{"foo":1,"bar":2}'
,
{
"foo"
:
1
,
"foo"
:
1
,
"bar"
:
2
"bar"
:
2
...
@@ -30,15 +28,20 @@ from vllm.utils import FlexibleArgumentParser
...
@@ -30,15 +28,20 @@ from vllm.utils import FlexibleArgumentParser
"foo"
:
1
,
"foo"
:
1
,
"bar"
:
2
"bar"
:
2
}),
}),
(
json
.
loads
,
"None"
,
None
),
])
])
def
test_
optional
_type
(
type
,
value
,
expected
):
def
test_
parse
_type
(
type
,
value
,
expected
):
optional
_type_func
=
optional
_type
(
type
)
parse
_type_func
=
parse
_type
(
type
)
context
=
nullcontext
()
context
=
nullcontext
()
if
value
==
"foo=1,bar=2"
:
if
value
==
"foo=1,bar=2"
:
context
=
pytest
.
warns
(
DeprecationWarning
)
context
=
pytest
.
warns
(
DeprecationWarning
)
with
context
:
with
context
:
assert
optional_type_func
(
value
)
==
expected
assert
parse_type_func
(
value
)
==
expected
def
test_optional_type
():
optional_type_func
=
optional_type
(
int
)
assert
optional_type_func
(
"None"
)
is
None
assert
optional_type_func
(
"42"
)
==
42
@
pytest
.
mark
.
parametrize
((
"type_hint"
,
"type"
,
"expected"
),
[
@
pytest
.
mark
.
parametrize
((
"type_hint"
,
"type"
,
"expected"
),
[
...
@@ -71,9 +74,57 @@ def test_get_type(type_hints, type, expected):
...
@@ -71,9 +74,57 @@ def test_get_type(type_hints, type, expected):
assert
get_type
(
type_hints
,
type
)
==
expected
assert
get_type
(
type_hints
,
type
)
==
expected
@
pytest
.
mark
.
parametrize
((
"type_hints"
,
"expected"
),
[
({
Literal
[
1
,
2
]},
{
"type"
:
int
,
"choices"
:
[
1
,
2
]
}),
({
Literal
[
1
,
"a"
]},
Exception
),
])
def
test_literal_to_kwargs
(
type_hints
,
expected
):
context
=
nullcontext
()
if
expected
is
Exception
:
context
=
pytest
.
raises
(
expected
)
with
context
:
assert
literal_to_kwargs
(
type_hints
)
==
expected
@
config
@
dataclass
class
NestedConfig
:
field
:
int
=
1
"""field"""
@
config
@
dataclass
class
FromCliConfig1
:
field
:
int
=
1
"""field"""
@
classmethod
def
from_cli
(
cls
,
cli_value
:
str
):
inst
=
cls
(
**
json
.
loads
(
cli_value
))
inst
.
field
+=
1
return
inst
@
config
@
dataclass
class
FromCliConfig2
:
field
:
int
=
1
"""field"""
@
classmethod
def
from_cli
(
cls
,
cli_value
:
str
):
inst
=
cls
(
**
json
.
loads
(
cli_value
))
inst
.
field
+=
2
return
inst
@
config
@
config
@
dataclass
@
dataclass
class
DummyConfig
Class
:
class
DummyConfig
:
regular_bool
:
bool
=
True
regular_bool
:
bool
=
True
"""Regular bool with default True"""
"""Regular bool with default True"""
optional_bool
:
Optional
[
bool
]
=
None
optional_bool
:
Optional
[
bool
]
=
None
...
@@ -81,23 +132,35 @@ class DummyConfigClass:
...
@@ -81,23 +132,35 @@ class DummyConfigClass:
optional_literal
:
Optional
[
Literal
[
"x"
,
"y"
]]
=
None
optional_literal
:
Optional
[
Literal
[
"x"
,
"y"
]]
=
None
"""Optional literal with default None"""
"""Optional literal with default None"""
tuple_n
:
tuple
[
int
,
...]
=
field
(
default_factory
=
lambda
:
(
1
,
2
,
3
))
tuple_n
:
tuple
[
int
,
...]
=
field
(
default_factory
=
lambda
:
(
1
,
2
,
3
))
"""Tuple with
default (1, 2, 3)
"""
"""Tuple with
variable length
"""
tuple_2
:
tuple
[
int
,
int
]
=
field
(
default_factory
=
lambda
:
(
1
,
2
))
tuple_2
:
tuple
[
int
,
int
]
=
field
(
default_factory
=
lambda
:
(
1
,
2
))
"""Tuple with
default (1, 2)
"""
"""Tuple with
fixed length
"""
list_n
:
list
[
int
]
=
field
(
default_factory
=
lambda
:
[
1
,
2
,
3
])
list_n
:
list
[
int
]
=
field
(
default_factory
=
lambda
:
[
1
,
2
,
3
])
"""List with default [1, 2, 3]"""
"""List with variable length"""
list_literal
:
list
[
Literal
[
1
,
2
]]
=
field
(
default_factory
=
list
)
"""List with literal choices"""
literal_literal
:
Literal
[
Literal
[
1
],
Literal
[
2
]]
=
1
"""Literal of literals with default 1"""
json_tip
:
dict
=
field
(
default_factory
=
dict
)
"""Dict which will be JSON in CLI"""
nested_config
:
NestedConfig
=
field
(
default_factory
=
NestedConfig
)
"""Nested config"""
from_cli_config1
:
FromCliConfig1
=
field
(
default_factory
=
FromCliConfig1
)
"""Config with from_cli method"""
from_cli_config2
:
FromCliConfig2
=
field
(
default_factory
=
FromCliConfig2
)
"""Different config with from_cli method"""
@
pytest
.
mark
.
parametrize
((
"type_hint"
,
"expected"
),
[
@
pytest
.
mark
.
parametrize
((
"type_hint"
,
"expected"
),
[
(
int
,
False
),
(
int
,
False
),
(
DummyConfig
Class
,
True
),
(
DummyConfig
,
True
),
])
])
def
test_is_not_builtin
(
type_hint
,
expected
):
def
test_is_not_builtin
(
type_hint
,
expected
):
assert
is_not_builtin
(
type_hint
)
==
expected
assert
is_not_builtin
(
type_hint
)
==
expected
def
test_get_kwargs
():
def
test_get_kwargs
():
kwargs
=
get_kwargs
(
DummyConfig
Class
)
kwargs
=
get_kwargs
(
DummyConfig
)
print
(
kwargs
)
print
(
kwargs
)
# bools should not have their type set
# bools should not have their type set
...
@@ -111,6 +174,20 @@ def test_get_kwargs():
...
@@ -111,6 +174,20 @@ def test_get_kwargs():
# lists should work
# lists should work
assert
kwargs
[
"list_n"
][
"type"
]
is
int
assert
kwargs
[
"list_n"
][
"type"
]
is
int
assert
kwargs
[
"list_n"
][
"nargs"
]
==
"+"
assert
kwargs
[
"list_n"
][
"nargs"
]
==
"+"
# lists with literals should have the correct choices
assert
kwargs
[
"list_literal"
][
"type"
]
is
int
assert
kwargs
[
"list_literal"
][
"nargs"
]
==
"+"
assert
kwargs
[
"list_literal"
][
"choices"
]
==
[
1
,
2
]
# literals of literals should have merged choices
assert
kwargs
[
"literal_literal"
][
"choices"
]
==
[
1
,
2
]
# dict should have json tip in help
json_tip
=
"Should either be a valid JSON string or JSON keys"
assert
json_tip
in
kwargs
[
"json_tip"
][
"help"
]
# nested config should should construct the nested config
assert
kwargs
[
"nested_config"
][
"type"
](
'{"field": 2}'
)
==
NestedConfig
(
2
)
# from_cli configs should be constructed with the correct method
assert
kwargs
[
"from_cli_config1"
][
"type"
](
'{"field": 2}'
).
field
==
3
assert
kwargs
[
"from_cli_config2"
][
"type"
](
'{"field": 2}'
).
field
==
4
@
pytest
.
mark
.
parametrize
((
"arg"
,
"expected"
),
[
@
pytest
.
mark
.
parametrize
((
"arg"
,
"expected"
),
[
...
@@ -146,7 +223,7 @@ def test_compilation_config():
...
@@ -146,7 +223,7 @@ def test_compilation_config():
# default value
# default value
args
=
parser
.
parse_args
([])
args
=
parser
.
parse_args
([])
assert
args
.
compilation_config
is
None
assert
args
.
compilation_config
==
CompilationConfig
()
# set to O3
# set to O3
args
=
parser
.
parse_args
([
"-O3"
])
args
=
parser
.
parse_args
([
"-O3"
])
...
@@ -163,7 +240,7 @@ def test_compilation_config():
...
@@ -163,7 +240,7 @@ def test_compilation_config():
# set to string form of a dict
# set to string form of a dict
args
=
parser
.
parse_args
([
args
=
parser
.
parse_args
([
"--compilation-config"
,
"--compilation-config"
,
"{'
level
'
: 3,
'
cudagraph_capture_sizes
'
: [1, 2, 4, 8]}
"
,
'{"
level
"
: 3,
"
cudagraph_capture_sizes
"
: [1, 2, 4, 8]}
'
,
])
])
assert
(
args
.
compilation_config
.
level
==
3
and
assert
(
args
.
compilation_config
.
level
==
3
and
args
.
compilation_config
.
cudagraph_capture_sizes
==
[
1
,
2
,
4
,
8
])
args
.
compilation_config
.
cudagraph_capture_sizes
==
[
1
,
2
,
4
,
8
])
...
@@ -171,7 +248,7 @@ def test_compilation_config():
...
@@ -171,7 +248,7 @@ def test_compilation_config():
# set to string form of a dict
# set to string form of a dict
args
=
parser
.
parse_args
([
args
=
parser
.
parse_args
([
"--compilation-config="
"--compilation-config="
"{'
level
'
: 3,
'
cudagraph_capture_sizes
'
: [1, 2, 4, 8]}
"
,
'{"
level
"
: 3,
"
cudagraph_capture_sizes
"
: [1, 2, 4, 8]}
'
,
])
])
assert
(
args
.
compilation_config
.
level
==
3
and
assert
(
args
.
compilation_config
.
level
==
3
and
args
.
compilation_config
.
cudagraph_capture_sizes
==
[
1
,
2
,
4
,
8
])
args
.
compilation_config
.
cudagraph_capture_sizes
==
[
1
,
2
,
4
,
8
])
...
@@ -196,17 +273,6 @@ def test_prefix_cache_default():
...
@@ -196,17 +273,6 @@ def test_prefix_cache_default():
assert
not
engine_args
.
enable_prefix_caching
assert
not
engine_args
.
enable_prefix_caching
def
test_valid_pooling_config
():
parser
=
EngineArgs
.
add_cli_args
(
FlexibleArgumentParser
())
args
=
parser
.
parse_args
([
'--override-pooler-config'
,
'{"pooling_type": "MEAN"}'
,
])
engine_args
=
EngineArgs
.
from_cli_args
(
args
=
args
)
assert
engine_args
.
override_pooler_config
==
PoolerConfig
(
pooling_type
=
"MEAN"
,
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"arg"
),
(
"arg"
),
[
[
...
...
tests/engine/test_
skip_tokenizer_init
.py
→
tests/engine/test_
options
.py
View file @
7a985548
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
contextlib
import
nullcontext
import
pytest
import
pytest
...
@@ -14,6 +15,7 @@ def test_skip_tokenizer_initialization(model: str):
...
@@ -14,6 +15,7 @@ def test_skip_tokenizer_initialization(model: str):
llm
=
LLM
(
llm
=
LLM
(
model
=
model
,
model
=
model
,
skip_tokenizer_init
=
True
,
skip_tokenizer_init
=
True
,
enforce_eager
=
True
,
)
)
sampling_params
=
SamplingParams
(
prompt_logprobs
=
True
,
detokenize
=
True
)
sampling_params
=
SamplingParams
(
prompt_logprobs
=
True
,
detokenize
=
True
)
...
@@ -27,3 +29,32 @@ def test_skip_tokenizer_initialization(model: str):
...
@@ -27,3 +29,32 @@ def test_skip_tokenizer_initialization(model: str):
assert
len
(
completions
)
>
0
assert
len
(
completions
)
>
0
assert
completions
[
0
].
text
==
""
assert
completions
[
0
].
text
==
""
assert
completions
[
0
].
token_ids
assert
completions
[
0
].
token_ids
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"distilbert/distilgpt2"
])
@
pytest
.
mark
.
parametrize
(
"enable_prompt_embeds"
,
[
True
,
False
])
def
test_enable_prompt_embeds
(
hf_runner
,
model
:
str
,
enable_prompt_embeds
:
bool
):
prompt
=
"abc"
with
hf_runner
(
model
)
as
hf_model
:
token_ids
=
hf_model
.
tokenizer
(
prompt
,
return_tensors
=
"pt"
).
input_ids
token_ids
=
token_ids
.
to
(
hf_model
.
model
.
device
)
embed_layer
=
hf_model
.
model
.
get_input_embeddings
()
prompt_embeds
=
embed_layer
(
token_ids
).
squeeze
(
0
)
ctx
=
(
nullcontext
()
if
enable_prompt_embeds
else
pytest
.
raises
(
ValueError
,
match
=
"set `--enable-prompt-embeds`"
))
# This test checks if the flag skip_tokenizer_init skips the initialization
# of tokenizer and detokenizer. The generated output is expected to contain
# token ids.
llm
=
LLM
(
model
=
model
,
enable_prompt_embeds
=
enable_prompt_embeds
,
enforce_eager
=
True
,
)
with
ctx
:
llm
.
generate
({
"prompt_embeds"
:
prompt_embeds
})
tests/entrypoints/llm/test_chat.py
View file @
7a985548
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
weakref
import
pytest
import
pytest
from
vllm
import
LLM
from
vllm
import
LLM
from
vllm.distributed
import
cleanup_dist_env_and_memory
from
..openai.test_vision
import
TEST_IMAGE_URLS
from
..openai.test_vision
import
TEST_IMAGE_URLS
def
test_chat
():
@
pytest
.
fixture
(
scope
=
"function"
)
llm
=
LLM
(
model
=
"meta-llama/Llama-3.2-1B-Instruct"
)
def
text_llm
():
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm
=
LLM
(
model
=
"meta-llama/Llama-3.2-1B-Instruct"
,
enforce_eager
=
True
,
seed
=
0
)
with
llm
.
deprecate_legacy_api
():
yield
weakref
.
proxy
(
llm
)
del
llm
cleanup_dist_env_and_memory
()
def
test_chat
(
text_llm
):
prompt1
=
"Explain the concept of entropy."
prompt1
=
"Explain the concept of entropy."
messages
=
[
messages
=
[
{
{
...
@@ -21,13 +37,11 @@ def test_chat():
...
@@ -21,13 +37,11 @@ def test_chat():
"content"
:
prompt1
"content"
:
prompt1
},
},
]
]
outputs
=
llm
.
chat
(
messages
)
outputs
=
text_
llm
.
chat
(
messages
)
assert
len
(
outputs
)
==
1
assert
len
(
outputs
)
==
1
def
test_multi_chat
():
def
test_multi_chat
(
text_llm
):
llm
=
LLM
(
model
=
"meta-llama/Llama-3.2-1B-Instruct"
)
prompt1
=
"Explain the concept of entropy."
prompt1
=
"Explain the concept of entropy."
prompt2
=
"Explain what among us is."
prompt2
=
"Explain what among us is."
...
@@ -55,13 +69,14 @@ def test_multi_chat():
...
@@ -55,13 +69,14 @@ def test_multi_chat():
messages
=
[
conversation1
,
conversation2
]
messages
=
[
conversation1
,
conversation2
]
outputs
=
llm
.
chat
(
messages
)
outputs
=
text_
llm
.
chat
(
messages
)
assert
len
(
outputs
)
==
2
assert
len
(
outputs
)
==
2
@
pytest
.
mark
.
parametrize
(
"image_urls"
,
@
pytest
.
fixture
(
scope
=
"function"
)
[[
TEST_IMAGE_URLS
[
0
],
TEST_IMAGE_URLS
[
1
]]])
def
vision_llm
():
def
test_chat_multi_image
(
image_urls
:
list
[
str
]):
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm
=
LLM
(
llm
=
LLM
(
model
=
"microsoft/Phi-3.5-vision-instruct"
,
model
=
"microsoft/Phi-3.5-vision-instruct"
,
max_model_len
=
4096
,
max_model_len
=
4096
,
...
@@ -69,8 +84,20 @@ def test_chat_multi_image(image_urls: list[str]):
...
@@ -69,8 +84,20 @@ def test_chat_multi_image(image_urls: list[str]):
enforce_eager
=
True
,
enforce_eager
=
True
,
trust_remote_code
=
True
,
trust_remote_code
=
True
,
limit_mm_per_prompt
=
{
"image"
:
2
},
limit_mm_per_prompt
=
{
"image"
:
2
},
seed
=
0
,
)
)
with
llm
.
deprecate_legacy_api
():
yield
weakref
.
proxy
(
llm
)
del
llm
cleanup_dist_env_and_memory
()
@
pytest
.
mark
.
parametrize
(
"image_urls"
,
[[
TEST_IMAGE_URLS
[
0
],
TEST_IMAGE_URLS
[
1
]]])
def
test_chat_multi_image
(
vision_llm
,
image_urls
:
list
[
str
]):
messages
=
[{
messages
=
[{
"role"
:
"role"
:
"user"
,
"user"
,
...
@@ -87,16 +114,15 @@ def test_chat_multi_image(image_urls: list[str]):
...
@@ -87,16 +114,15 @@ def test_chat_multi_image(image_urls: list[str]):
},
},
],
],
}]
}]
outputs
=
llm
.
chat
(
messages
)
outputs
=
vision_
llm
.
chat
(
messages
)
assert
len
(
outputs
)
>=
0
assert
len
(
outputs
)
>=
0
def
test_llm_chat_tokenization_no_double_bos
():
def
test_llm_chat_tokenization_no_double_bos
(
text_llm
):
"""
"""
LLM.chat() should not add special tokens when using chat templates.
LLM.chat() should not add special tokens when using chat templates.
Check we get a single BOS token for llama chat.
Check we get a single BOS token for llama chat.
"""
"""
llm
=
LLM
(
model
=
"meta-llama/Llama-3.2-1B-Instruct"
,
enforce_eager
=
True
)
messages
=
[
messages
=
[
{
{
"role"
:
"system"
,
"role"
:
"system"
,
...
@@ -107,13 +133,64 @@ def test_llm_chat_tokenization_no_double_bos():
...
@@ -107,13 +133,64 @@ def test_llm_chat_tokenization_no_double_bos():
"content"
:
"Hello!"
"content"
:
"Hello!"
},
},
]
]
outputs
=
llm
.
chat
(
messages
)
outputs
=
text_
llm
.
chat
(
messages
)
assert
len
(
outputs
)
==
1
assert
len
(
outputs
)
==
1
prompt_token_ids
=
getattr
(
outputs
[
0
],
"prompt_token_ids"
,
None
)
prompt_token_ids
=
outputs
[
0
].
prompt_token_ids
assert
prompt_token_ids
is
not
None
assert
prompt_token_ids
is
not
None
bos_token
=
llm
.
get_tokenizer
().
bos_token_id
bos_token
=
text_
llm
.
get_tokenizer
().
bos_token_id
# Ensure we have a single BOS
# Ensure we have a single BOS
assert
prompt_token_ids
[
0
]
==
bos_token
assert
prompt_token_ids
[
0
]
==
bos_token
assert
prompt_token_ids
[
1
]
!=
bos_token
,
"Double BOS"
assert
prompt_token_ids
[
1
]
!=
bos_token
,
"Double BOS"
@
pytest
.
fixture
(
scope
=
"function"
)
def
thinking_llm
():
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm
=
LLM
(
model
=
"Qwen/Qwen3-0.6B"
,
max_model_len
=
4096
,
enforce_eager
=
True
,
seed
=
0
,
)
with
llm
.
deprecate_legacy_api
():
yield
weakref
.
proxy
(
llm
)
del
llm
cleanup_dist_env_and_memory
()
@
pytest
.
mark
.
parametrize
(
"enable_thinking"
,
[
True
,
False
])
def
test_chat_extra_kwargs
(
thinking_llm
,
enable_thinking
):
messages
=
[
{
"role"
:
"system"
,
"content"
:
"You are a helpful assistant"
},
{
"role"
:
"user"
,
"content"
:
"What is 1+1?"
},
]
outputs
=
thinking_llm
.
chat
(
messages
,
chat_template_kwargs
=
{
"enable_thinking"
:
enable_thinking
},
)
assert
len
(
outputs
)
==
1
prompt_token_ids
=
outputs
[
0
].
prompt_token_ids
assert
prompt_token_ids
is
not
None
think_id
=
thinking_llm
.
get_tokenizer
().
get_vocab
()[
"<think>"
]
if
enable_thinking
:
assert
think_id
not
in
prompt_token_ids
else
:
# The chat template includes dummy thinking process
assert
think_id
in
prompt_token_ids
tests/entrypoints/llm/test_collective_rpc.py
View file @
7a985548
...
@@ -10,7 +10,7 @@ from ...utils import create_new_process_for_each_test
...
@@ -10,7 +10,7 @@ from ...utils import create_new_process_for_each_test
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"mp"
,
"ray"
])
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"mp"
,
"ray"
])
@
create_new_process_for_each_test
()
@
create_new_process_for_each_test
()
def
test_collective_rpc
(
tp_size
,
backend
):
def
test_collective_rpc
(
tp_size
,
backend
,
monkeypatch
):
if
tp_size
==
1
and
backend
==
"ray"
:
if
tp_size
==
1
and
backend
==
"ray"
:
pytest
.
skip
(
"Skip duplicate test case"
)
pytest
.
skip
(
"Skip duplicate test case"
)
if
tp_size
==
1
:
if
tp_size
==
1
:
...
@@ -21,6 +21,7 @@ def test_collective_rpc(tp_size, backend):
...
@@ -21,6 +21,7 @@ def test_collective_rpc(tp_size, backend):
def
echo_rank
(
self
):
def
echo_rank
(
self
):
return
self
.
rank
return
self
.
rank
monkeypatch
.
setenv
(
"VLLM_ALLOW_INSECURE_SERIALIZATION"
,
"1"
)
llm
=
LLM
(
model
=
"meta-llama/Llama-3.2-1B-Instruct"
,
llm
=
LLM
(
model
=
"meta-llama/Llama-3.2-1B-Instruct"
,
enforce_eager
=
True
,
enforce_eager
=
True
,
load_format
=
"dummy"
,
load_format
=
"dummy"
,
...
...
tests/entrypoints/llm/test_guided_generate.py
View file @
7a985548
...
@@ -16,10 +16,11 @@ from vllm.sampling_params import GuidedDecodingParams, SamplingParams
...
@@ -16,10 +16,11 @@ from vllm.sampling_params import GuidedDecodingParams, SamplingParams
MODEL_NAME
=
"Qwen/Qwen2.5-1.5B-Instruct"
MODEL_NAME
=
"Qwen/Qwen2.5-1.5B-Instruct"
GUIDED_DECODING_BACKENDS
=
[
GUIDED_DECODING_BACKENDS
=
[
"outlines"
,
# (backend, disable_any_whitespace),
"lm-format-enforcer"
,
(
"outlines"
,
False
),
"xgrammar:disable-any-whitespace"
,
(
"lm-format-enforcer"
,
False
),
"guidance:disable-any-whitespace"
,
(
"xgrammar"
,
True
),
(
"guidance"
,
True
),
]
]
...
@@ -36,13 +37,17 @@ def llm():
...
@@ -36,13 +37,17 @@ def llm():
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
GUIDED_DECODING_BACKENDS
)
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend,disable_any_whitespace"
,
def
test_guided_regex
(
sample_regex
,
llm
,
guided_decoding_backend
:
str
):
GUIDED_DECODING_BACKENDS
)
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
def
test_guided_regex
(
sample_regex
,
llm
,
guided_decoding_backend
:
str
,
top_p
=
0.95
,
disable_any_whitespace
:
bool
):
guided_decoding
=
GuidedDecodingParams
(
sampling_params
=
SamplingParams
(
regex
=
sample_regex
,
temperature
=
0.8
,
backend
=
guided_decoding_backend
))
top_p
=
0.95
,
guided_decoding
=
GuidedDecodingParams
(
regex
=
sample_regex
,
backend
=
guided_decoding_backend
,
disable_any_whitespace
=
disable_any_whitespace
))
outputs
=
llm
.
generate
(
prompts
=
[
outputs
=
llm
.
generate
(
prompts
=
[
f
"Give an example IPv4 address with this regex:
{
sample_regex
}
"
f
"Give an example IPv4 address with this regex:
{
sample_regex
}
"
]
*
2
,
]
*
2
,
...
@@ -62,14 +67,18 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str):
...
@@ -62,14 +67,18 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str):
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
GUIDED_DECODING_BACKENDS
)
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend,disable_any_whitespace"
,
GUIDED_DECODING_BACKENDS
)
def
test_guided_json_completion
(
sample_json_schema
,
llm
,
def
test_guided_json_completion
(
sample_json_schema
,
llm
,
guided_decoding_backend
:
str
):
guided_decoding_backend
:
str
,
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
disable_any_whitespace
:
bool
):
max_tokens
=
1000
,
sampling_params
=
SamplingParams
(
guided_decoding
=
GuidedDecodingParams
(
temperature
=
1.0
,
json
=
sample_json_schema
,
max_tokens
=
1000
,
backend
=
guided_decoding_backend
))
guided_decoding
=
GuidedDecodingParams
(
json
=
sample_json_schema
,
backend
=
guided_decoding_backend
,
disable_any_whitespace
=
disable_any_whitespace
))
outputs
=
llm
.
generate
(
prompts
=
[
outputs
=
llm
.
generate
(
prompts
=
[
f
"Give an example JSON for an employee profile "
f
"Give an example JSON for an employee profile "
f
"that fits this schema:
{
sample_json_schema
}
"
f
"that fits this schema:
{
sample_json_schema
}
"
...
@@ -92,14 +101,18 @@ def test_guided_json_completion(sample_json_schema, llm,
...
@@ -92,14 +101,18 @@ def test_guided_json_completion(sample_json_schema, llm,
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
GUIDED_DECODING_BACKENDS
)
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend,disable_any_whitespace"
,
GUIDED_DECODING_BACKENDS
)
def
test_guided_complex_json_completion
(
sample_complex_json_schema
,
llm
,
def
test_guided_complex_json_completion
(
sample_complex_json_schema
,
llm
,
guided_decoding_backend
:
str
):
guided_decoding_backend
:
str
,
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
disable_any_whitespace
:
bool
):
max_tokens
=
1000
,
sampling_params
=
SamplingParams
(
guided_decoding
=
GuidedDecodingParams
(
temperature
=
1.0
,
json
=
sample_complex_json_schema
,
max_tokens
=
1000
,
backend
=
guided_decoding_backend
))
guided_decoding
=
GuidedDecodingParams
(
json
=
sample_complex_json_schema
,
backend
=
guided_decoding_backend
,
disable_any_whitespace
=
disable_any_whitespace
))
outputs
=
llm
.
generate
(
prompts
=
[
outputs
=
llm
.
generate
(
prompts
=
[
f
"Give an example JSON for an assignment grade "
f
"Give an example JSON for an assignment grade "
f
"that fits this schema:
{
sample_complex_json_schema
}
"
f
"that fits this schema:
{
sample_complex_json_schema
}
"
...
@@ -123,14 +136,18 @@ def test_guided_complex_json_completion(sample_complex_json_schema, llm,
...
@@ -123,14 +136,18 @@ def test_guided_complex_json_completion(sample_complex_json_schema, llm,
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
GUIDED_DECODING_BACKENDS
)
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend,disable_any_whitespace"
,
GUIDED_DECODING_BACKENDS
)
def
test_guided_definition_json_completion
(
sample_definition_json_schema
,
llm
,
def
test_guided_definition_json_completion
(
sample_definition_json_schema
,
llm
,
guided_decoding_backend
:
str
):
guided_decoding_backend
:
str
,
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
disable_any_whitespace
:
bool
):
max_tokens
=
1000
,
sampling_params
=
SamplingParams
(
guided_decoding
=
GuidedDecodingParams
(
temperature
=
1.0
,
json
=
sample_definition_json_schema
,
max_tokens
=
1000
,
backend
=
guided_decoding_backend
))
guided_decoding
=
GuidedDecodingParams
(
json
=
sample_definition_json_schema
,
backend
=
guided_decoding_backend
,
disable_any_whitespace
=
disable_any_whitespace
))
outputs
=
llm
.
generate
(
prompts
=
[
outputs
=
llm
.
generate
(
prompts
=
[
f
"Give an example JSON for solving 8x + 7 = -23 "
f
"Give an example JSON for solving 8x + 7 = -23 "
f
"that fits this schema:
{
sample_definition_json_schema
}
"
f
"that fits this schema:
{
sample_definition_json_schema
}
"
...
@@ -154,14 +171,18 @@ def test_guided_definition_json_completion(sample_definition_json_schema, llm,
...
@@ -154,14 +171,18 @@ def test_guided_definition_json_completion(sample_definition_json_schema, llm,
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
GUIDED_DECODING_BACKENDS
)
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend,disable_any_whitespace"
,
GUIDED_DECODING_BACKENDS
)
def
test_guided_enum_json_completion
(
sample_enum_json_schema
,
llm
,
def
test_guided_enum_json_completion
(
sample_enum_json_schema
,
llm
,
guided_decoding_backend
:
str
):
guided_decoding_backend
:
str
,
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
disable_any_whitespace
:
bool
):
max_tokens
=
1000
,
sampling_params
=
SamplingParams
(
guided_decoding
=
GuidedDecodingParams
(
temperature
=
1.0
,
json
=
sample_enum_json_schema
,
max_tokens
=
1000
,
backend
=
guided_decoding_backend
))
guided_decoding
=
GuidedDecodingParams
(
json
=
sample_enum_json_schema
,
backend
=
guided_decoding_backend
,
disable_any_whitespace
=
disable_any_whitespace
))
outputs
=
llm
.
generate
(
prompts
=
[
outputs
=
llm
.
generate
(
prompts
=
[
"Create a bug report JSON that fits this schema: "
"Create a bug report JSON that fits this schema: "
f
"
{
sample_enum_json_schema
}
. Make it for a high priority critical bug."
f
"
{
sample_enum_json_schema
}
. Make it for a high priority critical bug."
...
@@ -195,14 +216,18 @@ def test_guided_enum_json_completion(sample_enum_json_schema, llm,
...
@@ -195,14 +216,18 @@ def test_guided_enum_json_completion(sample_enum_json_schema, llm,
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
GUIDED_DECODING_BACKENDS
)
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend,disable_any_whitespace"
,
GUIDED_DECODING_BACKENDS
)
def
test_guided_choice_completion
(
sample_guided_choice
,
llm
,
def
test_guided_choice_completion
(
sample_guided_choice
,
llm
,
guided_decoding_backend
:
str
):
guided_decoding_backend
:
str
,
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
disable_any_whitespace
:
bool
):
top_p
=
0.95
,
sampling_params
=
SamplingParams
(
guided_decoding
=
GuidedDecodingParams
(
temperature
=
0.8
,
choice
=
sample_guided_choice
,
top_p
=
0.95
,
backend
=
guided_decoding_backend
))
guided_decoding
=
GuidedDecodingParams
(
choice
=
sample_guided_choice
,
backend
=
guided_decoding_backend
,
disable_any_whitespace
=
disable_any_whitespace
))
outputs
=
llm
.
generate
(
outputs
=
llm
.
generate
(
prompts
=
"The best language for type-safe systems programming is "
,
prompts
=
"The best language for type-safe systems programming is "
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
...
@@ -221,15 +246,19 @@ def test_guided_choice_completion(sample_guided_choice, llm,
...
@@ -221,15 +246,19 @@ def test_guided_choice_completion(sample_guided_choice, llm,
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
GUIDED_DECODING_BACKENDS
)
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend,disable_any_whitespace"
,
GUIDED_DECODING_BACKENDS
)
def
test_guided_grammar
(
sample_sql_statements
,
llm
,
def
test_guided_grammar
(
sample_sql_statements
,
llm
,
guided_decoding_backend
:
str
):
guided_decoding_backend
:
str
,
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
disable_any_whitespace
:
bool
):
top_p
=
0.95
,
sampling_params
=
SamplingParams
(
max_tokens
=
1000
,
temperature
=
0.8
,
guided_decoding
=
GuidedDecodingParams
(
top_p
=
0.95
,
grammar
=
sample_sql_statements
,
max_tokens
=
1000
,
backend
=
guided_decoding_backend
))
guided_decoding
=
GuidedDecodingParams
(
grammar
=
sample_sql_statements
,
backend
=
guided_decoding_backend
,
disable_any_whitespace
=
disable_any_whitespace
))
outputs
=
llm
.
generate
(
outputs
=
llm
.
generate
(
prompts
=
(
"Generate a sql state that select col_1 from "
prompts
=
(
"Generate a sql state that select col_1 from "
"table_1 where it is equals to 1"
),
"table_1 where it is equals to 1"
),
...
@@ -300,7 +329,8 @@ def test_disable_guided_decoding_fallback(sample_regex, llm):
...
@@ -300,7 +329,8 @@ def test_disable_guided_decoding_fallback(sample_regex, llm):
top_p
=
0.95
,
top_p
=
0.95
,
guided_decoding
=
GuidedDecodingParams
(
guided_decoding
=
GuidedDecodingParams
(
json
=
unsupported_json
,
json
=
unsupported_json
,
backend
=
"xgrammar:no-fallback"
))
backend
=
"xgrammar"
,
disable_fallback
=
True
))
with
pytest
.
raises
(
with
pytest
.
raises
(
ValueError
,
ValueError
,
...
@@ -312,14 +342,18 @@ def test_disable_guided_decoding_fallback(sample_regex, llm):
...
@@ -312,14 +342,18 @@ def test_disable_guided_decoding_fallback(sample_regex, llm):
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
GUIDED_DECODING_BACKENDS
)
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend,disable_any_whitespace"
,
def
test_guided_json_object
(
llm
,
guided_decoding_backend
:
str
):
GUIDED_DECODING_BACKENDS
)
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
def
test_guided_json_object
(
llm
,
guided_decoding_backend
:
str
,
max_tokens
=
100
,
disable_any_whitespace
:
bool
):
n
=
2
,
sampling_params
=
SamplingParams
(
guided_decoding
=
GuidedDecodingParams
(
temperature
=
1.0
,
json_object
=
True
,
max_tokens
=
100
,
backend
=
guided_decoding_backend
))
n
=
2
,
guided_decoding
=
GuidedDecodingParams
(
json_object
=
True
,
backend
=
guided_decoding_backend
,
disable_any_whitespace
=
disable_any_whitespace
))
outputs
=
llm
.
generate
(
outputs
=
llm
.
generate
(
prompts
=
(
"Generate a JSON object with curly braces for a person with "
prompts
=
(
"Generate a JSON object with curly braces for a person with "
...
@@ -337,7 +371,7 @@ def test_guided_json_object(llm, guided_decoding_backend: str):
...
@@ -337,7 +371,7 @@ def test_guided_json_object(llm, guided_decoding_backend: str):
print
(
generated_text
)
print
(
generated_text
)
assert
generated_text
is
not
None
assert
generated_text
is
not
None
if
'
disable
-
any
-
whitespace
'
in
guided_decoding_backend
:
if
disable
_
any
_
whitespace
:
assert
"
\n
"
not
in
generated_text
assert
"
\n
"
not
in
generated_text
# Parse to verify it is valid JSON
# Parse to verify it is valid JSON
...
@@ -359,14 +393,18 @@ class CarDescription(BaseModel):
...
@@ -359,14 +393,18 @@ class CarDescription(BaseModel):
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
GUIDED_DECODING_BACKENDS
)
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend,disable_any_whitespace"
,
def
test_guided_json_completion_with_enum
(
llm
,
guided_decoding_backend
:
str
):
GUIDED_DECODING_BACKENDS
)
def
test_guided_json_completion_with_enum
(
llm
,
guided_decoding_backend
:
str
,
disable_any_whitespace
:
bool
):
json_schema
=
CarDescription
.
model_json_schema
()
json_schema
=
CarDescription
.
model_json_schema
()
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
sampling_params
=
SamplingParams
(
max_tokens
=
1000
,
temperature
=
1.0
,
guided_decoding
=
GuidedDecodingParams
(
max_tokens
=
1000
,
json
=
json_schema
,
guided_decoding
=
GuidedDecodingParams
(
backend
=
guided_decoding_backend
))
json
=
json_schema
,
backend
=
guided_decoding_backend
,
disable_any_whitespace
=
disable_any_whitespace
))
outputs
=
llm
.
generate
(
outputs
=
llm
.
generate
(
prompts
=
"Generate a JSON with the brand, model and car_type of"
prompts
=
"Generate a JSON with the brand, model and car_type of"
"the most iconic car from the 90's"
,
"the most iconic car from the 90's"
,
...
@@ -387,9 +425,10 @@ def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str):
...
@@ -387,9 +425,10 @@ def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str):
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
GUIDED_DECODING_BACKENDS
)
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend,disable_any_whitespace"
,
def
test_guided_number_range_json_completion
(
llm
,
GUIDED_DECODING_BACKENDS
)
guided_decoding_backend
:
str
):
def
test_guided_number_range_json_completion
(
llm
,
guided_decoding_backend
:
str
,
disable_any_whitespace
:
bool
):
sample_output_schema
=
{
sample_output_schema
=
{
"type"
:
"object"
,
"type"
:
"object"
,
"properties"
:
{
"properties"
:
{
...
@@ -413,8 +452,10 @@ def test_guided_number_range_json_completion(llm,
...
@@ -413,8 +452,10 @@ def test_guided_number_range_json_completion(llm,
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
temperature
=
1.0
,
max_tokens
=
1000
,
max_tokens
=
1000
,
guided_decoding
=
GuidedDecodingParams
(
json
=
sample_output_schema
,
guided_decoding
=
GuidedDecodingParams
(
backend
=
guided_decoding_backend
),
json
=
sample_output_schema
,
backend
=
guided_decoding_backend
,
disable_any_whitespace
=
disable_any_whitespace
),
)
)
outputs
=
llm
.
generate
(
outputs
=
llm
.
generate
(
prompts
=
[
prompts
=
[
...
@@ -466,8 +507,12 @@ def test_guidance_no_additional_properties(llm):
...
@@ -466,8 +507,12 @@ def test_guidance_no_additional_properties(llm):
"large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20"
"large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20"
"<|im_end|>
\n
<|im_start|>assistant
\n
"
)
"<|im_end|>
\n
<|im_start|>assistant
\n
"
)
def
generate_with_backend
(
backend
):
def
generate_with_backend
(
backend
,
disable_additional_properties
):
guided_params
=
GuidedDecodingParams
(
json
=
schema
,
backend
=
backend
)
guided_params
=
GuidedDecodingParams
(
json
=
schema
,
backend
=
backend
,
disable_any_whitespace
=
True
,
disable_additional_properties
=
disable_additional_properties
)
sampling_params
=
SamplingParams
(
temperature
=
0
,
sampling_params
=
SamplingParams
(
temperature
=
0
,
max_tokens
=
256
,
max_tokens
=
256
,
guided_decoding
=
guided_params
)
guided_decoding
=
guided_params
)
...
@@ -481,7 +526,7 @@ def test_guidance_no_additional_properties(llm):
...
@@ -481,7 +526,7 @@ def test_guidance_no_additional_properties(llm):
jsonschema
.
validate
(
instance
=
parsed_json
,
schema
=
schema
)
jsonschema
.
validate
(
instance
=
parsed_json
,
schema
=
schema
)
return
parsed_json
return
parsed_json
base_generated
=
generate_with_backend
(
'
guidance
:disable-any-whitespace'
)
base_generated
=
generate_with_backend
(
"
guidance
"
,
False
)
assert
"a1"
in
base_generated
assert
"a1"
in
base_generated
assert
"a2"
in
base_generated
assert
"a2"
in
base_generated
assert
"a3"
in
base_generated
assert
"a3"
in
base_generated
...
@@ -490,8 +535,7 @@ def test_guidance_no_additional_properties(llm):
...
@@ -490,8 +535,7 @@ def test_guidance_no_additional_properties(llm):
assert
"a5"
in
base_generated
assert
"a5"
in
base_generated
assert
"a6"
in
base_generated
assert
"a6"
in
base_generated
generated
=
generate_with_backend
(
generated
=
generate_with_backend
(
"guidance"
,
True
)
'guidance:no-additional-properties,disable-any-whitespace'
)
assert
"a1"
in
generated
assert
"a1"
in
generated
assert
"a2"
in
generated
assert
"a2"
in
generated
assert
"a3"
in
generated
assert
"a3"
in
generated
...
...
tests/entrypoints/openai/test_audio.py
View file @
7a985548
...
@@ -272,7 +272,7 @@ async def test_chat_streaming_audio(client: openai.AsyncOpenAI,
...
@@ -272,7 +272,7 @@ async def test_chat_streaming_audio(client: openai.AsyncOpenAI,
chat_completion
=
await
client
.
chat
.
completions
.
create
(
chat_completion
=
await
client
.
chat
.
completions
.
create
(
model
=
model_name
,
model
=
model_name
,
messages
=
messages
,
messages
=
messages
,
max_completion_tokens
=
10
,
max_completion_tokens
=
8
,
temperature
=
0.0
,
temperature
=
0.0
,
)
)
output
=
chat_completion
.
choices
[
0
].
message
.
content
output
=
chat_completion
.
choices
[
0
].
message
.
content
...
@@ -282,7 +282,7 @@ async def test_chat_streaming_audio(client: openai.AsyncOpenAI,
...
@@ -282,7 +282,7 @@ async def test_chat_streaming_audio(client: openai.AsyncOpenAI,
stream
=
await
client
.
chat
.
completions
.
create
(
stream
=
await
client
.
chat
.
completions
.
create
(
model
=
model_name
,
model
=
model_name
,
messages
=
messages
,
messages
=
messages
,
max_completion_tokens
=
10
,
max_completion_tokens
=
8
,
temperature
=
0.0
,
temperature
=
0.0
,
stream
=
True
,
stream
=
True
,
)
)
...
@@ -332,7 +332,7 @@ async def test_chat_streaming_input_audio(client: openai.AsyncOpenAI,
...
@@ -332,7 +332,7 @@ async def test_chat_streaming_input_audio(client: openai.AsyncOpenAI,
chat_completion
=
await
client
.
chat
.
completions
.
create
(
chat_completion
=
await
client
.
chat
.
completions
.
create
(
model
=
model_name
,
model
=
model_name
,
messages
=
messages
,
messages
=
messages
,
max_completion_tokens
=
10
,
max_completion_tokens
=
8
,
temperature
=
0.0
,
temperature
=
0.0
,
)
)
output
=
chat_completion
.
choices
[
0
].
message
.
content
output
=
chat_completion
.
choices
[
0
].
message
.
content
...
@@ -342,7 +342,7 @@ async def test_chat_streaming_input_audio(client: openai.AsyncOpenAI,
...
@@ -342,7 +342,7 @@ async def test_chat_streaming_input_audio(client: openai.AsyncOpenAI,
stream
=
await
client
.
chat
.
completions
.
create
(
stream
=
await
client
.
chat
.
completions
.
create
(
model
=
model_name
,
model
=
model_name
,
messages
=
messages
,
messages
=
messages
,
max_completion_tokens
=
10
,
max_completion_tokens
=
8
,
temperature
=
0.0
,
temperature
=
0.0
,
stream
=
True
,
stream
=
True
,
)
)
...
...
tests/entrypoints/openai/test_chat_template.py
View file @
7a985548
...
@@ -2,11 +2,13 @@
...
@@ -2,11 +2,13 @@
import
pytest
import
pytest
from
vllm.config
import
ModelConfig
from
vllm.entrypoints.chat_utils
import
(
apply_hf_chat_template
,
from
vllm.entrypoints.chat_utils
import
(
apply_hf_chat_template
,
load_chat_template
)
load_chat_template
)
from
vllm.entrypoints.openai.protocol
import
ChatCompletionRequest
from
vllm.entrypoints.openai.protocol
import
ChatCompletionRequest
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
from
...models.registry
import
HF_EXAMPLE_MODELS
from
...utils
import
VLLM_PATH
from
...utils
import
VLLM_PATH
chatml_jinja_path
=
VLLM_PATH
/
"examples/template_chatml.jinja"
chatml_jinja_path
=
VLLM_PATH
/
"examples/template_chatml.jinja"
...
@@ -91,8 +93,22 @@ def test_no_load_chat_template_literallike():
...
@@ -91,8 +93,22 @@ def test_no_load_chat_template_literallike():
MODEL_TEMPLATE_GENERATON_OUTPUT
)
MODEL_TEMPLATE_GENERATON_OUTPUT
)
def
test_get_gen_prompt
(
model
,
template
,
add_generation_prompt
,
def
test_get_gen_prompt
(
model
,
template
,
add_generation_prompt
,
continue_final_message
,
expected_output
):
continue_final_message
,
expected_output
):
model_info
=
HF_EXAMPLE_MODELS
.
find_hf_info
(
model
)
model_info
.
check_available_online
(
on_fail
=
"skip"
)
model_config
=
ModelConfig
(
model
,
tokenizer
=
model_info
.
tokenizer
or
model
,
tokenizer_mode
=
model_info
.
tokenizer_mode
,
trust_remote_code
=
model_info
.
trust_remote_code
,
hf_overrides
=
model_info
.
hf_overrides
,
)
# Initialize the tokenizer
# Initialize the tokenizer
tokenizer
=
get_tokenizer
(
tokenizer_name
=
model
)
tokenizer
=
get_tokenizer
(
tokenizer_name
=
model_config
.
tokenizer
,
trust_remote_code
=
model_config
.
trust_remote_code
,
)
template_content
=
load_chat_template
(
chat_template
=
template
)
template_content
=
load_chat_template
(
chat_template
=
template
)
# Create a mock request object using keyword arguments
# Create a mock request object using keyword arguments
...
@@ -106,10 +122,10 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
...
@@ -106,10 +122,10 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
# Call the function and get the result
# Call the function and get the result
result
=
apply_hf_chat_template
(
result
=
apply_hf_chat_template
(
tokenizer
,
tokenizer
=
tokenizer
,
trust_remote_code
=
True
,
conversation
=
mock_request
.
messages
,
conversation
=
mock_request
.
messages
,
chat_template
=
mock_request
.
chat_template
or
template_content
,
chat_template
=
mock_request
.
chat_template
or
template_content
,
model_config
=
model_config
,
tools
=
None
,
tools
=
None
,
add_generation_prompt
=
mock_request
.
add_generation_prompt
,
add_generation_prompt
=
mock_request
.
add_generation_prompt
,
continue_final_message
=
mock_request
.
continue_final_message
,
continue_final_message
=
mock_request
.
continue_final_message
,
...
...
tests/entrypoints/openai/test_chat_with_tool_reasoning.py
View file @
7a985548
...
@@ -13,9 +13,9 @@ MODEL_NAME = "Qwen/QwQ-32B"
...
@@ -13,9 +13,9 @@ MODEL_NAME = "Qwen/QwQ-32B"
@
pytest
.
fixture
(
scope
=
"module"
)
@
pytest
.
fixture
(
scope
=
"module"
)
def
server
():
# noqa: F811
def
server
():
# noqa: F811
args
=
[
args
=
[
"--max-model-len"
,
"8192"
,
"--enforce-eager"
,
"--
enable-
reasoning"
,
"--max-model-len"
,
"8192"
,
"--enforce-eager"
,
"--reasoning
-parser
"
,
"--reasoning-parser"
,
"deepseek_r1"
,
"--enable-auto-tool-choice"
,
"deepseek_r1"
,
"--enable-auto-tool-choice"
,
"--tool-call-parser"
,
"--tool-call-parser"
,
"hermes"
"hermes"
]
]
with
RemoteOpenAIServer
(
MODEL_NAME
,
args
)
as
remote_server
:
with
RemoteOpenAIServer
(
MODEL_NAME
,
args
)
as
remote_server
:
...
...
tests/entrypoints/openai/test_classification.py
0 → 100644
View file @
7a985548
# SPDX-License-Identifier: Apache-2.0
import
pytest
import
requests
from
vllm.entrypoints.openai.protocol
import
ClassificationResponse
from
...utils
import
RemoteOpenAIServer
MODEL_NAME
=
"jason9693/Qwen2.5-1.5B-apeach"
DTYPE
=
"float32"
# Use float32 to avoid NaN issue
@
pytest
.
fixture
(
scope
=
"module"
)
def
server
():
args
=
[
"--enforce-eager"
,
"--max-model-len"
,
"512"
,
"--dtype"
,
DTYPE
,
]
with
RemoteOpenAIServer
(
MODEL_NAME
,
args
)
as
remote_server
:
yield
remote_server
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
])
def
test_single_input_classification
(
server
:
RemoteOpenAIServer
,
model_name
:
str
):
input_text
=
"This product was excellent and exceeded my expectations"
classification_response
=
requests
.
post
(
server
.
url_for
(
"classify"
),
json
=
{
"model"
:
model_name
,
"input"
:
input_text
},
)
classification_response
.
raise_for_status
()
output
=
ClassificationResponse
.
model_validate
(
classification_response
.
json
())
assert
output
.
object
==
"list"
assert
output
.
model
==
MODEL_NAME
assert
len
(
output
.
data
)
==
1
assert
hasattr
(
output
.
data
[
0
],
"label"
)
assert
hasattr
(
output
.
data
[
0
],
"probs"
)
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
])
def
test_multiple_inputs_classification
(
server
:
RemoteOpenAIServer
,
model_name
:
str
):
input_texts
=
[
"The product arrived on time and works perfectly"
,
"I'm very satisfied with my purchase, would buy again"
,
"The customer service was helpful and resolved my issue quickly"
,
"This product broke after one week, terrible quality"
,
"I'm very disappointed with this purchase, complete waste of money"
,
"The customer service was rude and unhelpful"
,
]
classification_response
=
requests
.
post
(
server
.
url_for
(
"classify"
),
json
=
{
"model"
:
model_name
,
"input"
:
input_texts
},
)
output
=
ClassificationResponse
.
model_validate
(
classification_response
.
json
())
assert
len
(
output
.
data
)
==
len
(
input_texts
)
for
i
,
item
in
enumerate
(
output
.
data
):
assert
item
.
index
==
i
assert
hasattr
(
item
,
"label"
)
assert
hasattr
(
item
,
"probs"
)
assert
len
(
item
.
probs
)
==
item
.
num_classes
assert
item
.
label
in
[
"Default"
,
"Spoiled"
]
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
])
def
test_truncate_prompt_tokens
(
server
:
RemoteOpenAIServer
,
model_name
:
str
):
long_text
=
"hello "
*
600
classification_response
=
requests
.
post
(
server
.
url_for
(
"classify"
),
json
=
{
"model"
:
model_name
,
"input"
:
long_text
,
"truncate_prompt_tokens"
:
5
},
)
classification_response
.
raise_for_status
()
output
=
ClassificationResponse
.
model_validate
(
classification_response
.
json
())
assert
len
(
output
.
data
)
==
1
assert
output
.
data
[
0
].
index
==
0
assert
hasattr
(
output
.
data
[
0
],
"probs"
)
assert
output
.
usage
.
prompt_tokens
==
5
assert
output
.
usage
.
total_tokens
==
5
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
])
def
test_invalid_truncate_prompt_tokens_error
(
server
:
RemoteOpenAIServer
,
model_name
:
str
):
classification_response
=
requests
.
post
(
server
.
url_for
(
"classify"
),
json
=
{
"model"
:
model_name
,
"input"
:
"test"
,
"truncate_prompt_tokens"
:
513
},
)
error
=
classification_response
.
json
()
assert
classification_response
.
status_code
==
400
assert
error
[
"object"
]
==
"error"
assert
"truncate_prompt_tokens"
in
error
[
"message"
]
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
])
def
test_empty_input_error
(
server
:
RemoteOpenAIServer
,
model_name
:
str
):
classification_response
=
requests
.
post
(
server
.
url_for
(
"classify"
),
json
=
{
"model"
:
model_name
,
"input"
:
""
},
)
error
=
classification_response
.
json
()
assert
classification_response
.
status_code
==
400
assert
error
[
"object"
]
==
"error"
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
])
def
test_batch_classification_empty_list
(
server
:
RemoteOpenAIServer
,
model_name
:
str
):
classification_response
=
requests
.
post
(
server
.
url_for
(
"classify"
),
json
=
{
"model"
:
model_name
,
"input"
:
[]
},
)
classification_response
.
raise_for_status
()
output
=
ClassificationResponse
.
model_validate
(
classification_response
.
json
())
assert
output
.
object
==
"list"
assert
isinstance
(
output
.
data
,
list
)
assert
len
(
output
.
data
)
==
0
tests/entrypoints/openai/test_cli_args.py
View file @
7a985548
...
@@ -122,31 +122,23 @@ def test_enable_auto_choice_fails_with_enable_reasoning(serve_parser):
...
@@ -122,31 +122,23 @@ def test_enable_auto_choice_fails_with_enable_reasoning(serve_parser):
"""Ensure validation fails if reasoning is enabled with auto tool choice"""
"""Ensure validation fails if reasoning is enabled with auto tool choice"""
args
=
serve_parser
.
parse_args
(
args
=
[
args
=
serve_parser
.
parse_args
(
args
=
[
"--enable-auto-tool-choice"
,
"--enable-auto-tool-choice"
,
"--enable-reasoning"
,
"--reasoning-parser"
,
"deepseek_r1"
,
])
])
with
pytest
.
raises
(
TypeError
):
with
pytest
.
raises
(
TypeError
):
validate_parsed_serve_args
(
args
)
validate_parsed_serve_args
(
args
)
def
test_
enable_reasoning_
passes_with_reasoning_parser
(
serve_parser
):
def
test_passes_with_reasoning_parser
(
serve_parser
):
"""Ensure validation passes if reasoning is enabled
"""Ensure validation passes if reasoning is enabled
with a reasoning parser"""
with a reasoning parser"""
args
=
serve_parser
.
parse_args
(
args
=
[
args
=
serve_parser
.
parse_args
(
args
=
[
"--enable-reasoning"
,
"--reasoning-parser"
,
"--reasoning-parser"
,
"deepseek_r1"
,
"deepseek_r1"
,
])
])
validate_parsed_serve_args
(
args
)
validate_parsed_serve_args
(
args
)
def
test_enable_reasoning_fails_without_reasoning_parser
(
serve_parser
):
"""Ensure validation fails if reasoning is enabled
without a reasoning parser"""
args
=
serve_parser
.
parse_args
(
args
=
[
"--enable-reasoning"
])
with
pytest
.
raises
(
TypeError
):
validate_parsed_serve_args
(
args
)
def
test_chat_template_validation_for_happy_paths
(
serve_parser
):
def
test_chat_template_validation_for_happy_paths
(
serve_parser
):
"""Ensure validation passes if the chat template exists"""
"""Ensure validation passes if the chat template exists"""
args
=
serve_parser
.
parse_args
(
args
=
serve_parser
.
parse_args
(
...
...
tests/entrypoints/openai/test_completion_with_function_calling.py
0 → 100644
View file @
7a985548
# SPDX-License-Identifier: Apache-2.0
import
openai
# use the official client for correctness check
import
pytest
import
pytest_asyncio
# downloading lora to test lora requests
from
...utils
import
RemoteOpenAIServer
# any model with a chat template should work here
MODEL_NAME
=
"Qwen/Qwen2.5-1.5B-Instruct"
@
pytest
.
fixture
(
scope
=
"module"
)
def
server
():
# noqa: F811
args
=
[
# use half precision for speed and memory savings in CI environment
"--dtype"
,
"half"
,
"--enable-auto-tool-choice"
,
"--guided-decoding-backend"
,
"xgrammar"
,
"--tool-call-parser"
,
"hermes"
]
with
RemoteOpenAIServer
(
MODEL_NAME
,
args
)
as
remote_server
:
yield
remote_server
@
pytest_asyncio
.
fixture
async
def
client
(
server
):
async
with
server
.
get_async_client
()
as
async_client
:
yield
async_client
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
])
async
def
test_required_tool_use
(
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
):
tools
=
[
{
"type"
:
"function"
,
"function"
:
{
"name"
:
"get_current_weather"
,
"description"
:
"Get the current weather in a given location"
,
"parameters"
:
{
"type"
:
"object"
,
"properties"
:
{
"city"
:
{
"type"
:
"string"
,
"description"
:
"The city to find the weather for, e.g. 'Vienna'"
,
"default"
:
"Vienna"
,
},
"country"
:
{
"type"
:
"string"
,
"description"
:
"The country that the city is in, e.g. 'Austria'"
,
},
"unit"
:
{
"type"
:
"string"
,
"description"
:
"The unit to fetch the temperature in"
,
"enum"
:
[
"celsius"
,
"fahrenheit"
],
},
},
"required"
:
[
"country"
,
"unit"
],
},
},
},
{
"type"
:
"function"
,
"function"
:
{
"name"
:
"get_forecast"
,
"description"
:
"Get the weather forecast for a given location"
,
"parameters"
:
{
"type"
:
"object"
,
"properties"
:
{
"city"
:
{
"type"
:
"string"
,
"description"
:
"The city to get the forecast for, e.g. 'Vienna'"
,
"default"
:
"Vienna"
,
},
"country"
:
{
"type"
:
"string"
,
"description"
:
"The country that the city is in, e.g. 'Austria'"
,
},
"days"
:
{
"type"
:
"integer"
,
"description"
:
"Number of days to get the forecast for (1-7)"
,
},
"unit"
:
{
"type"
:
"string"
,
"description"
:
"The unit to fetch the temperature in"
,
"enum"
:
[
"celsius"
,
"fahrenheit"
],
},
},
"required"
:
[
"country"
,
"days"
,
"unit"
],
},
},
},
]
messages
=
[
{
"role"
:
"user"
,
"content"
:
"Hi! How are you doing today?"
},
{
"role"
:
"assistant"
,
"content"
:
"I'm doing well! How can I help you?"
},
{
"role"
:
"user"
,
"content"
:
"Can you tell me what the current weather is in Berlin and the "
\
"forecast for the next 5 days, in fahrenheit?"
,
},
]
# Non-streaming test
chat_completion
=
await
client
.
chat
.
completions
.
create
(
messages
=
messages
,
model
=
model_name
,
tools
=
tools
,
tool_choice
=
"required"
,
)
assert
chat_completion
.
choices
[
0
].
message
.
tool_calls
is
not
None
assert
len
(
chat_completion
.
choices
[
0
].
message
.
tool_calls
)
>
0
# Streaming test
stream
=
await
client
.
chat
.
completions
.
create
(
messages
=
messages
,
model
=
model_name
,
tools
=
tools
,
tool_choice
=
"required"
,
stream
=
True
,
)
output
=
[]
async
for
chunk
in
stream
:
if
chunk
.
choices
and
chunk
.
choices
[
0
].
delta
.
tool_calls
:
output
.
extend
(
chunk
.
choices
[
0
].
delta
.
tool_calls
)
assert
len
(
output
)
>
0
Prev
1
…
12
13
14
15
16
17
18
19
20
…
25
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment