Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d2b52805
Commit
d2b52805
authored
Sep 07, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.10.2rc1' into v0.10.2rc1-ori
parents
9a521c23
5438967f
Changes
501
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1786 additions
and
1170 deletions
+1786
-1170
tests/multimodal/test_utils.py
tests/multimodal/test_utils.py
+324
-2
tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/__init__.py
...thvi_io_processor_plugin/prithvi_io_processor/__init__.py
+8
-0
tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py
...rocessor_plugin/prithvi_io_processor/prithvi_processor.py
+449
-0
tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/types.py
...prithvi_io_processor_plugin/prithvi_io_processor/types.py
+59
-0
tests/plugins/prithvi_io_processor_plugin/setup.py
tests/plugins/prithvi_io_processor_plugin/setup.py
+16
-0
tests/plugins_tests/test_io_processor_plugins.py
tests/plugins_tests/test_io_processor_plugins.py
+137
-0
tests/plugins_tests/test_platform_plugins.py
tests/plugins_tests/test_platform_plugins.py
+9
-0
tests/prefix_caching/test_disable_sliding_window.py
tests/prefix_caching/test_disable_sliding_window.py
+0
-49
tests/prefix_caching/test_prefix_caching.py
tests/prefix_caching/test_prefix_caching.py
+0
-231
tests/quantization/test_compressed_tensors.py
tests/quantization/test_compressed_tensors.py
+62
-4
tests/quantization/test_configs.py
tests/quantization/test_configs.py
+0
-10
tests/quantization/test_fp8.py
tests/quantization/test_fp8.py
+2
-4
tests/quantization/test_lm_head.py
tests/quantization/test_lm_head.py
+2
-6
tests/samplers/test_beam_search.py
tests/samplers/test_beam_search.py
+53
-0
tests/samplers/test_sampler.py
tests/samplers/test_sampler.py
+0
-769
tests/samplers/test_seeded_generate.py
tests/samplers/test_seeded_generate.py
+0
-86
tests/test_sequence.py
tests/test_sequence.py
+38
-2
tests/tokenization/test_detokenize.py
tests/tokenization/test_detokenize.py
+0
-2
tests/tool_use/test_qwen3coder_tool_parser.py
tests/tool_use/test_qwen3coder_tool_parser.py
+173
-5
tests/tool_use/test_seed_oss_tool_parser.py
tests/tool_use/test_seed_oss_tool_parser.py
+454
-0
No files found.
Too many changes to show.
To preserve performance only
501 of 501+
files are displayed.
Plain diff
Email patch
tests/multimodal/test_utils.py
View file @
d2b52805
...
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
base64
import
math
import
mimetypes
import
os
from
tempfile
import
NamedTemporaryFile
,
TemporaryDirectory
...
...
@@ -20,6 +21,8 @@ from vllm.distributed.parallel_state import (init_distributed_environment,
from
vllm.multimodal.image
import
convert_image_mode
from
vllm.multimodal.inputs
import
PlaceholderRange
from
vllm.multimodal.utils
import
(
MediaConnector
,
argsort_mm_positions
,
get_load_balance_assignment
,
run_dp_sharded_mrope_vision_model
,
run_dp_sharded_vision_model
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
get_open_port
,
update_environment_variables
...
...
@@ -425,8 +428,8 @@ def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int,
# Set random seed for reproducibility
current_platform
.
seed_everything
(
0
)
device
=
torch
.
device
(
f
"cuda
:
{
local_rank
}
"
)
torch
.
cuda
.
set_device
(
device
)
device
=
f
"
{
current_platform
.
device_name
}
:
{
local_rank
}
"
current_platform
.
set_device
(
device
)
torch
.
set_default_device
(
device
)
update_environment_variables
({
...
...
@@ -463,3 +466,322 @@ def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int,
# Check that the outputs are close (they should be identical)
assert
torch
.
allclose
(
direct_output
,
sharded_output
,
rtol
=
1e-5
,
atol
=
1e-5
)
@
pytest
.
mark
.
parametrize
(
"sizes,num_gpus,expected_shuffle_indices,expected_gpu_sample_counts,"
"expected_grouped_sizes_per_gpu,test_description"
,
[
# Empty input
([],
2
,
[],
[
0
,
0
],
[
0
,
0
],
"empty input"
),
# Fewer samples than GPUs
([
100
,
200
],
4
,
[
1
,
0
],
[
1
,
1
,
0
,
0
],
[
200
,
100
,
0
,
0
],
"fewer samples than GPUs"
),
# Single GPU
([
100
,
200
,
300
],
1
,
[
2
,
1
,
0
],
[
3
],
[
600
],
"single GPU"
),
# Balanced assignment
([
100
,
100
,
100
,
100
],
2
,
[
0
,
2
,
1
,
3
],
[
2
,
2
],
[
200
,
200
],
"balanced assignment"
),
# Unbalanced sizes - this one is trickier since the algorithm is greedy
([
1000
,
100
,
200
,
50
],
2
,
[
0
,
2
,
1
,
3
],
[
1
,
3
],
[
1000
,
350
],
"unbalanced sizes"
),
],
)
def
test_get_load_balance_assignment_cases
(
sizes
,
num_gpus
,
expected_shuffle_indices
,
expected_gpu_sample_counts
,
expected_grouped_sizes_per_gpu
,
test_description
):
"""Test get_load_balance_assignment with various input cases."""
result
=
get_load_balance_assignment
(
sizes
,
num_gpus
=
num_gpus
)
(
shuffle_indices
,
gpu_sample_counts
,
grouped_sizes_per_gpu
)
=
result
# Common assertions for all cases
assert
len
(
shuffle_indices
)
==
len
(
sizes
)
assert
len
(
gpu_sample_counts
)
==
num_gpus
assert
len
(
grouped_sizes_per_gpu
)
==
num_gpus
assert
sum
(
gpu_sample_counts
)
==
len
(
sizes
)
assert
shuffle_indices
==
expected_shuffle_indices
assert
gpu_sample_counts
==
expected_gpu_sample_counts
assert
grouped_sizes_per_gpu
==
expected_grouped_sizes_per_gpu
class
SimpleMRopeVisionModel
(
torch
.
nn
.
Module
):
"""A simple vision model for testing mrope functionality."""
def
__init__
(
self
,
spatial_merge_size
:
int
=
2
,
out_hidden_size
:
int
=
64
):
super
().
__init__
()
self
.
spatial_merge_size
=
spatial_merge_size
self
.
out_hidden_size
=
out_hidden_size
self
.
linear
=
torch
.
nn
.
Linear
(
768
,
out_hidden_size
)
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
,
grid_thw_list
:
list
[
list
[
int
]]):
"""Simple forward pass that simulates spatial merging."""
# Apply linear transformation
embeddings
=
self
.
linear
(
pixel_values
)
# Simulate spatial merging by reducing the number of patches
merge_factor
=
self
.
spatial_merge_size
*
self
.
spatial_merge_size
# Group patches and merge spatially
merged_embeddings
=
[]
start_idx
=
0
for
grid_thw
in
grid_thw_list
:
num_patches
=
math
.
prod
(
grid_thw
)
end_idx
=
start_idx
+
num_patches
# Get patches for this image
image_patches
=
embeddings
[
start_idx
:
end_idx
]
# Simulate spatial merging by averaging groups of patches
merged_patches
=
num_patches
//
merge_factor
if
merged_patches
>
0
:
# Reshape and average to simulate merging
reshaped
=
image_patches
[:
merged_patches
*
merge_factor
].
view
(
merged_patches
,
merge_factor
,
-
1
)
merged
=
reshaped
.
mean
(
dim
=
1
)
merged_embeddings
.
append
(
merged
)
start_idx
=
end_idx
if
merged_embeddings
:
return
torch
.
cat
(
merged_embeddings
,
dim
=
0
)
else
:
return
torch
.
empty
((
0
,
self
.
out_hidden_size
),
device
=
pixel_values
.
device
,
dtype
=
pixel_values
.
dtype
)
@
multi_gpu_test
(
num_gpus
=
2
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
# Single image
3
,
# Small batch
5
,
# Odd batch size (for testing padding)
],
)
def
test_run_dp_sharded_mrope_vision_model
(
batch_size
:
int
):
world_size
=
2
# Launch processes
mp
.
spawn
(
run_dp_sharded_mrope_vision_model_vs_direct
,
args
=
(
world_size
,
batch_size
,
get_open_port
(),
),
nprocs
=
world_size
,
)
def
run_dp_sharded_mrope_vision_model_vs_direct
(
local_rank
:
int
,
world_size
:
int
,
batch_size
:
int
,
master_port
:
int
):
"""
Test that run_dp_sharded_mrope_vision_model produces the same results as
calling the model directly.
"""
# Set random seed for reproducibility
current_platform
.
seed_everything
(
0
)
device
=
f
"
{
current_platform
.
device_name
}
:
{
local_rank
}
"
current_platform
.
set_device
(
device
)
torch
.
set_default_device
(
device
)
update_environment_variables
({
'RANK'
:
str
(
local_rank
),
'LOCAL_RANK'
:
str
(
local_rank
),
'WORLD_SIZE'
:
str
(
world_size
),
'MASTER_ADDR'
:
'localhost'
,
'MASTER_PORT'
:
str
(
master_port
),
})
# initialize distributed
init_distributed_environment
()
initialize_model_parallel
(
tensor_model_parallel_size
=
world_size
)
# Create test data
grid_thw_list
=
[]
pixel_values_list
=
[]
for
i
in
range
(
batch_size
):
# Varying image sizes for better testing
t
,
h
,
w
=
1
,
4
+
i
,
4
+
i
grid_thw_list
.
append
([
t
,
h
,
w
])
num_patches
=
t
*
h
*
w
# Create random pixel values for this image
image_pixels
=
torch
.
randn
(
num_patches
,
768
)
pixel_values_list
.
append
(
image_pixels
)
# Concatenate all pixel values
pixel_values
=
torch
.
cat
(
pixel_values_list
,
dim
=
0
)
# Create a simple mrope vision model
vision_model
=
SimpleMRopeVisionModel
()
# Run the model directly on the full input (only on rank 0)
if
local_rank
==
0
:
with
torch
.
inference_mode
():
direct_output
=
vision_model
(
pixel_values
,
grid_thw_list
)
# Run the model through the sharded function
with
torch
.
inference_mode
():
sharded_output
=
run_dp_sharded_mrope_vision_model
(
vision_model
,
pixel_values
,
grid_thw_list
)
sharded_output
=
torch
.
cat
(
sharded_output
,
dim
=
0
)
# Check that the world size is setup correctly
assert
get_tensor_model_parallel_world_size
()
==
world_size
# Compare outputs (only on rank 0)
if
local_rank
==
0
:
# Check that the outputs have the same shape
assert
direct_output
.
shape
==
sharded_output
.
shape
# Check that the outputs are close (they should be identical)
assert
torch
.
allclose
(
direct_output
,
sharded_output
,
rtol
=
1e-5
,
atol
=
1e-5
)
@
multi_gpu_test
(
num_gpus
=
2
)
def
test_run_dp_sharded_mrope_vision_model_empty_input
():
world_size
=
2
mp
.
spawn
(
run_dp_sharded_mrope_vision_model_empty_input_worker
,
args
=
(
world_size
,
get_open_port
()),
nprocs
=
world_size
,
)
def
run_dp_sharded_mrope_vision_model_empty_input_worker
(
local_rank
:
int
,
world_size
:
int
,
master_port
:
int
):
"""Test run_dp_sharded_mrope_vision_model with empty input."""
# Set up distributed environment
device
=
f
"
{
current_platform
.
device_name
}
:
{
local_rank
}
"
current_platform
.
set_device
(
device
)
torch
.
set_default_device
(
device
)
update_environment_variables
({
'RANK'
:
str
(
local_rank
),
'LOCAL_RANK'
:
str
(
local_rank
),
'WORLD_SIZE'
:
str
(
world_size
),
'MASTER_ADDR'
:
'localhost'
,
'MASTER_PORT'
:
str
(
master_port
),
})
init_distributed_environment
()
initialize_model_parallel
(
tensor_model_parallel_size
=
world_size
)
# Create empty inputs
pixel_values
=
torch
.
empty
((
0
,
768
))
grid_thw_list
:
list
[
list
[
int
]]
=
[]
vision_model
=
SimpleMRopeVisionModel
()
# Should handle empty input gracefully
with
torch
.
inference_mode
():
output
=
run_dp_sharded_mrope_vision_model
(
vision_model
,
pixel_values
,
grid_thw_list
)
assert
len
(
output
)
==
0
@
multi_gpu_test
(
num_gpus
=
4
)
def
test_run_dp_sharded_mrope_vision_model_uneven_load
():
world_size
=
4
mp
.
spawn
(
run_dp_sharded_mrope_vision_model_uneven_load_worker
,
args
=
(
world_size
,
get_open_port
()),
nprocs
=
world_size
,
)
def
run_dp_sharded_mrope_vision_model_uneven_load_worker
(
local_rank
:
int
,
world_size
:
int
,
master_port
:
int
):
"""Test run_dp_sharded_mrope_vision_model with uneven load distribution."""
# Set up distributed environment
current_platform
.
seed_everything
(
123
)
device
=
f
"
{
current_platform
.
device_name
}
:
{
local_rank
}
"
current_platform
.
set_device
(
device
)
torch
.
set_default_device
(
device
)
update_environment_variables
({
'RANK'
:
str
(
local_rank
),
'LOCAL_RANK'
:
str
(
local_rank
),
'WORLD_SIZE'
:
str
(
world_size
),
'MASTER_ADDR'
:
'localhost'
,
'MASTER_PORT'
:
str
(
master_port
),
})
init_distributed_environment
()
initialize_model_parallel
(
tensor_model_parallel_size
=
world_size
)
# Create images with very different sizes
grid_thw_list
=
[
[
1
,
2
,
2
],
# Small: 4 patches
[
1
,
8
,
8
],
# Large: 64 patches
[
1
,
3
,
3
],
# Medium: 9 patches
]
pixel_values_list
=
[]
for
grid_thw
in
grid_thw_list
:
num_patches
=
math
.
prod
(
grid_thw
)
image_pixels
=
torch
.
randn
(
num_patches
,
768
)
pixel_values_list
.
append
(
image_pixels
)
pixel_values
=
torch
.
cat
(
pixel_values_list
,
dim
=
0
)
vision_model
=
SimpleMRopeVisionModel
()
# Should handle uneven distribution without errors
with
torch
.
inference_mode
():
output_tuple
=
run_dp_sharded_mrope_vision_model
(
vision_model
,
pixel_values
,
grid_thw_list
)
# Verify output shape is reasonable
merge_factor
=
vision_model
.
spatial_merge_size
**
2
expected_output_patches
=
list
(
math
.
prod
(
grid_thw
)
//
merge_factor
for
grid_thw
in
grid_thw_list
)
for
i
,
output
in
enumerate
(
output_tuple
):
assert
output
.
shape
[
0
]
==
expected_output_patches
[
i
]
assert
output
.
shape
[
1
]
==
vision_model
.
out_hidden_size
@
pytest
.
mark
.
parametrize
(
"spatial_merge_size"
,
[
2
,
4
])
def
test_simple_mrope_vision_model_spatial_merge
(
spatial_merge_size
:
int
):
"""Test SimpleMRopeVisionModel with different spatial merge sizes."""
device
=
current_platform
.
device_type
grid_thw_list
=
[[
1
,
4
,
4
],
[
1
,
6
,
6
]]
# Two images
pixel_values_list
=
[]
for
grid_thw
in
grid_thw_list
:
num_patches
=
math
.
prod
(
grid_thw
)
image_pixels
=
torch
.
randn
(
num_patches
,
768
,
device
=
device
)
pixel_values_list
.
append
(
image_pixels
)
pixel_values
=
torch
.
cat
(
pixel_values_list
,
dim
=
0
)
vision_model
=
SimpleMRopeVisionModel
(
spatial_merge_size
=
spatial_merge_size
).
to
(
device
)
with
torch
.
inference_mode
():
output
=
vision_model
(
pixel_values
,
grid_thw_list
)
# Verify output dimensions based on spatial merging
total_patches
=
sum
(
math
.
prod
(
grid_thw
)
for
grid_thw
in
grid_thw_list
)
merge_factor
=
spatial_merge_size
**
2
expected_output_patches
=
total_patches
//
merge_factor
assert
output
.
shape
[
0
]
==
expected_output_patches
assert
output
.
shape
[
1
]
==
vision_model
.
out_hidden_size
tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/__init__.py
0 → 100644
View file @
d2b52805
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
def
register_prithvi_india
():
return
"prithvi_io_processor.prithvi_processor.PrithviMultimodalDataProcessorIndia"
# noqa: E501
def
register_prithvi_valencia
():
return
"prithvi_io_processor.prithvi_processor.PrithviMultimodalDataProcessorValencia"
# noqa: E501
tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py
0 → 100644
View file @
d2b52805
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
import
base64
import
datetime
import
os
import
tempfile
import
urllib.request
from
collections.abc
import
AsyncGenerator
,
Sequence
from
typing
import
Any
,
Optional
,
Union
import
albumentations
import
numpy
as
np
import
rasterio
import
regex
as
re
import
torch
from
einops
import
rearrange
from
terratorch.datamodules
import
Sen1Floods11NonGeoDataModule
from
vllm.config
import
VllmConfig
from
vllm.entrypoints.openai.protocol
import
(
IOProcessorRequest
,
IOProcessorResponse
)
from
vllm.inputs.data
import
PromptType
from
vllm.logger
import
init_logger
from
vllm.outputs
import
PoolingRequestOutput
from
vllm.plugins.io_processors.interface
import
(
IOProcessor
,
IOProcessorInput
,
IOProcessorOutput
)
from
.types
import
DataModuleConfig
,
ImagePrompt
,
ImageRequestOutput
logger
=
init_logger
(
__name__
)
NO_DATA
=
-
9999
NO_DATA_FLOAT
=
0.0001
OFFSET
=
0
PERCENTILE
=
99
DEFAULT_INPUT_INDICES
=
[
0
,
1
,
2
,
3
,
4
,
5
]
datamodule_config
:
DataModuleConfig
=
{
"bands"
:
[
"BLUE"
,
"GREEN"
,
"RED"
,
"NIR_NARROW"
,
"SWIR_1"
,
"SWIR_2"
],
"batch_size"
:
16
,
"constant_scale"
:
0.0001
,
"data_root"
:
"/dccstor/geofm-finetuning/datasets/sen1floods11"
,
"drop_last"
:
True
,
"no_data_replace"
:
0.0
,
"no_label_replace"
:
-
1
,
"num_workers"
:
8
,
"test_transform"
:
[
albumentations
.
Resize
(
always_apply
=
False
,
height
=
448
,
interpolation
=
1
,
p
=
1
,
width
=
448
),
albumentations
.
pytorch
.
ToTensorV2
(
transpose_mask
=
False
,
always_apply
=
True
,
p
=
1.0
),
],
}
def
save_geotiff
(
image
:
torch
.
Tensor
,
meta
:
dict
,
out_format
:
str
)
->
str
|
bytes
:
"""Save multi-band image in Geotiff file.
Args:
image: np.ndarray with shape (bands, height, width)
output_path: path where to save the image
meta: dict with meta info.
"""
if
out_format
==
"path"
:
# create temp file
file_path
=
os
.
path
.
join
(
os
.
getcwd
(),
"prediction.tiff"
)
with
rasterio
.
open
(
file_path
,
"w"
,
**
meta
)
as
dest
:
for
i
in
range
(
image
.
shape
[
0
]):
dest
.
write
(
image
[
i
,
:,
:],
i
+
1
)
return
file_path
elif
out_format
==
"b64_json"
:
with
tempfile
.
NamedTemporaryFile
()
as
tmpfile
:
with
rasterio
.
open
(
tmpfile
.
name
,
"w"
,
**
meta
)
as
dest
:
for
i
in
range
(
image
.
shape
[
0
]):
dest
.
write
(
image
[
i
,
:,
:],
i
+
1
)
file_data
=
tmpfile
.
read
()
return
base64
.
b64encode
(
file_data
)
else
:
raise
ValueError
(
"Unknown output format"
)
def
_convert_np_uint8
(
float_image
:
torch
.
Tensor
):
image
=
float_image
.
numpy
()
*
255.0
image
=
image
.
astype
(
dtype
=
np
.
uint8
)
return
image
def
read_geotiff
(
file_path
:
Optional
[
str
]
=
None
,
path_type
:
Optional
[
str
]
=
None
,
file_data
:
Optional
[
bytes
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
dict
,
tuple
[
float
,
float
]
|
None
]:
"""Read all bands from *file_path* and return image + meta info.
Args:
file_path: path to image file.
Returns:
np.ndarray with shape (bands, height, width)
meta info dict
"""
if
all
([
x
is
None
for
x
in
[
file_path
,
path_type
,
file_data
]]):
raise
Exception
(
"All input fields to read_geotiff are None"
)
write_to_file
:
Optional
[
bytes
]
=
None
path
:
Optional
[
str
]
=
None
if
file_data
is
not
None
:
# with tempfile.NamedTemporaryFile() as tmpfile:
# tmpfile.write(file_data)
# path = tmpfile.name
write_to_file
=
file_data
elif
file_path
is
not
None
and
path_type
==
"url"
:
resp
=
urllib
.
request
.
urlopen
(
file_path
)
# with tempfile.NamedTemporaryFile() as tmpfile:
# tmpfile.write(resp.read())
# path = tmpfile.name
write_to_file
=
resp
.
read
()
elif
file_path
is
not
None
and
path_type
==
"path"
:
path
=
file_path
elif
file_path
is
not
None
and
path_type
==
"b64_json"
:
image_data
=
base64
.
b64decode
(
file_path
)
# with tempfile.NamedTemporaryFile() as tmpfile:
# tmpfile.write(image_data)
# path = tmpfile.name
write_to_file
=
image_data
else
:
raise
Exception
(
"Wrong combination of parameters to read_geotiff"
)
with
tempfile
.
NamedTemporaryFile
()
as
tmpfile
:
path_to_use
=
None
if
write_to_file
:
tmpfile
.
write
(
write_to_file
)
path_to_use
=
tmpfile
.
name
elif
path
:
path_to_use
=
path
with
rasterio
.
open
(
path_to_use
)
as
src
:
img
=
src
.
read
()
meta
=
src
.
meta
try
:
coords
=
src
.
lnglat
()
except
Exception
:
# Cannot read coords
coords
=
None
return
img
,
meta
,
coords
def
load_image
(
data
:
Union
[
list
[
str
]],
path_type
:
str
,
mean
:
Optional
[
list
[
float
]]
=
None
,
std
:
Optional
[
list
[
float
]]
=
None
,
indices
:
Optional
[
Union
[
list
[
int
],
None
]]
=
None
,
):
"""Build an input example by loading images in *file_paths*.
Args:
file_paths: list of file paths .
mean: list containing mean values for each band in the
images in *file_paths*.
std: list containing std values for each band in the
images in *file_paths*.
Returns:
np.array containing created example
list of meta info for each image in *file_paths*
"""
imgs
=
[]
metas
=
[]
temporal_coords
=
[]
location_coords
=
[]
for
file
in
data
:
# if isinstance(file, bytes):
# img, meta, coords = read_geotiff(file_data=file)
# else:
img
,
meta
,
coords
=
read_geotiff
(
file_path
=
file
,
path_type
=
path_type
)
# Rescaling (don't normalize on nodata)
img
=
np
.
moveaxis
(
img
,
0
,
-
1
)
# channels last for rescaling
if
indices
is
not
None
:
img
=
img
[...,
indices
]
if
mean
is
not
None
and
std
is
not
None
:
img
=
np
.
where
(
img
==
NO_DATA
,
NO_DATA_FLOAT
,
(
img
-
mean
)
/
std
)
imgs
.
append
(
img
)
metas
.
append
(
meta
)
if
coords
is
not
None
:
location_coords
.
append
(
coords
)
try
:
match
=
re
.
search
(
r
"(\d{7,8}T\d{6})"
,
file
)
if
match
:
year
=
int
(
match
.
group
(
1
)[:
4
])
julian_day
=
match
.
group
(
1
).
split
(
"T"
)[
0
][
4
:]
if
len
(
julian_day
)
==
3
:
julian_day
=
int
(
julian_day
)
else
:
julian_day
=
(
datetime
.
datetime
.
strptime
(
julian_day
,
"%m%d"
).
timetuple
().
tm_yday
)
temporal_coords
.
append
([
year
,
julian_day
])
except
Exception
:
logger
.
exception
(
"Could not extract timestamp for %s"
,
file
)
imgs
=
np
.
stack
(
imgs
,
axis
=
0
)
# num_frames, H, W, C
imgs
=
np
.
moveaxis
(
imgs
,
-
1
,
0
).
astype
(
"float32"
)
# C, num_frames, H, W
imgs
=
np
.
expand_dims
(
imgs
,
axis
=
0
)
# add batch di
return
imgs
,
temporal_coords
,
location_coords
,
metas
class
PrithviMultimodalDataProcessor
(
IOProcessor
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
):
super
().
__init__
(
vllm_config
)
self
.
datamodule
=
Sen1Floods11NonGeoDataModule
(
data_root
=
datamodule_config
[
"data_root"
],
batch_size
=
datamodule_config
[
"batch_size"
],
num_workers
=
datamodule_config
[
"num_workers"
],
bands
=
datamodule_config
[
"bands"
],
drop_last
=
datamodule_config
[
"drop_last"
],
test_transform
=
datamodule_config
[
"test_transform"
],
)
self
.
img_size
=
512
self
.
h1
=
1
self
.
w1
=
1
self
.
original_h
=
512
self
.
original_w
=
512
self
.
batch_size
=
1
self
.
meta_data
=
None
self
.
requests_cache
:
dict
[
str
,
dict
[
str
,
Any
]]
=
{}
self
.
indices
=
DEFAULT_INPUT_INDICES
def
parse_request
(
self
,
request
:
Any
)
->
IOProcessorInput
:
if
type
(
request
)
is
dict
:
image_prompt
=
ImagePrompt
(
**
request
)
return
image_prompt
if
isinstance
(
request
,
IOProcessorRequest
):
if
not
hasattr
(
request
,
"data"
):
raise
ValueError
(
"missing 'data' field in OpenAIBaseModel Request"
)
request_data
=
request
.
data
if
type
(
request_data
)
is
dict
:
return
ImagePrompt
(
**
request_data
)
else
:
raise
ValueError
(
"Unable to parse the request data"
)
raise
ValueError
(
"Unable to parse request"
)
def
output_to_response
(
self
,
plugin_output
:
IOProcessorOutput
)
->
IOProcessorResponse
:
return
IOProcessorResponse
(
request_id
=
plugin_output
.
request_id
,
data
=
plugin_output
,
)
def
pre_process
(
self
,
prompt
:
IOProcessorInput
,
request_id
:
Optional
[
str
]
=
None
,
**
kwargs
,
)
->
Union
[
PromptType
,
Sequence
[
PromptType
]]:
image_data
=
dict
(
prompt
)
if
request_id
:
self
.
requests_cache
[
request_id
]
=
{
"out_format"
:
image_data
[
"out_data_format"
],
}
input_data
,
temporal_coords
,
location_coords
,
meta_data
=
load_image
(
data
=
[
image_data
[
"data"
]],
indices
=
self
.
indices
,
path_type
=
image_data
[
"data_format"
],
)
self
.
meta_data
=
meta_data
[
0
]
if
input_data
.
mean
()
>
1
:
input_data
=
input_data
/
10000
# Convert to range 0-1
self
.
original_h
,
self
.
original_w
=
input_data
.
shape
[
-
2
:]
pad_h
=
(
self
.
img_size
-
(
self
.
original_h
%
self
.
img_size
))
%
self
.
img_size
pad_w
=
(
self
.
img_size
-
(
self
.
original_w
%
self
.
img_size
))
%
self
.
img_size
input_data
=
np
.
pad
(
input_data
,
((
0
,
0
),
(
0
,
0
),
(
0
,
0
),
(
0
,
pad_h
),
(
0
,
pad_w
)),
mode
=
"reflect"
,
)
batch
=
torch
.
tensor
(
input_data
)
windows
=
batch
.
unfold
(
3
,
self
.
img_size
,
self
.
img_size
).
unfold
(
4
,
self
.
img_size
,
self
.
img_size
)
self
.
h1
,
self
.
w1
=
windows
.
shape
[
3
:
5
]
windows
=
rearrange
(
windows
,
"b c t h1 w1 h w -> (b h1 w1) c t h w"
,
h
=
self
.
img_size
,
w
=
self
.
img_size
,
)
# Split into batches if number of windows > batch_size
num_batches
=
(
windows
.
shape
[
0
]
//
self
.
batch_size
if
windows
.
shape
[
0
]
>
self
.
batch_size
else
1
)
windows
=
torch
.
tensor_split
(
windows
,
num_batches
,
dim
=
0
)
if
temporal_coords
:
temporal_coords
=
torch
.
tensor
(
temporal_coords
).
unsqueeze
(
0
)
else
:
temporal_coords
=
None
if
location_coords
:
location_coords
=
torch
.
tensor
(
location_coords
[
0
]).
unsqueeze
(
0
)
else
:
location_coords
=
None
prompts
=
[]
for
window
in
windows
:
# Apply standardization
window
=
self
.
datamodule
.
test_transform
(
image
=
window
.
squeeze
().
numpy
().
transpose
(
1
,
2
,
0
))
window
=
self
.
datamodule
.
aug
(
window
)[
"image"
]
prompts
.
append
({
"prompt_token_ids"
:
[
1
],
"multi_modal_data"
:
{
"pixel_values"
:
window
.
to
(
torch
.
float16
)[
0
],
"location_coords"
:
location_coords
.
to
(
torch
.
float16
),
},
})
return
prompts
async
def
pre_process_async
(
self
,
prompt
:
IOProcessorInput
,
request_id
:
Optional
[
str
]
=
None
,
**
kwargs
,
)
->
Union
[
PromptType
,
Sequence
[
PromptType
]]:
return
self
.
pre_process
(
prompt
,
request_id
,
**
kwargs
)
def
post_process
(
self
,
model_output
:
Sequence
[
PoolingRequestOutput
],
request_id
:
Optional
[
str
]
=
None
,
**
kwargs
,
)
->
IOProcessorOutput
:
pred_imgs_list
=
[]
if
request_id
and
(
request_id
in
self
.
requests_cache
):
out_format
=
self
.
requests_cache
[
request_id
][
"out_format"
]
else
:
out_format
=
"b64_json"
for
output
in
model_output
:
y_hat
=
output
.
outputs
.
data
.
argmax
(
dim
=
1
)
pred
=
torch
.
nn
.
functional
.
interpolate
(
y_hat
.
unsqueeze
(
1
).
float
(),
size
=
self
.
img_size
,
mode
=
"nearest"
,
)
pred_imgs_list
.
append
(
pred
)
pred_imgs
:
torch
.
Tensor
=
torch
.
concat
(
pred_imgs_list
,
dim
=
0
)
# Build images from patches
pred_imgs
=
rearrange
(
pred_imgs
,
"(b h1 w1) c h w -> b c (h1 h) (w1 w)"
,
h
=
self
.
img_size
,
w
=
self
.
img_size
,
b
=
1
,
c
=
1
,
h1
=
self
.
h1
,
w1
=
self
.
w1
,
)
# Cut padded area back to original size
pred_imgs
=
pred_imgs
[...,
:
self
.
original_h
,
:
self
.
original_w
]
# Squeeze (batch size 1)
pred_imgs
=
pred_imgs
[
0
]
if
not
self
.
meta_data
:
raise
ValueError
(
"No metadata available for the current task"
)
self
.
meta_data
.
update
(
count
=
1
,
dtype
=
"uint8"
,
compress
=
"lzw"
,
nodata
=
0
)
out_data
=
save_geotiff
(
_convert_np_uint8
(
pred_imgs
),
self
.
meta_data
,
out_format
)
return
ImageRequestOutput
(
type
=
out_format
,
format
=
"tiff"
,
data
=
out_data
,
request_id
=
request_id
)
async
def
post_process_async
(
self
,
model_output
:
AsyncGenerator
[
tuple
[
int
,
PoolingRequestOutput
]],
request_id
:
Optional
[
str
]
=
None
,
**
kwargs
,
)
->
IOProcessorOutput
:
collected_output
=
[
item
async
for
i
,
item
in
model_output
]
return
self
.
post_process
(
collected_output
,
request_id
,
**
kwargs
)
class
PrithviMultimodalDataProcessorIndia
(
PrithviMultimodalDataProcessor
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
):
super
().
__init__
(
vllm_config
)
self
.
indices
=
[
1
,
2
,
3
,
8
,
11
,
12
]
class
PrithviMultimodalDataProcessorValencia
(
PrithviMultimodalDataProcessor
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
):
super
().
__init__
(
vllm_config
)
self
.
indices
=
[
0
,
1
,
2
,
3
,
4
,
5
]
tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/types.py
0 → 100644
View file @
d2b52805
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
,
Literal
,
Optional
,
TypedDict
,
Union
import
albumentations
from
pydantic
import
BaseModel
class
DataModuleConfig
(
TypedDict
):
bands
:
list
[
str
]
batch_size
:
int
constant_scale
:
float
data_root
:
str
drop_last
:
bool
no_data_replace
:
float
no_label_replace
:
int
num_workers
:
int
test_transform
:
list
[
albumentations
.
core
.
transforms_interface
.
BasicTransform
]
class
ImagePrompt
(
BaseModel
):
data_format
:
Literal
[
"b64_json"
,
"bytes"
,
"url"
]
"""
This is the data type for the input image
"""
image_format
:
str
"""
This is the image format (e.g., jpeg, png, etc.)
"""
out_data_format
:
Literal
[
"b64_json"
,
"url"
]
data
:
Any
"""
Input image data
"""
MultiModalPromptType
=
Union
[
ImagePrompt
]
class
ImageRequestOutput
(
BaseModel
):
"""
The output data of an image request to vLLM.
Args:
type (str): The data content type [path, object]
format (str): The image format (e.g., jpeg, png, etc.)
data (Any): The resulting data.
"""
type
:
Literal
[
"path"
,
"b64_json"
]
format
:
str
data
:
str
request_id
:
Optional
[
str
]
=
None
tests/plugins/prithvi_io_processor_plugin/setup.py
0 → 100644
View file @
d2b52805
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
setuptools
import
setup
setup
(
name
=
"prithvi_io_processor_plugin"
,
version
=
"0.1"
,
packages
=
[
"prithvi_io_processor"
],
entry_points
=
{
"vllm.io_processor_plugins"
:
[
"prithvi_to_tiff_india = prithvi_io_processor:register_prithvi_india"
,
# noqa: E501
"prithvi_to_tiff_valencia = prithvi_io_processor:register_prithvi_valencia"
,
# noqa: E501
]
},
)
tests/plugins_tests/test_io_processor_plugins.py
0 → 100644
View file @
d2b52805
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
base64
import
pytest
import
requests
from
tests.utils
import
RemoteOpenAIServer
from
vllm.config
import
VllmConfig
from
vllm.entrypoints.llm
import
LLM
from
vllm.entrypoints.openai.protocol
import
IOProcessorResponse
from
vllm.plugins.io_processors
import
get_io_processor
from
vllm.pooling_params
import
PoolingParams
MODEL_NAME
=
"christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"
image_url
=
"https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff"
# noqa: E501
def
test_loading_missing_plugin
():
vllm_config
=
VllmConfig
()
with
pytest
.
raises
(
ValueError
):
get_io_processor
(
vllm_config
,
"wrong_plugin"
)
def
test_loading_engine_with_wrong_plugin
():
with
pytest
.
raises
(
ValueError
):
LLM
(
model
=
MODEL_NAME
,
skip_tokenizer_init
=
True
,
trust_remote_code
=
True
,
enforce_eager
=
True
,
# Limit the maximum number of parallel requests
# to avoid the model going OOM in CI.
max_num_seqs
=
32
,
io_processor_plugin
=
"wrong_plugin"
,
)
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
])
def
test_prithvi_mae_plugin_offline
(
vllm_runner
,
model_name
:
str
):
img_prompt
=
dict
(
data
=
image_url
,
data_format
=
"url"
,
image_format
=
"tiff"
,
out_data_format
=
"b64_json"
,
)
pooling_params
=
PoolingParams
(
task
=
"encode"
,
softmax
=
False
)
with
vllm_runner
(
model_name
,
runner
=
"pooling"
,
skip_tokenizer_init
=
True
,
trust_remote_code
=
True
,
enforce_eager
=
True
,
# Limit the maximum number of parallel requests
# to avoid the model going OOM in CI.
max_num_seqs
=
1
,
io_processor_plugin
=
"prithvi_to_tiff_valencia"
,
)
as
llm_runner
:
pooler_output
=
llm_runner
.
get_llm
().
encode
(
img_prompt
,
pooling_params
=
pooling_params
,
)
output
=
pooler_output
[
0
].
outputs
# verify the output is formatted as expected for this plugin
assert
all
(
hasattr
(
output
,
attr
)
for
attr
in
[
"type"
,
"format"
,
"data"
,
"request_id"
])
# We just check that the output is a valid base64 string.
# Raises an exception and fails the test if the string is corrupted.
base64
.
b64decode
(
output
.
data
)
@
pytest
.
fixture
(
scope
=
"module"
)
def
server
():
args
=
[
"--runner"
,
"pooling"
,
"--enforce-eager"
,
"--trust-remote-code"
,
"--skip-tokenizer-init"
,
# Limit the maximum number of parallel requests
# to avoid the model going OOM in CI.
"--max-num-seqs"
,
"32"
,
"--io-processor-plugin"
,
"prithvi_to_tiff_valencia"
]
with
RemoteOpenAIServer
(
MODEL_NAME
,
args
)
as
remote_server
:
yield
remote_server
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
])
async
def
test_prithvi_mae_plugin_online
(
server
:
RemoteOpenAIServer
,
model_name
:
str
,
):
request_payload_url
=
{
"data"
:
{
"data"
:
image_url
,
"data_format"
:
"url"
,
"image_format"
:
"tiff"
,
"out_data_format"
:
"b64_json"
,
},
"priority"
:
0
,
"model"
:
model_name
,
}
ret
=
requests
.
post
(
server
.
url_for
(
"pooling"
),
json
=
request_payload_url
,
)
response
=
ret
.
json
()
# verify the request response is in the correct format
assert
(
parsed_response
:
=
IOProcessorResponse
(
**
response
))
# verify the output is formatted as expected for this plugin
plugin_data
=
parsed_response
.
data
assert
all
(
plugin_data
.
get
(
attr
)
for
attr
in
[
"type"
,
"format"
,
"data"
,
"request_id"
])
# We just check that the output is a valid base64 string.
# Raises an exception and fails the test if the string is corrupted.
base64
.
b64decode
(
plugin_data
[
"data"
])
tests/plugins_tests/test_platform_plugins.py
View file @
d2b52805
...
...
@@ -7,6 +7,15 @@ import torch
from
vllm.plugins
import
load_general_plugins
@
pytest
.
fixture
(
scope
=
"function"
,
autouse
=
True
)
def
use_v0_only
(
monkeypatch
):
"""
Since this module is V0 only, set VLLM_USE_V1=0 for
all tests in the module.
"""
monkeypatch
.
setenv
(
'VLLM_USE_V1'
,
'0'
)
def
test_platform_plugins
():
# simulate workload by running an example
import
runpy
...
...
tests/prefix_caching/test_disable_sliding_window.py
deleted
100644 → 0
View file @
9a521c23
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Compare the with and without prefix caching.
Run `pytest tests/prefix_caching/test_prefix_caching.py`.
"""
import
pytest
from
vllm
import
LLM
from
vllm.distributed
import
cleanup_dist_env_and_memory
MODEL_LEN_LEN
=
[
# Example models with sliding window.
(
"bigcode/starcoder2-3b"
,
4096
,
16384
),
# ("mistralai/Mistral-7B-v0.1", 4096, 32768), << OOM in CI
# Confirm model with sliding window works.
# config has "use_sliding_window": false
(
"Qwen/Qwen1.5-0.5B-Chat"
,
32768
,
32768
),
# config has no sliding window attribute.
(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0"
,
2048
,
2048
),
]
@
pytest
.
mark
.
parametrize
(
"model_len_len"
,
MODEL_LEN_LEN
)
def
test_disable_sliding_window
(
model_len_len
,
):
model
,
sliding_len
,
full_len
=
model_len_len
disabled_llm
=
LLM
(
model
,
disable_sliding_window
=
True
)
disabled_llm
.
generate
(
"Hi my name is"
)
model_config
=
disabled_llm
.
llm_engine
.
model_config
assert
model_config
.
max_model_len
==
sliding_len
,
(
"Max len expected to equal sliding_len of %s, but got %s"
,
sliding_len
,
model_config
.
max_model_len
)
del
disabled_llm
cleanup_dist_env_and_memory
()
enabled_llm
=
LLM
(
model
,
enforce_eager
=
True
,
disable_sliding_window
=
False
,
enable_prefix_caching
=
False
)
enabled_llm
.
generate
(
"Hi my name is"
)
model_config
=
enabled_llm
.
llm_engine
.
model_config
assert
model_config
.
max_model_len
==
full_len
,
(
"Max len expected to equal full_len of %s, but got %s"
,
full_len
,
model_config
.
max_model_len
)
del
enabled_llm
cleanup_dist_env_and_memory
()
tests/prefix_caching/test_prefix_caching.py
deleted
100644 → 0
View file @
9a521c23
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Compare the with and without prefix caching.
Run `pytest tests/prefix_caching/test_prefix_caching.py`.
"""
from
__future__
import
annotations
import
pytest
from
tests.conftest
import
VllmRunner
from
tests.core.utils
import
SchedulerProxy
,
create_dummy_prompt
from
vllm
import
SamplingParams
,
TokensPrompt
from
vllm.core.scheduler
import
Scheduler
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.platforms
import
current_platform
from
vllm.utils
import
STR_BACKEND_ENV_VAR
from
..models.utils
import
check_outputs_equal
@
pytest
.
fixture
(
scope
=
"function"
,
autouse
=
True
)
def
use_v0_only
(
monkeypatch
:
pytest
.
MonkeyPatch
):
"""
This module relies on V0 internals, so set VLLM_USE_V1=0.
"""
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
'VLLM_USE_V1'
,
'0'
)
yield
MODELS
=
[
"distilbert/distilgpt2"
,
]
UNSTABLE_PROMPT_SEQUENCE
=
[
([
0
]
*
588
)
+
([
1
]
*
1332
)
+
([
2
]
*
30
)
+
([
3
]
*
1
),
([
0
]
*
588
)
+
([
1
]
*
1332
)
+
([
4
]
*
3
)
+
([
5
]
*
50
),
([
0
]
*
588
)
+
([
1
]
*
1332
)
+
([
2
]
*
30
)
+
([
6
]
*
95
),
([
0
]
*
588
)
+
([
1
]
*
1332
)
+
([
4
]
*
3
)
+
([
7
]
*
174
),
([
0
]
*
588
)
+
([
8
]
*
1539
),
]
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
,
"XFORMERS"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"cached_position"
,
[
0
,
1
])
@
pytest
.
mark
.
parametrize
(
"enable_chunked_prefill"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
16
])
def
test_mixed_requests
(
hf_runner
,
vllm_runner
,
example_prompts
,
model
:
str
,
backend
:
str
,
dtype
:
str
,
max_tokens
:
int
,
cached_position
:
int
,
enable_chunked_prefill
:
bool
,
block_size
:
int
,
monkeypatch
:
pytest
.
MonkeyPatch
,
)
->
None
:
"""
Test the case when some sequences have the prefix cache hit
and the others don't. The cached position determines where
the sequence is at among the batch of prefills.
"""
if
backend
==
"FLASHINFER"
and
current_platform
.
is_rocm
():
pytest
.
skip
(
"Flashinfer does not support ROCm/HIP."
)
if
backend
==
"XFORMERS"
and
current_platform
.
is_rocm
():
pytest
.
skip
(
"Xformers does not support ROCm/HIP."
)
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
backend
)
with
hf_runner
(
model
,
dtype
=
dtype
)
as
hf_model
:
hf_outputs
=
hf_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
cached_prompt
=
example_prompts
[
cached_position
]
with
vllm_runner
(
model
,
dtype
=
dtype
,
enable_prefix_caching
=
True
,
enable_chunked_prefill
=
enable_chunked_prefill
,
block_size
=
block_size
,
)
as
vllm_model
:
# Run the first prompt so the cache is populated
vllm_outputs
=
vllm_model
.
generate_greedy
([
cached_prompt
],
max_tokens
)
# Run all the promopts
greedy_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
max_tokens
)
req_outputs
=
vllm_model
.
llm
.
generate
(
example_prompts
,
greedy_params
)
# Verify number of cached tokens
for
i
in
range
(
len
(
req_outputs
)):
if
i
==
cached_position
:
expected_num_cached_tokens
=
(
len
(
req_outputs
[
i
].
prompt_token_ids
)
//
block_size
)
*
block_size
else
:
expected_num_cached_tokens
=
0
assert
(
req_outputs
[
i
].
num_cached_tokens
==
expected_num_cached_tokens
)
vllm_outputs
=
[(
output
.
prompt_token_ids
+
list
(
output
.
outputs
[
0
].
token_ids
),
output
.
prompt
+
output
.
outputs
[
0
].
text
,
)
for
output
in
req_outputs
]
check_outputs_equal
(
outputs_0_lst
=
hf_outputs
,
outputs_1_lst
=
vllm_outputs
,
name_0
=
"hf"
,
name_1
=
"vllm"
,
)
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
,
"XFORMERS"
])
def
test_unstable_prompt_sequence
(
vllm_runner
,
backend
:
str
,
monkeypatch
:
pytest
.
MonkeyPatch
,
)
->
None
:
if
backend
==
"FLASHINFER"
and
current_platform
.
is_rocm
():
pytest
.
skip
(
"Flashinfer does not support ROCm/HIP."
)
if
backend
==
"XFORMERS"
and
current_platform
.
is_rocm
():
pytest
.
skip
(
"Xformers does not support ROCm/HIP."
)
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
backend
)
with
vllm_runner
(
"Qwen/Qwen2.5-0.5B-Instruct"
,
enable_chunked_prefill
=
True
,
enable_prefix_caching
=
True
,
max_model_len
=
4096
,
)
as
vllm_model
:
for
prompt
in
UNSTABLE_PROMPT_SEQUENCE
:
vllm_model
.
generate
(
TokensPrompt
(
prompt_token_ids
=
prompt
),
SamplingParams
(
max_tokens
=
1
))
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
def
test_fully_cached_prefill_needs_uncached_token
(
model
):
block_size
=
16
max_num_batched_tokens
=
16
num_output_tokens
=
5
# Make a vllm engine
runner
=
VllmRunner
(
model_name
=
model
,
gpu_memory_utilization
=
0.7
,
enable_chunked_prefill
=
True
,
enforce_eager
=
True
,
enable_prefix_caching
=
True
,
block_size
=
block_size
,
max_num_batched_tokens
=
max_num_batched_tokens
,
max_num_seqs
=
max_num_batched_tokens
,
)
engine
:
LLMEngine
=
runner
.
llm
.
llm_engine
scheduler
:
Scheduler
=
SchedulerProxy
(
engine
.
scheduler
[
0
])
# type: ignore
engine
.
scheduler
[
0
]
=
scheduler
# SeqA
seqA_tokens
=
list
(
range
(
2
*
block_size
))
seqA
,
seq_groupA
=
create_dummy_prompt
(
request_id
=
"0"
,
prompt_tokens
=
seqA_tokens
,
max_tokens
=
num_output_tokens
,
block_size
=
block_size
,
)
scheduler
.
add_seq_group
(
seq_groupA
)
assert
seqA
.
data
.
get_num_computed_tokens
()
==
0
# Prefill seqA
while
not
seqA
.
is_finished
():
engine
.
step
()
# seqB
seqB_tokens
=
[
t
+
1
for
t
in
seqA_tokens
]
# shift by 1
seqB
,
seq_groupB
=
create_dummy_prompt
(
request_id
=
"1"
,
prompt_tokens
=
seqB_tokens
,
max_tokens
=
num_output_tokens
,
block_size
=
block_size
,
)
# seqC is the same as seqA
seqC
,
seq_groupC
=
create_dummy_prompt
(
request_id
=
"2"
,
prompt_tokens
=
seqA_tokens
,
max_tokens
=
num_output_tokens
,
block_size
=
block_size
,
)
scheduler
.
add_seq_group
(
seq_groupB
)
scheduler
.
add_seq_group
(
seq_groupC
)
# Even seqC is fully cached, it should not be prefilled since we
# require at least 1 uncached token.
engine
.
step
()
sched_metas
,
sched_out
,
_
=
scheduler
.
last_schedule_ret
()
assert
len
(
sched_out
.
scheduled_seq_groups
)
==
1
assert
(
sched_out
.
scheduled_seq_groups
[
0
].
seq_group
.
request_id
==
seq_groupB
.
request_id
)
assert
(
sched_out
.
scheduled_seq_groups
[
0
].
token_chunk_size
==
max_num_batched_tokens
)
# When seqB is finished, seqC could be prefilled.
while
not
seqB
.
is_finished
():
engine
.
step
()
sched_metas
,
sched_out
,
_
=
scheduler
.
last_schedule_ret
()
assert
len
(
sched_out
.
scheduled_seq_groups
)
==
1
assert
(
sched_out
.
scheduled_seq_groups
[
0
].
seq_group
.
request_id
==
seq_groupB
.
request_id
)
engine
.
step
()
sched_metas
,
sched_out
,
_
=
scheduler
.
last_schedule_ret
()
assert
len
(
sched_out
.
scheduled_seq_groups
)
==
1
assert
(
sched_out
.
scheduled_seq_groups
[
0
].
seq_group
.
request_id
==
seq_groupC
.
request_id
)
assert
sched_out
.
scheduled_seq_groups
[
0
].
token_chunk_size
==
len
(
seqA_tokens
)
tests/quantization/test_compressed_tensors.py
View file @
d2b52805
...
...
@@ -14,10 +14,10 @@ from compressed_tensors.quantization import QuantizationType
from
tests.models.utils
import
check_logprobs_close
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors
import
(
# noqa: E501
CompressedTensors24
,
CompressedTensorsLinearMethod
,
CompressedTensorsW4A4Fp4
,
CompressedTensorsW4A
16
Fp
4
,
CompressedTensorsW4A16
Sparse2
4
,
CompressedTensorsW
8A8Fp8
,
CompressedTensorsW8A8
Int
8
,
CompressedTensorsW8A
16Fp
8
,
CompressedTensorsWNA16
)
CompressedTensorsW4A4Fp4
,
CompressedTensorsW4A
8
Fp
8
,
CompressedTensorsW4A16
Fp
4
,
CompressedTensorsW
4A16Sparse24
,
CompressedTensorsW8A8
Fp
8
,
CompressedTensorsW8A
8Int
8
,
CompressedTensorsW8A16Fp8
,
CompressedTensorsWNA16
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
cutlass_fp4_supported
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
...
...
@@ -683,3 +683,61 @@ def test_compressed_tensors_nvfp4(vllm_runner, args):
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
print
(
output
)
assert
output
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
()
or
not
current_platform
.
has_device_capability
(
90
),
reason
=
"W4A8 FP8 is not yet supported on this GPU type."
,
)
@
pytest
.
mark
.
parametrize
(
"args"
,
[
(
"czhu-cohere/TinyLlama-1.1B-Chat-v1.0-W4A8-e2e"
,
CompressedTensorsW4A8Fp8
)
])
def
test_compressed_tensors_w4a8_fp8
(
vllm_runner
,
args
):
model
,
scheme
=
args
with
vllm_runner
(
model
,
enforce_eager
=
True
)
as
llm
:
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
o_proj
=
layer
.
self_attn
.
o_proj
gate_up_proj
=
layer
.
mlp
.
gate_up_proj
down_proj
=
layer
.
mlp
.
down_proj
for
proj
in
(
qkv_proj
,
o_proj
,
gate_up_proj
,
down_proj
):
assert
isinstance
(
proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
proj
.
scheme
,
scheme
)
assert
proj
.
weight_packed
.
dtype
is
torch
.
int32
assert
proj
.
weight_scale
.
dtype
is
torch
.
float8_e4m3fn
assert
proj
.
weight_chan_scale
.
dtype
is
torch
.
float32
assert
proj
.
scheme
.
group_size
==
128
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
print
(
output
)
assert
output
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"This test is skipped on non-CUDA platform."
)
@
pytest
.
mark
.
parametrize
(
"model,prompt,exp_perplexity"
,
[
(
"nm-testing/Llama-3.2-1B-Instruct-spinquantR1R2R4-w4a16"
,
"Flat is better than nested.
\n
Sparse is better than dense."
,
150.0
,
),
(
"nm-testing/Llama-3.2-1B-Instruct-quip-w4a16"
,
"Flat is better than nested.
\n
Sparse is better than dense."
,
150.0
,
),
])
def
test_compressed_tensors_transforms_perplexity
(
vllm_runner
,
model
,
prompt
,
exp_perplexity
):
with
vllm_runner
(
model
,
enforce_eager
=
True
)
as
llm
:
perplexity
=
llm
.
generate_prompt_perplexity
([
prompt
])[
0
]
print
(
perplexity
)
assert
perplexity
<=
exp_perplexity
\ No newline at end of file
tests/quantization/test_configs.py
View file @
d2b52805
...
...
@@ -22,22 +22,12 @@ class ModelPair:
MODEL_ARG_EXPTYPES
=
[
# AUTOGPTQ
# compat: autogptq <=0.7.1 is_marlin_format: bool
# Model Serialized in Marlin Format should always use Marlin kernel.
(
"neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin"
,
None
,
"marlin"
),
(
"neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin"
,
"marlin"
,
"marlin"
),
(
"neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin"
,
"gptq"
,
"marlin"
),
(
"neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin"
,
"awq"
,
"ERROR"
),
# Model Serialized in Exllama Format.
(
"TheBloke/Llama-2-7B-Chat-GPTQ"
,
None
,
"gptq_marlin"
),
(
"TheBloke/Llama-2-7B-Chat-GPTQ"
,
"marlin"
,
"gptq_marlin"
),
(
"TheBloke/Llama-2-7B-Chat-GPTQ"
,
"gptq"
,
"gptq"
),
(
"TheBloke/Llama-2-7B-Chat-GPTQ"
,
"awq"
,
"ERROR"
),
# compat: autogptq >=0.8.0 use checkpoint_format: str
# Model Serialized in Marlin Format should always use Marlin kernel.
(
"LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit"
,
None
,
"marlin"
),
(
"LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit"
,
"marlin"
,
"marlin"
),
(
"LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit"
,
"gptq"
,
"marlin"
),
(
"LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit"
,
"awq"
,
"ERROR"
),
# Model Serialized in Exllama Format.
(
"LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"
,
None
,
"gptq_marlin"
),
(
"LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"
,
"marlin"
,
"gptq_marlin"
),
...
...
tests/quantization/test_fp8.py
View file @
d2b52805
...
...
@@ -38,8 +38,7 @@ def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool,
with
vllm_runner
(
model_id
)
as
llm
:
# note: this does not test accuracy, just that we can run through
# see lm-eval tests for accuracy
outputs
=
llm
.
generate_greedy
(
prompts
=
[
"Hello my name is"
],
max_tokens
=
10
)
outputs
=
llm
.
generate_greedy
([
"Hello my name is"
],
max_tokens
=
10
)
print
(
outputs
[
0
][
1
])
...
...
@@ -90,8 +89,7 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str,
# note: this does not test accuracy, just that we can run through
# see lm-eval tests for accuracy
outputs
=
llm
.
generate_greedy
(
prompts
=
[
"Hello my name is"
],
max_tokens
=
10
)
outputs
=
llm
.
generate_greedy
([
"Hello my name is"
],
max_tokens
=
10
)
print
(
outputs
[
0
][
1
])
...
...
tests/quantization/test_lm_head.py
View file @
d2b52805
...
...
@@ -11,7 +11,6 @@ import torch
from
vllm.model_executor.layers.quantization.gptq
import
GPTQLinearMethod
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQMarlinLinearMethod
)
from
vllm.model_executor.layers.quantization.marlin
import
MarlinLinearMethod
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
UnquantizedEmbeddingMethod
)
...
...
@@ -19,9 +18,7 @@ PROMPT = "On the surface of Mars, we found"
MODELS_QUANT
=
[
(
"ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head"
,
True
),
(
"ModelCloud/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit-10-25-2024"
,
False
),
(
"TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ"
,
False
),
(
"neuralmagic/Meta-Llama-3-8B-Instruct-FP8"
,
False
)
]
...
...
@@ -41,8 +38,7 @@ def test_lm_head(
lm_head_layer
=
model
.
lm_head
if
lm_head_quantized
:
assert
isinstance
(
lm_head_layer
.
quant_method
,
(
GPTQLinearMethod
,
GPTQMarlinLinearMethod
,
MarlinLinearMethod
))
(
GPTQLinearMethod
,
GPTQMarlinLinearMethod
))
else
:
assert
isinstance
(
lm_head_layer
.
quant_method
,
UnquantizedEmbeddingMethod
)
...
...
@@ -50,5 +46,5 @@ def test_lm_head(
vllm_model
.
apply_model
(
check_model
)
print
(
vllm_model
.
generate_greedy
(
prompts
=
[
"Hello my name is"
],
vllm_model
.
generate_greedy
([
"Hello my name is"
],
max_tokens
=
10
)[
0
][
1
])
tests/samplers/test_beam_search.py
View file @
d2b52805
...
...
@@ -67,6 +67,59 @@ def test_beam_search_single_input(
f
"vLLM:
{
vllm_output_ids
}
"
)
@
pytest
.
mark
.
skip_v1
# FIXME: This fails on V1 right now.
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
MAX_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"beam_width"
,
BEAM_WIDTHS
)
def
test_beam_search_with_concurrency_limit
(
hf_runner
,
vllm_runner
,
example_prompts
,
model
:
str
,
dtype
:
str
,
max_tokens
:
int
,
beam_width
:
int
,
)
->
None
:
# example_prompts[1]&[3]&[7] fails due to unknown reason even without
# concurency limit. skip them for now.
example_prompts
=
(
example_prompts
[:
8
])
concurrency_limit
=
2
assert
len
(
example_prompts
)
>
concurrency_limit
with
vllm_runner
(
model
,
dtype
=
dtype
)
as
vllm_model
:
outputs_with_limit
=
vllm_model
.
generate_beam_search
(
example_prompts
,
beam_width
,
max_tokens
,
concurrency_limit
=
concurrency_limit
)
outputs_without_limit
=
[]
for
i
in
range
(
0
,
len
(
example_prompts
),
concurrency_limit
):
outputs_without_limit
.
extend
(
vllm_model
.
generate_beam_search
(
example_prompts
[
i
:
i
+
concurrency_limit
],
beam_width
,
max_tokens
))
correct
=
True
for
i
in
range
(
len
(
example_prompts
)):
output_ids_with_limit
,
output_texts_with_limit
=
outputs_with_limit
[
i
]
output_ids_without_limit
,
output_texts_without_limit
=
(
outputs_without_limit
[
i
])
for
j
,
(
text_with_limit
,
text_without_limit
)
in
enumerate
(
zip
(
output_texts_with_limit
,
output_texts_without_limit
)):
print
(
f
">>>
{
j
}
-th with limit output:"
)
print
(
text_with_limit
)
print
(
f
">>>
{
j
}
-th without limit output:"
)
print
(
text_without_limit
)
assert
len
(
output_ids_with_limit
)
==
len
(
output_ids_without_limit
)
for
j
in
range
(
len
(
output_ids_with_limit
)):
if
output_ids_with_limit
[
j
]
!=
output_ids_without_limit
[
j
]:
print
(
f
"Test
{
i
}
output
{
j
}
:
\n
+limit:
{
output_ids_with_limit
}
\n
"
f
"-limit:
{
output_ids_without_limit
}
"
)
correct
=
False
assert
correct
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
MAX_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"beam_width"
,
MM_BEAM_WIDTHS
)
...
...
tests/samplers/test_sampler.py
deleted
100644 → 0
View file @
9a521c23
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
itertools
import
random
from
dataclasses
import
dataclass
from
typing
import
Optional
from
unittest.mock
import
Mock
,
patch
import
pytest
import
torch
from
transformers
import
GenerationConfig
,
GenerationMixin
import
vllm.envs
as
envs
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.utils
import
Counter
,
is_pin_memory_available
@
pytest
.
fixture
(
scope
=
"function"
,
autouse
=
True
)
def
use_v0_only
(
monkeypatch
):
"""
This file tests V0 internals, so set VLLM_USE_V1=0.
"""
monkeypatch
.
setenv
(
'VLLM_USE_V1'
,
'0'
)
class
MockLogitsSampler
(
Sampler
):
def
__init__
(
self
,
fake_logits
:
torch
.
Tensor
):
super
().
__init__
()
self
.
fake_logits
=
fake_logits
def
forward
(
self
,
*
args
,
**
kwargs
):
return
super
().
forward
(
*
args
,
**
kwargs
)
def
_prepare_test
(
batch_size
:
int
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
MockLogitsSampler
]:
input_tensor
=
torch
.
rand
((
batch_size
,
1024
),
dtype
=
torch
.
float16
)
fake_logits
=
torch
.
full
((
batch_size
,
VOCAB_SIZE
),
1e-2
,
dtype
=
input_tensor
.
dtype
)
sampler
=
MockLogitsSampler
(
fake_logits
)
return
input_tensor
,
fake_logits
,
sampler
VOCAB_SIZE
=
32000
RANDOM_SEEDS
=
list
(
range
(
128
))
CUDA_DEVICES
=
[
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
]
def
_do_sample
(
batch_size
:
int
,
input_tensor
:
torch
.
Tensor
,
sampler
:
MockLogitsSampler
,
sampling_params
:
SamplingParams
,
device
:
str
,
):
seq_group_metadata_list
:
list
[
SequenceGroupMetadata
]
=
[]
seq_lens
:
list
[
int
]
=
[]
for
i
in
range
(
batch_size
):
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
.
from_seqs
([
1
,
2
,
3
])},
sampling_params
=
sampling_params
,
block_tables
=
{
0
:
[
1
]},
))
seq_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_lens
,
query_lens
=
seq_lens
,
device
=
device
,
pin_memory
=
is_pin_memory_available
())
return
sampler
(
logits
=
input_tensor
,
sampling_metadata
=
sampling_metadata
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_sampler_all_greedy
(
seed
:
int
,
device
:
str
):
set_random_seed
(
seed
)
torch
.
set_default_device
(
device
)
batch_size
=
random
.
randint
(
1
,
256
)
input_tensor
,
fake_logits
,
sampler
=
_prepare_test
(
batch_size
)
sampling_params
=
SamplingParams
(
temperature
=
0
)
sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
sampling_params
,
device
)
expected
=
torch
.
argmax
(
fake_logits
,
dim
=-
1
)
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
nth_output
in
sequence_output
.
samples
:
assert
nth_output
.
output_token
==
expected
[
i
].
item
()
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_sampler_all_random
(
seed
:
int
,
device
:
str
):
set_random_seed
(
seed
)
torch
.
set_default_device
(
device
)
batch_size
=
random
.
randint
(
1
,
256
)
_
,
fake_logits
,
sampler
=
_prepare_test
(
batch_size
)
for
i
in
range
(
batch_size
):
fake_logits
[
i
,
i
]
=
1e2
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
n
=
random
.
randint
(
1
,
10
),
)
sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
sampling_params
,
device
)
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
nth_output
in
sequence_output
.
samples
:
assert
nth_output
.
output_token
==
i
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_sampler_all_random_seed
(
seed
:
int
,
device
:
str
):
set_random_seed
(
seed
)
torch
.
set_default_device
(
device
)
batch_size
=
random
.
randint
(
1
,
256
)
_
,
fake_logits
,
sampler
=
_prepare_test
(
batch_size
)
for
i
in
range
(
batch_size
):
fake_logits
[
i
,
i
]
=
1e2
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
n
=
random
.
randint
(
1
,
10
),
seed
=
random
.
randint
(
0
,
10000
),
)
sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
sampling_params
,
device
)
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
nth_output
in
sequence_output
.
samples
:
assert
nth_output
.
output_token
==
i
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_sampler_all_random_seed_deterministic
(
seed
:
int
,
device
:
str
):
set_random_seed
(
seed
)
torch
.
set_default_device
(
device
)
batch_size
=
random
.
randint
(
1
,
256
)
_
,
fake_logits
,
sampler
=
_prepare_test
(
batch_size
)
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
n
=
random
.
randint
(
1
,
10
),
seed
=
random
.
randint
(
0
,
10000
),
)
first_sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
sampling_params
,
device
)
second_sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
sampling_params
,
device
)
assert
first_sampler_output
==
second_sampler_output
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_sampler_min_tokens_penalty
(
seed
:
int
,
device
:
str
):
seq_id_counter
=
Counter
(
start
=
random
.
randint
(
0
,
100
))
set_random_seed
(
seed
)
torch
.
set_default_device
(
device
)
def
create_sampling_params
(
min_tokens
,
eos_token_id
=
0
,
*
,
stop_token_ids
:
Optional
[
list
[
int
]]
=
None
,
prompt_logprobs
:
Optional
[
int
]
=
None
):
sampling_params
=
SamplingParams
(
min_tokens
=
min_tokens
,
max_tokens
=
9999
,
# keep higher than max of min_tokens
stop_token_ids
=
stop_token_ids
,
# requesting prompt_logprobs changes the structure of `logits`
prompt_logprobs
=
prompt_logprobs
,
)
sampling_params
.
all_stop_token_ids
.
add
(
eos_token_id
)
return
sampling_params
def
create_sequence_data
(
num_input
=
3
,
num_generated
=
0
):
seq_data
=
SequenceData
.
from_seqs
(
random
.
choices
(
range
(
0
,
VOCAB_SIZE
),
k
=
num_input
))
if
num_generated
>
0
:
seq_data
.
output_token_ids
=
random
.
choices
(
range
(
0
,
VOCAB_SIZE
),
k
=
num_generated
)
return
seq_data
def
generate_test_case
():
# generate multiple seq groups but limit total batch size
batch_size
=
random
.
randint
(
1
,
128
)
expected_penalization
=
[]
sequence_metadata_list
:
list
[
SequenceGroupMetadata
]
=
[]
# 20% chance to generate seq group metadata list with all prompts
is_prompt
=
random
.
random
()
<
0.2
while
batch_size
>
0
:
num_seqs
=
1
if
is_prompt
else
random
.
randint
(
1
,
batch_size
)
eos_token_id
=
random
.
randint
(
0
,
VOCAB_SIZE
-
1
)
min_tokens
=
random
.
randint
(
0
,
50
)
num_stop_tokens
=
random
.
randint
(
0
,
8
)
if
num_stop_tokens
>
0
:
stop_token_ids
=
random
.
choices
(
range
(
0
,
VOCAB_SIZE
-
1
),
k
=
num_stop_tokens
)
else
:
stop_token_ids
=
None
sampling_params
=
create_sampling_params
(
min_tokens
=
min_tokens
,
eos_token_id
=
eos_token_id
,
stop_token_ids
=
stop_token_ids
)
seq_data
:
dict
[
int
,
SequenceData
]
=
{}
seq_group_penalization
:
list
[
bool
]
=
[]
for
_
in
range
(
num_seqs
):
num_input
=
random
.
randint
(
1
,
100
)
num_generated
=
0
if
is_prompt
else
random
.
randint
(
1
,
100
)
seq_data
[
next
(
seq_id_counter
)]
=
create_sequence_data
(
num_input
=
num_input
,
num_generated
=
num_generated
)
seq_group_penalization
.
append
(
num_generated
<
min_tokens
)
expected_penalization
.
extend
(
seq_group_penalization
)
sequence_metadata_list
.
append
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
batch_size
}
"
,
is_prompt
=
is_prompt
,
seq_data
=
seq_data
,
sampling_params
=
sampling_params
,
block_tables
=
{},
))
batch_size
-=
num_seqs
return
{
"expected_penalization"
:
expected_penalization
,
"seq_group_metadata_list"
:
sequence_metadata_list
,
}
# define some explicit test cases for edge case behavior
prompt_without_penalization
=
{
"expected_penalization"
:
[
False
],
"seq_group_metadata_list"
:
[
SequenceGroupMetadata
(
request_id
=
"test_1"
,
is_prompt
=
True
,
seq_data
=
{
next
(
seq_id_counter
):
create_sequence_data
(),
},
sampling_params
=
create_sampling_params
(
0
),
block_tables
=
{},
),
]
}
prompt_with_penalization
=
{
"expected_penalization"
:
[
True
],
"seq_group_metadata_list"
:
[
SequenceGroupMetadata
(
request_id
=
"test_1"
,
is_prompt
=
True
,
seq_data
=
{
next
(
seq_id_counter
):
create_sequence_data
(),
},
sampling_params
=
create_sampling_params
(
1
),
block_tables
=
{},
),
]
}
prompt_with_penalization_and_prompt_logprobs
=
{
"expected_penalization"
:
[
False
,
False
,
True
],
"seq_group_metadata_list"
:
[
SequenceGroupMetadata
(
request_id
=
"test_1"
,
is_prompt
=
True
,
seq_data
=
{
next
(
seq_id_counter
):
create_sequence_data
(
num_input
=
3
),
},
sampling_params
=
create_sampling_params
(
1
,
prompt_logprobs
=
3
),
block_tables
=
{},
),
]
}
stop_penalizing_after_min_tokens
=
{
"expected_penalization"
:
[
False
],
"seq_group_metadata_list"
:
[
SequenceGroupMetadata
(
request_id
=
"test_1"
,
is_prompt
=
False
,
seq_data
=
{
next
(
seq_id_counter
):
create_sequence_data
(
num_generated
=
1
),
},
sampling_params
=
create_sampling_params
(
1
),
block_tables
=
{},
)
]
}
stop_token_ids
=
[
42
,
99
,
42
,
0
]
# intentional duplication
prompt_combination
=
{
"expected_penalization"
:
[
False
,
True
,
False
],
"seq_group_metadata_list"
:
[
SequenceGroupMetadata
(
request_id
=
"test_2"
,
is_prompt
=
True
,
seq_data
=
{
next
(
seq_id_counter
):
create_sequence_data
(
num_input
=
2
),
},
sampling_params
=
create_sampling_params
(
1
,
prompt_logprobs
=
3
),
block_tables
=
{},
),
SequenceGroupMetadata
(
request_id
=
"test_3"
,
is_prompt
=
True
,
seq_data
=
{
next
(
seq_id_counter
):
create_sequence_data
(),
},
sampling_params
=
create_sampling_params
(
0
,
stop_token_ids
=
stop_token_ids
),
block_tables
=
{},
)
]
}
stop_token_ids
=
[
1
,
999
,
37
,
37
]
# intentional duplication
decode_combination
=
{
"expected_penalization"
:
[
True
,
False
,
False
,
True
,
False
],
"seq_group_metadata_list"
:
[
SequenceGroupMetadata
(
request_id
=
"test_1"
,
is_prompt
=
False
,
seq_data
=
{
next
(
seq_id_counter
):
create_sequence_data
(
num_generated
=
1
),
next
(
seq_id_counter
):
create_sequence_data
(
num_generated
=
100
),
},
sampling_params
=
create_sampling_params
(
2
,
stop_token_ids
=
stop_token_ids
),
block_tables
=
{},
),
SequenceGroupMetadata
(
request_id
=
"test_2"
,
is_prompt
=
False
,
seq_data
=
{
next
(
seq_id_counter
):
create_sequence_data
(
num_generated
=
20
),
next
(
seq_id_counter
):
create_sequence_data
(
num_generated
=
1
),
next
(
seq_id_counter
):
create_sequence_data
(
num_generated
=
10
),
},
sampling_params
=
create_sampling_params
(
10
,
prompt_logprobs
=
5
,
stop_token_ids
=
stop_token_ids
),
block_tables
=
{},
),
]
}
if
seed
==
0
:
test_cases
=
[
prompt_without_penalization
,
prompt_with_penalization
,
prompt_with_penalization_and_prompt_logprobs
,
stop_penalizing_after_min_tokens
,
prompt_combination
,
decode_combination
,
]
else
:
test_cases
=
[
generate_test_case
()]
def
run_test_case
(
*
,
expected_penalization
:
list
[
bool
],
seq_group_metadata_list
:
list
[
SequenceGroupMetadata
]):
assert
expected_penalization
,
\
"Invalid test case, need expected_penalization"
assert
seq_group_metadata_list
,
\
"Invalid test case, need seq_group_metadata_list"
batch_size
=
0
seq_lens
:
list
[
int
]
=
[]
sampling_params_per_row
:
list
[
SamplingParams
]
=
[]
for
sgm
in
seq_group_metadata_list
:
sampling_params
=
sgm
.
sampling_params
num_rows
=
len
(
sgm
.
seq_data
)
if
sgm
.
is_prompt
:
# a prompt seq_group has only one sequence
seq_data
=
next
(
iter
(
sgm
.
seq_data
.
values
()))
prompt_len
=
seq_data
.
get_prompt_len
()
seq_lens
.
append
(
prompt_len
)
assert
sgm
.
sampling_params
is
not
None
if
sgm
.
sampling_params
.
prompt_logprobs
:
# with prompt_logprobs each token in the prompt has a row in
# logits
num_rows
=
prompt_len
batch_size
+=
num_rows
sampling_params_per_row
.
extend
(
itertools
.
repeat
(
sampling_params
,
num_rows
))
assert
len
(
expected_penalization
)
==
batch_size
,
\
(
"Invalid test case, expected_penalization does not match computed"
"batch size"
)
_
,
fake_logits
,
sampler
=
_prepare_test
(
batch_size
)
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_lens
=
seq_lens
if
seq_lens
else
None
,
query_lens
=
seq_lens
if
seq_lens
else
[
1
]
*
batch_size
,
device
=
device
,
pin_memory
=
is_pin_memory_available
())
# the logits tensor is modified in-place by the sampler
_
=
sampler
(
logits
=
fake_logits
,
sampling_metadata
=
sampling_metadata
)
for
logits_idx
,
(
should_penalize
,
sampling_params
)
in
enumerate
(
zip
(
expected_penalization
,
sampling_params_per_row
)):
tokens_to_check
=
sampling_params
.
all_stop_token_ids
if
should_penalize
:
for
token_id
in
tokens_to_check
:
assert
fake_logits
[
logits_idx
,
token_id
]
==
-
float
(
'inf'
),
f
"Expected token
{
token_id
}
for logits row
{
logits_idx
}
"
" to be penalized"
# no other tokens should be set to -inf
assert
torch
.
count_nonzero
(
fake_logits
[
logits_idx
,
:]
==
-
float
(
'inf'
))
==
len
(
tokens_to_check
),
f
"Expected only
{
len
(
tokens_to_check
)
}
to be penalized"
else
:
# no tokens should be set to -inf
assert
torch
.
count_nonzero
(
fake_logits
[
logits_idx
,
:]
==
-
float
(
'inf'
))
==
0
,
"No tokens should have been penalized"
for
test_case
in
test_cases
:
run_test_case
(
**
test_case
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_sampler_mixed
(
seed
:
int
,
device
:
str
):
set_random_seed
(
seed
)
torch
.
set_default_device
(
device
)
batch_size
=
random
.
randint
(
1
,
256
)
input_tensor
,
fake_logits
,
sampler
=
_prepare_test
(
batch_size
)
seq_group_metadata_list
:
list
[
SequenceGroupMetadata
]
=
[]
expected_tokens
:
list
[
Optional
[
list
[
int
]]]
=
[]
seq_lens
:
list
[
int
]
=
[]
for
i
in
range
(
batch_size
):
expected
:
Optional
[
list
[
int
]]
=
None
sampling_type
=
random
.
randint
(
0
,
2
)
if
sampling_type
==
0
:
sampling_params
=
SamplingParams
(
temperature
=
0
)
expected
=
[
int
(
torch
.
argmax
(
fake_logits
[
i
],
dim
=-
1
).
item
())]
elif
sampling_type
in
(
1
,
2
):
n
=
random
.
randint
(
1
,
10
)
sampling_params
=
SamplingParams
(
temperature
=
random
.
random
()
+
0.1
,
top_p
=
min
(
random
.
random
()
+
0.1
,
1
),
top_k
=
random
.
randint
(
0
,
10
),
n
=
n
,
presence_penalty
=
random
.
randint
(
0
,
1
),
)
if
sampling_type
==
2
:
sampling_params
.
seed
=
random
.
randint
(
0
,
10000
)
else
:
for
idx
in
range
(
n
):
fake_logits
[
i
,
i
+
idx
]
=
1e2
expected
=
list
(
range
(
i
,
i
+
n
))
expected_tokens
.
append
(
expected
)
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
.
from_seqs
([
1
,
2
,
3
])},
sampling_params
=
sampling_params
,
block_tables
=
{
0
:
[
1
]},
))
seq_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
generators
:
dict
[
str
,
torch
.
Generator
]
=
{}
def
test_sampling
():
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_lens
,
query_lens
=
seq_lens
,
device
=
device
,
pin_memory
=
is_pin_memory_available
(),
generators
=
generators
)
sampler_output
=
sampler
(
logits
=
fake_logits
,
sampling_metadata
=
sampling_metadata
)
for
i
,
(
sequence_output
,
metadata
)
in
enumerate
(
zip
(
sampler_output
,
seq_group_metadata_list
)):
assert
metadata
.
sampling_params
is
not
None
if
(
metadata
.
sampling_params
.
seed
is
not
None
and
expected_tokens
[
i
]
is
None
):
# Record seeded random result to compare with results of
# second invocation
expected_tokens
[
i
]
=
[
nth_output
.
output_token
for
nth_output
in
sequence_output
.
samples
]
continue
expected_tokens_item
=
expected_tokens
[
i
]
assert
expected_tokens_item
is
not
None
for
n
,
nth_output
in
enumerate
(
sequence_output
.
samples
):
assert
metadata
.
sampling_params
is
not
None
if
(
metadata
.
sampling_params
.
temperature
==
0
or
metadata
.
sampling_params
.
seed
is
not
None
):
# Ensure exact matches for greedy or random with seed
assert
nth_output
.
output_token
==
expected_tokens_item
[
n
]
else
:
# For non-seeded random check that one of the high-logit
# tokens were chosen
assert
nth_output
.
output_token
in
expected_tokens_item
# Test batch
test_sampling
()
# Shuffle the batch and resample
target_index
=
list
(
range
(
batch_size
))
for
list_to_shuffle
in
(
target_index
,
seq_group_metadata_list
,
expected_tokens
,
seq_lens
):
random
.
Random
(
seed
).
shuffle
(
list_to_shuffle
)
target_index
=
torch
.
tensor
(
target_index
)
input_tensor
.
data
=
input_tensor
.
index_select
(
0
,
target_index
)
fake_logits
.
data
=
fake_logits
.
index_select
(
0
,
target_index
)
# This time, results of seeded random samples will be compared with
# the corresponding sample in the pre-shuffled batch
test_sampling
()
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_sampler_top_k_top_p
(
seed
:
int
,
device
:
str
):
set_random_seed
(
seed
)
batch_size
=
random
.
randint
(
1
,
256
)
top_k
=
random
.
randint
(
100
,
500
)
top_p
=
random
.
random
()
*
0.1
vocab_size
=
32000
input_tensor
=
torch
.
rand
((
batch_size
,
1024
),
device
=
device
,
dtype
=
torch
.
float16
)
fake_logits
=
torch
.
normal
(
0
,
5
,
size
=
(
batch_size
,
vocab_size
),
device
=
input_tensor
.
device
,
dtype
=
input_tensor
.
dtype
)
sampler
=
MockLogitsSampler
(
fake_logits
)
generation_model
=
GenerationMixin
()
generation_config
=
GenerationConfig
(
top_k
=
top_k
,
top_p
=
top_p
,
do_sample
=
True
)
@
dataclass
class
MockConfig
:
is_encoder_decoder
:
bool
=
False
generation_model
.
config
=
MockConfig
()
# needed by the following method
generation_model
.
_prepare_special_tokens
(
generation_config
,
device
=
device
)
processors
=
generation_model
.
_get_logits_processor
(
generation_config
,
None
,
None
,
None
,
[],
device
=
device
)
assert
len
(
processors
)
==
2
# top_p and top_k
seq_group_metadata_list
:
list
[
SequenceGroupMetadata
]
=
[]
seq_lens
:
list
[
int
]
=
[]
for
i
in
range
(
batch_size
):
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
.
from_seqs
([
1
,
2
,
3
])},
sampling_params
=
SamplingParams
(
temperature
=
1
,
top_k
=
top_k
,
top_p
=
top_p
,
),
block_tables
=
{
0
:
[
1
]},
))
seq_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_lens
,
query_lens
=
seq_lens
,
device
=
device
,
pin_memory
=
is_pin_memory_available
())
sample_probs
=
None
def
mock_sample
(
probs
,
*
args
,
**
kwargs
):
nonlocal
sample_probs
sample_probs
=
probs
return
([[
prob
.
topk
(
1
,
dim
=-
1
).
indices
.
tolist
(),
[
0
]]
for
prob
in
probs
],
None
)
# top-k and top-p is only calculated when flashinfer kernel is not available
with
patch
(
"vllm.model_executor.layers.sampler._sample"
,
mock_sample
),
\
patch
(
"vllm.model_executor.layers.sampler."
"flashinfer_top_k_top_p_sampling"
,
None
):
sampler
(
logits
=
fake_logits
,
sampling_metadata
=
sampling_metadata
)
assert
sample_probs
is
not
None
hf_probs
=
processors
(
torch
.
zeros_like
(
fake_logits
),
fake_logits
.
clone
())
hf_probs
=
torch
.
softmax
(
hf_probs
,
dim
=-
1
,
dtype
=
torch
.
float
)
torch
.
testing
.
assert_close
(
hf_probs
,
sample_probs
,
rtol
=
0.0
,
atol
=
1e-5
)
assert
torch
.
equal
(
hf_probs
.
eq
(
0
),
sample_probs
.
eq
(
0
))
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_flashinfer_fallback
(
seed
:
int
,
device
:
str
):
if
not
envs
.
VLLM_USE_FLASHINFER_SAMPLER
:
pytest
.
skip
(
"Flashinfer sampler is disabled"
)
pytest
.
skip
(
"After FlashInfer 0.2.3, sampling will never fail"
)
set_random_seed
(
seed
)
torch
.
set_default_device
(
device
)
batch_size
=
random
.
randint
(
1
,
256
)
_
,
fake_logits
,
sampler
=
_prepare_test
(
batch_size
)
def
failing_flashinfer_sampling
(
*
_args
,
**
_kwargs
):
return
None
,
torch
.
zeros
(
batch_size
,
device
=
device
,
dtype
=
torch
.
int32
)
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
n
=
random
.
randint
(
1
,
10
),
seed
=
random
.
randint
(
0
,
10000
),
)
sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
sampling_params
,
device
)
with
patch
(
"vllm.model_executor.layers.sampler."
"flashinfer_top_k_top_p_sampling"
,
failing_flashinfer_sampling
):
fallback_sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
sampling_params
,
device
)
assert
sampler_output
==
fallback_sampler_output
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_sampler_repetition_penalty_mixed
(
device
:
str
):
vocab_size
=
8
def
test_sampling_params
(
sampling_params
:
list
[
SamplingParams
]):
seq_group_metadata_list
:
list
[
SequenceGroupMetadata
]
=
[]
seq_lens
:
list
[
int
]
=
[]
for
i
in
range
(
2
):
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
.
from_seqs
([
1
,
2
,
3
])},
sampling_params
=
sampling_params
[
i
],
block_tables
=
{
0
:
[
1
]},
))
seq_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_lens
,
query_lens
=
seq_lens
,
device
=
device
,
pin_memory
=
is_pin_memory_available
())
fake_logits
=
torch
.
full
((
2
,
vocab_size
),
1e-2
,
device
=
device
,
dtype
=
torch
.
float16
)
fake_logits
[:,
5
]
=
1.1e-2
fake_logits
[:,
1
]
=
1.2e-2
sampler
=
MockLogitsSampler
(
fake_logits
)
sampler_output
=
sampler
(
logits
=
fake_logits
,
sampling_metadata
=
sampling_metadata
)
generated_tokens
=
[]
for
output
in
sampler_output
:
generated_tokens
.
append
(
output
.
samples
[
0
].
output_token
)
return
generated_tokens
# one configuration is greedy with repetition_penalty
sampling_params_rep
=
SamplingParams
(
temperature
=
0.0
,
repetition_penalty
=
2.0
,
)
# other configuration is sampling w/o repetition_penalty
sampling_params_sample
=
SamplingParams
(
temperature
=
1.0
,
top_k
=
1
,
seed
=
42
,
)
tokens1
=
test_sampling_params
(
[
sampling_params_rep
,
sampling_params_sample
])
tokens2
=
test_sampling_params
(
[
sampling_params_sample
,
sampling_params_rep
])
assert
tokens1
[
0
]
==
tokens2
[
1
]
assert
tokens1
[
1
]
==
tokens2
[
0
]
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_sampler_include_gpu_probs_tensor
(
device
:
str
):
set_random_seed
(
42
)
torch
.
set_default_device
(
device
)
batch_size
=
random
.
randint
(
1
,
256
)
_
,
fake_logits
,
sampler
=
_prepare_test
(
batch_size
)
sampler
.
include_gpu_probs_tensor
=
True
sampler
.
should_modify_greedy_probs_inplace
=
False
sampling_params
=
SamplingParams
(
temperature
=
0
)
mock_inplace
=
Mock
()
with
patch
(
"vllm.model_executor.layers.sampler._modify_greedy_probs_inplace"
,
mock_inplace
):
sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
sampling_params
,
device
)
mock_inplace
.
assert_not_called
()
assert
sampler_output
.
sampled_token_probs
is
not
None
assert
sampler_output
.
logprobs
is
not
None
assert
sampler_output
.
sampled_token_ids
is
not
None
tests/samplers/test_seeded_generate.py
deleted
100644 → 0
View file @
9a521c23
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Verify that seeded random sampling is deterministic.
Run `pytest tests/samplers/test_seeded_generate.py`.
"""
import
copy
import
random
from
itertools
import
combinations
import
pytest
from
vllm
import
SamplingParams
from
vllm.model_executor.utils
import
set_random_seed
MODEL
=
"facebook/opt-125m"
RANDOM_SEEDS
=
list
(
range
(
5
))
@
pytest
.
fixture
def
vllm_model
(
vllm_runner
,
monkeypatch
):
# This file relies on V0 internals.
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"0"
)
with
vllm_runner
(
MODEL
,
dtype
=
"half"
)
as
vllm_model
:
yield
vllm_model
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
def
test_random_sample_with_seed
(
vllm_model
,
example_prompts
,
seed
:
int
,
)
->
None
:
set_random_seed
(
seed
)
sampling_params
=
SamplingParams
(
# Parameters to ensure sufficient randomness
temperature
=
3.0
,
top_p
=
min
(
random
.
random
()
+
0.3
,
1
),
top_k
=
random
.
randint
(
5
,
20
),
n
=
random
.
randint
(
1
,
10
),
presence_penalty
=
random
.
randint
(
0
,
1
),
max_tokens
=
8
,
ignore_eos
=
True
,
)
sampling_params_seed_1
=
copy
.
deepcopy
(
sampling_params
)
sampling_params_seed_1
.
seed
=
100
sampling_params_seed_2
=
copy
.
deepcopy
(
sampling_params
)
sampling_params_seed_2
.
seed
=
200
llm
=
vllm_model
.
llm
for
prompt
in
example_prompts
:
for
params
in
(
sampling_params
,
sampling_params_seed_1
,
sampling_params_seed_2
,
sampling_params
,
sampling_params_seed_1
,
sampling_params_seed_2
,
):
llm
.
_add_request
(
prompt
,
params
=
params
)
results
=
llm
.
_run_engine
(
use_tqdm
=
False
)
all_outputs
=
[[
out
.
token_ids
for
out
in
output
.
outputs
]
for
output
in
results
]
for
i
in
range
(
0
,
len
(
example_prompts
),
6
):
outputs
=
all_outputs
[
i
:
i
+
6
]
# verify all non-seeded requests differ
for
output_a
,
output_b
in
combinations
(
(
outputs
[
0
],
outputs
[
1
],
outputs
[
2
],
outputs
[
3
]),
2
,
):
assert
output_a
!=
output_b
# verify requests with the same seed match
assert
outputs
[
1
]
==
outputs
[
4
]
assert
outputs
[
2
]
==
outputs
[
5
]
# verify generations within the same parallel sampling group differ
for
output
in
outputs
:
for
sub_output_a
,
sub_output_b
in
combinations
(
output
,
2
):
assert
sub_output_a
!=
sub_output_b
tests/test_sequence.py
View file @
d2b52805
...
...
@@ -2,10 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
import
torch
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
SequenceData
,
SequenceOutput
)
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
IntermediateTensors
,
SequenceData
,
SequenceOutput
)
from
.core.utils
import
create_dummy_prompt
...
...
@@ -98,3 +99,38 @@ def test_sequence_group_stage():
assert
seq_group
.
is_prefill
()
is
True
seq_group
.
update_num_computed_tokens
(
1
)
assert
seq_group
.
is_prefill
()
is
False
def
test_sequence_intermediate_tensors_equal
():
class
AnotherIntermediateTensors
(
IntermediateTensors
):
pass
intermediate_tensors
=
IntermediateTensors
({})
another_intermediate_tensors
=
AnotherIntermediateTensors
({})
assert
intermediate_tensors
!=
another_intermediate_tensors
empty_intermediate_tensors_1
=
IntermediateTensors
({})
empty_intermediate_tensors_2
=
IntermediateTensors
({})
assert
empty_intermediate_tensors_1
==
empty_intermediate_tensors_2
different_key_intermediate_tensors_1
=
IntermediateTensors
(
{
"1"
:
torch
.
zeros
([
2
,
4
],
dtype
=
torch
.
int32
)})
difference_key_intermediate_tensors_2
=
IntermediateTensors
(
{
"2"
:
torch
.
zeros
([
2
,
4
],
dtype
=
torch
.
int32
)})
assert
(
different_key_intermediate_tensors_1
!=
difference_key_intermediate_tensors_2
)
same_key_different_value_intermediate_tensors_1
=
IntermediateTensors
(
{
"1"
:
torch
.
zeros
([
2
,
4
],
dtype
=
torch
.
int32
)})
same_key_different_value_intermediate_tensors_2
=
IntermediateTensors
(
{
"1"
:
torch
.
zeros
([
2
,
5
],
dtype
=
torch
.
int32
)})
assert
(
same_key_different_value_intermediate_tensors_1
!=
same_key_different_value_intermediate_tensors_2
)
same_key_same_value_intermediate_tensors_1
=
IntermediateTensors
(
{
"1"
:
torch
.
zeros
([
2
,
4
],
dtype
=
torch
.
int32
)})
same_key_same_value_intermediate_tensors_2
=
IntermediateTensors
(
{
"1"
:
torch
.
zeros
([
2
,
4
],
dtype
=
torch
.
int32
)})
assert
(
same_key_same_value_intermediate_tensors_1
==
same_key_same_value_intermediate_tensors_2
)
tests/tokenization/test_detokenize.py
View file @
d2b52805
...
...
@@ -64,8 +64,6 @@ def _run_incremental_decode(tokenizer,
request
=
EngineCoreRequest
(
""
,
prompt_token_ids
,
None
,
None
,
None
,
params
,
None
,
None
,
...
...
tests/tool_use/test_qwen3coder_tool_parser.py
View file @
d2b52805
...
...
@@ -16,7 +16,7 @@ from vllm.entrypoints.openai.tool_parsers.qwen3coder_tool_parser import (
from
vllm.transformers_utils.detokenizer
import
detokenize_incrementally
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
get_tokenizer
MODEL
=
"Qwen/Qwen3-Coder-
48
0B-A3
5
B-Instruct-FP8"
MODEL
=
"Qwen/Qwen3-Coder-
3
0B-A3B-Instruct-FP8"
@
pytest
.
fixture
(
scope
=
"module"
)
...
...
@@ -397,7 +397,9 @@ hello world
"no_tools"
,
"single_tool"
,
"single_tool_with_content"
,
"single_tool_multiline_param"
,
"parallel_tools"
,
"tool_with_typed_params"
,
# Added this test case
],
argnames
=
[
"model_output"
,
"expected_tool_calls"
,
"expected_content"
],
argvalues
=
[
...
...
@@ -422,7 +424,7 @@ fahrenheit
"state"
:
"TX"
,
"unit"
:
"fahrenheit"
})))
],
""
),
],
None
),
(
'''Sure! Let me check the weather for you.<tool_call>
<function=get_current_weather>
<parameter=city>
...
...
@@ -445,6 +447,30 @@ fahrenheit
})))
],
"Sure! Let me check the weather for you."
),
(
'''<tool_call>
<function=calculate_area>
<parameter=shape>
rectangle
</parameter>
<parameter=dimensions>
{"width": 10,
"height": 20}
</parameter>
<parameter=precision>
2
</parameter>
</function>
</tool_call>'''
,
[
ToolCall
(
function
=
FunctionCall
(
name
=
"calculate_area"
,
arguments
=
json
.
dumps
({
"shape"
:
"rectangle"
,
"dimensions"
:
{
"width"
:
10
,
"height"
:
20
},
"precision"
:
2
})))
],
None
),
(
'''<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
...
...
@@ -484,13 +510,36 @@ celsius
"state"
:
"FL"
,
"unit"
:
"celsius"
})))
],
""
),
],
None
),
# Added tool_with_typed_params test case
(
'''Let me calculate that area for you.<tool_call>
<function=calculate_area>
<parameter=shape>
circle
</parameter>
<parameter=dimensions>
{"radius": 15.5}
</parameter>
<parameter=precision>
3
</parameter>
</function>
</tool_call>'''
,
[
ToolCall
(
function
=
FunctionCall
(
name
=
"calculate_area"
,
arguments
=
json
.
dumps
({
"shape"
:
"circle"
,
"dimensions"
:
{
"radius"
:
15.5
},
"precision"
:
3
})))
],
"Let me calculate that area for you."
),
],
)
def
test_extract_tool_calls_streaming
(
qwen3_tool_parser
,
qwen3_tokenizer
,
sample_tools
,
model_output
,
expected_tool_calls
,
expected_content
):
"""Test incremental streaming behavior"""
"""Test incremental streaming behavior
including typed parameters
"""
request
=
ChatCompletionRequest
(
model
=
MODEL
,
messages
=
[],
tools
=
sample_tools
)
...
...
@@ -539,7 +588,7 @@ def test_extract_tool_calls_streaming(qwen3_tool_parser, qwen3_tokenizer,
"arguments"
]
+=
tool_call
.
function
.
arguments
# Verify final content
assert
other_content
==
expected_content
assert
other_content
==
(
expected_content
or
""
)
# Handle None case
# Verify we got all expected tool calls
assert
len
(
tool_states
)
==
len
(
expected_tool_calls
)
...
...
@@ -559,6 +608,125 @@ def test_extract_tool_calls_streaming(qwen3_tool_parser, qwen3_tokenizer,
assert
actual_args
==
expected_args
def
test_extract_tool_calls_missing_closing_parameter_tag
(
qwen3_tool_parser
,
sample_tools
):
"""Test handling of missing closing </parameter> tag"""
# Using get_current_weather from sample_tools but with malformed XML
model_output
=
'''Let me check the weather for you:
<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
<parameter=state>
TX
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>'''
request
=
ChatCompletionRequest
(
model
=
MODEL
,
messages
=
[],
tools
=
sample_tools
)
extracted_tool_calls
=
qwen3_tool_parser
.
extract_tool_calls
(
model_output
,
request
=
request
)
# The parser should handle the malformed XML gracefully
assert
extracted_tool_calls
.
tools_called
assert
len
(
extracted_tool_calls
.
tool_calls
)
==
1
# Verify the function name is correct
assert
extracted_tool_calls
.
tool_calls
[
0
].
function
.
name
==
"get_current_weather"
# Verify the arguments are parsed despite the missing closing tag
args
=
json
.
loads
(
extracted_tool_calls
.
tool_calls
[
0
].
function
.
arguments
)
assert
"city"
in
args
assert
args
[
"city"
]
==
"Dallas"
assert
args
[
"state"
]
==
"TX"
assert
args
[
"unit"
]
==
"fahrenheit"
# Check that content before the tool call is preserved
assert
"Let me check the weather for you:"
in
extracted_tool_calls
.
content
def
test_extract_tool_calls_streaming_missing_closing_tag
(
qwen3_tool_parser
,
qwen3_tokenizer
,
sample_tools
):
"""Test streaming with missing closing </parameter> tag"""
# Using get_current_weather from sample_tools but with malformed XML
model_output
=
'''Let me check the weather for you:
<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
<parameter=state>
TX
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>'''
request
=
ChatCompletionRequest
(
model
=
MODEL
,
messages
=
[],
tools
=
sample_tools
)
other_content
=
''
tool_states
=
{}
for
delta_message
in
stream_delta_message_generator
(
qwen3_tool_parser
,
qwen3_tokenizer
,
model_output
,
request
):
if
delta_message
.
content
:
other_content
+=
delta_message
.
content
if
delta_message
.
tool_calls
:
for
tool_call
in
delta_message
.
tool_calls
:
idx
=
tool_call
.
index
if
idx
not
in
tool_states
:
tool_states
[
idx
]
=
{
"id"
:
None
,
"name"
:
None
,
"arguments"
:
""
,
"type"
:
None
}
if
tool_call
.
id
:
tool_states
[
idx
][
"id"
]
=
tool_call
.
id
if
tool_call
.
type
:
assert
tool_call
.
type
==
"function"
tool_states
[
idx
][
"type"
]
=
tool_call
.
type
if
tool_call
.
function
:
if
tool_call
.
function
.
name
:
tool_states
[
idx
][
"name"
]
=
tool_call
.
function
.
name
if
tool_call
.
function
.
arguments
is
not
None
:
tool_states
[
idx
][
"arguments"
]
+=
tool_call
.
function
.
arguments
# Verify content was streamed
assert
"Let me check the weather for you:"
in
other_content
# Verify we got the tool call
assert
len
(
tool_states
)
==
1
state
=
tool_states
[
0
]
assert
state
[
"id"
]
is
not
None
assert
state
[
"type"
]
==
"function"
assert
state
[
"name"
]
==
"get_current_weather"
# Verify arguments were parsed correctly despite missing closing tag
assert
state
[
"arguments"
]
is
not
None
args
=
json
.
loads
(
state
[
"arguments"
])
assert
args
[
"city"
]
==
"Dallas"
assert
args
[
"state"
]
==
"TX"
assert
args
[
"unit"
]
==
"fahrenheit"
def
test_extract_tool_calls_streaming_incremental
(
qwen3_tool_parser
,
qwen3_tokenizer
,
sample_tools
):
...
...
tests/tool_use/test_seed_oss_tool_parser.py
0 → 100644
View file @
d2b52805
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
import
json
from
collections.abc
import
Generator
from
typing
import
Optional
import
pytest
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
ChatCompletionToolsParam
,
DeltaMessage
,
FunctionCall
,
ToolCall
)
from
vllm.entrypoints.openai.tool_parsers
import
SeedOssToolParser
from
vllm.transformers_utils.detokenizer
import
detokenize_incrementally
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
get_tokenizer
# Use a common model that is likely to be available
MODEL
=
"ByteDance-Seed/Seed-OSS-36B-Instruct"
@
pytest
.
fixture
(
scope
=
"module"
)
def
seed_oss_tokenizer
():
return
get_tokenizer
(
tokenizer_name
=
MODEL
,
trust_remote_code
=
True
)
@
pytest
.
fixture
def
seed_oss_tool_parser
(
seed_oss_tokenizer
):
return
SeedOssToolParser
(
seed_oss_tokenizer
)
@
pytest
.
fixture
def
sample_tools
():
return
[
ChatCompletionToolsParam
(
type
=
"function"
,
function
=
{
"name"
:
"get_weather"
,
"description"
:
"Get current temperature for a given location."
,
"parameters"
:
{
"type"
:
"object"
,
"properties"
:
{
"location"
:
{
"type"
:
"string"
,
"description"
:
"City and country e.g. Bogotá, Colombia"
},
"unit"
:
{
"type"
:
"string"
,
"description"
:
"this is the unit of temperature"
}
},
"required"
:
[
"location"
],
"additionalProperties"
:
False
},
"returns"
:
{
"type"
:
"object"
,
"properties"
:
{
"temperature"
:
{
"type"
:
"number"
,
"description"
:
"temperature in celsius"
}
},
"required"
:
[
"temperature"
],
"additionalProperties"
:
False
},
"strict"
:
True
}),
]
def
assert_tool_calls
(
actual_tool_calls
:
list
[
ToolCall
],
expected_tool_calls
:
list
[
ToolCall
]):
assert
len
(
actual_tool_calls
)
==
len
(
expected_tool_calls
)
for
actual_tool_call
,
expected_tool_call
in
zip
(
actual_tool_calls
,
expected_tool_calls
):
# Seed-OSS tool call will not generate id
assert
actual_tool_call
.
type
==
"function"
assert
actual_tool_call
.
function
==
expected_tool_call
.
function
assert
actual_tool_call
.
function
.
name
==
expected_tool_call
.
function
.
name
assert
actual_tool_call
.
function
.
arguments
==
expected_tool_call
.
function
.
arguments
def
test_extract_tool_calls_no_tools
(
seed_oss_tool_parser
):
model_output
=
"This is a test response without any tool calls"
extracted_tool_calls
=
seed_oss_tool_parser
.
extract_tool_calls
(
model_output
,
request
=
None
)
# type: ignore[arg-type]
assert
not
extracted_tool_calls
.
tools_called
assert
extracted_tool_calls
.
tool_calls
==
[]
assert
extracted_tool_calls
.
content
==
model_output
@
pytest
.
mark
.
parametrize
(
ids
=
[
"tool_call_0_thinking_budget"
,
"tool_call_512_thinkg_budget"
,
"tool_call_unlimited_thinking_budget"
,
],
argnames
=
[
"model_output"
,
"expected_tool_calls"
,
"expected_content"
],
argvalues
=
[
(
"""<seed:tool_call>
\n
<function=get_weather>
\n
"""
"""<parameter=location>Barcelona, Spain</parameter>
\n
</function>
\n
</seed:tool_call>"""
,
[
ToolCall
(
function
=
FunctionCall
(
name
=
"get_weather"
,
arguments
=
json
.
dumps
({
"location"
:
"Barcelona, Spain"
,
},
),
),
type
=
'function'
)
],
None
),
(
"""<seed:think>The user
\'
s current thinking budget is 512.</seed:cot_budget_reflect>
\n
Let me analyze the """
"""question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
"""there
\'
s a get_weather function that can retrieve the current temperature for a given location.
\n\n
First, """
"""check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """
"""optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """
"""country).
\n
<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use."""
"""</seed:cot_budget_reflect>
\n
Since the unit isn
\'
t specified, the function will default to Celsius, which """
"""is fine.
\n\n
There
\'
s no need to ask for more information because the location is clear. So I should call """
"""the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """
"""user
\'
s input has a space, but the function might accept either; to be safe, using the standard format """
"""with a comma).
\n
<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """
"""use.</seed:cot_budget_reflect>
\n
The unit parameter can be omitted since it
\'
s optional.</seed:think>
\n
"""
"""<seed:tool_call>
\n
<function=get_weather>
\n
<parameter=location>Barcelona, Spain</parameter>
\n
</function>"""
"""
\n
</seed:tool_call>"""
,
[
ToolCall
(
function
=
FunctionCall
(
name
=
"get_weather"
,
arguments
=
json
.
dumps
({
"location"
:
"Barcelona, Spain"
,
},
),
),
type
=
'function'
)
],
"""<seed:think>The user
\'
s current thinking budget is 512.</seed:cot_budget_reflect>
\n
Let me analyze the """
"""question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
"""there
\'
s a get_weather function that can retrieve the current temperature for a given location.
\n\n
First, """
"""check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """
"""optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """
"""country).
\n
<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use."""
"""</seed:cot_budget_reflect>
\n
Since the unit isn
\'
t specified, the function will default to Celsius, which """
"""is fine.
\n\n
There
\'
s no need to ask for more information because the location is clear. So I should call """
"""the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """
"""user
\'
s input has a space, but the function might accept either; to be safe, using the standard format """
"""with a comma).
\n
<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """
"""use.</seed:cot_budget_reflect>
\n
The unit parameter can be omitted since it
\'
s optional.</seed:think>
\n
"""
,
),
(
"""<seed:think>
\n
Got it, let
\'
s see. The user asked for the weather in Barcelona, Spain. """
"""First, I need to remember the function I can use: get_weather. The function requires a """
"""location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """
"""the user didn
\'
t specify the unit, the default in the function is Celsius, right? Wait, """
"""let me check the function docstring again. Oh, the function says unit is optional, and """
"""returns temperature in Celsius. So I should call get_weather with location "Barcelona, """
"""Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """
"""The format is <seed:tool_call>
\n
<function=get_weather>
\n
<parameter=location>Barcelona, """
"""Spain</parameter>
\n
<parameter=unit>celsius</parameter>
\n
</function>
\n
</seed:tool_call>. """
"""Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """
"""of temperature, but the return is in Celsius anyway. Maybe even if I don
\'
t pass unit, """
"""it
\'
s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """
"""call should be as above. Then wait for the result to come back and tell the user the """
"""temperature in Celsius.</seed:think><seed:tool_call>
\n
<function=get_weather>
\n
<parameter=location>"""
"""Barcelona, Spain</parameter>
\n
<parameter=unit>celsius</parameter>
\n
</function>
\n
</seed:tool_call>"""
,
[
ToolCall
(
function
=
FunctionCall
(
name
=
"get_weather"
,
arguments
=
json
.
dumps
(
{
"location"
:
"Barcelona, Spain"
,
"unit"
:
"celsius"
,
},
),
),
type
=
'function'
)
],
"""<seed:think>
\n
Got it, let
\'
s see. The user asked for the weather in Barcelona, Spain. """
"""First, I need to remember the function I can use: get_weather. The function requires a """
"""location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """
"""the user didn
\'
t specify the unit, the default in the function is Celsius, right? Wait, """
"""let me check the function docstring again. Oh, the function says unit is optional, and """
"""returns temperature in Celsius. So I should call get_weather with location "Barcelona, """
"""Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """
"""The format is <seed:tool_call>
\n
<function=get_weather>
\n
<parameter=location>Barcelona, """
"""Spain</parameter>
\n
<parameter=unit>celsius</parameter>
\n
</function>
\n
</seed:tool_call>. """
"""Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """
"""of temperature, but the return is in Celsius anyway. Maybe even if I don
\'
t pass unit, """
"""it
\'
s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """
"""call should be as above. Then wait for the result to come back and tell the user the """
"""temperature in Celsius.</seed:think>"""
,
),
],
)
def
test_extract_tool_calls
(
seed_oss_tool_parser
,
sample_tools
,
model_output
,
expected_tool_calls
,
expected_content
):
request
=
ChatCompletionRequest
(
model
=
MODEL
,
messages
=
[],
tools
=
sample_tools
)
extracted_tool_calls
=
seed_oss_tool_parser
.
extract_tool_calls
(
model_output
,
request
=
request
)
# type: ignore[arg-type]
assert
extracted_tool_calls
.
tools_called
assert_tool_calls
(
extracted_tool_calls
.
tool_calls
,
expected_tool_calls
)
assert
extracted_tool_calls
.
content
==
expected_content
def
test_streaming_tool_calls_no_tools
(
seed_oss_tool_parser
):
model_output
=
"This is a test response without any tool calls"
result
=
seed_oss_tool_parser
.
extract_tool_calls_streaming
(
previous_text
=
"his is a test response"
,
current_text
=
model_output
,
delta_text
=
" without any tool calls."
,
previous_token_ids
=
[],
current_token_ids
=
[],
delta_token_ids
=
[],
request
=
None
,
)
# Should return the delta text as content
assert
result
is
not
None
assert
hasattr
(
result
,
'content'
)
assert
result
.
content
==
" without any tool calls."
def
stream_delta_message_generator
(
seed_oss_tool_parser
:
SeedOssToolParser
,
seed_oss_tokenizer
:
AnyTokenizer
,
model_output
:
str
,
request
:
Optional
[
ChatCompletionRequest
]
=
None
)
->
Generator
[
DeltaMessage
,
None
,
None
]:
all_token_ids
=
seed_oss_tokenizer
.
encode
(
model_output
,
add_special_tokens
=
False
)
previous_text
=
""
previous_tokens
=
None
prefix_offset
=
0
read_offset
=
0
for
i
,
delta_token
in
enumerate
(
all_token_ids
):
delta_token_ids
=
[
delta_token
]
previous_token_ids
=
all_token_ids
[:
i
]
current_token_ids
=
all_token_ids
[:
i
+
1
]
(
new_tokens
,
delta_text
,
new_prefix_offset
,
new_read_offset
)
=
detokenize_incrementally
(
tokenizer
=
seed_oss_tokenizer
,
all_input_ids
=
current_token_ids
,
prev_tokens
=
previous_tokens
,
prefix_offset
=
prefix_offset
,
read_offset
=
read_offset
,
skip_special_tokens
=
False
,
spaces_between_special_tokens
=
True
,
)
current_text
=
previous_text
+
delta_text
delta_message
=
seed_oss_tool_parser
.
extract_tool_calls_streaming
(
previous_text
,
current_text
,
delta_text
,
previous_token_ids
,
current_token_ids
,
delta_token_ids
,
request
=
request
,
)
if
delta_message
:
yield
delta_message
previous_text
=
current_text
previous_tokens
=
(
previous_tokens
+
new_tokens
if
previous_tokens
else
new_tokens
)
prefix_offset
=
new_prefix_offset
read_offset
=
new_read_offset
@
pytest
.
mark
.
parametrize
(
ids
=
[
"tool_call_0_thinking_budget"
,
"tool_call_512_thinkg_budget"
,
"tool_call_unlimited_thinking_budget"
,
],
argnames
=
[
"model_output"
,
"expected_tool_calls"
,
"expected_content"
],
argvalues
=
[
(
"""<seed:think>
\n
</seed:cot_budget_reflect>
\n
</seed:cot_budget_reflect>
\n
"""
"""The current thinking budget is 0, so I will directly start answering the question.
\n
</seed:think>
\n
"""
"""<seed:tool_call>
\n
<function=get_weather>
\n
"""
"""<parameter=location>Barcelona, Spain</parameter>
\n
</function>
\n
</seed:tool_call>"""
,
[
ToolCall
(
function
=
FunctionCall
(
name
=
"get_weather"
,
arguments
=
json
.
dumps
({
"location"
:
"Barcelona, Spain"
,
},
),
),
type
=
'function'
)
],
"""<seed:think>
\n
</seed:cot_budget_reflect>
\n
</seed:cot_budget_reflect>
\n
"""
"""The current thinking budget is 0, so I will directly start answering the question.
\n
</seed:think>
\n
"""
),
(
"""<seed:think>The user
\'
s current thinking budget is 512.</seed:cot_budget_reflect>
\n
Let me analyze the """
"""question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
"""there
\'
s a get_weather function that can retrieve the current temperature for a given location.
\n\n
First, """
"""check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """
"""optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """
"""country).
\n
<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use."""
"""</seed:cot_budget_reflect>
\n
Since the unit isn
\'
t specified, the function will default to Celsius, which """
"""is fine.
\n\n
There
\'
s no need to ask for more information because the location is clear. So I should call """
"""the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """
"""user
\'
s input has a space, but the function might accept either; to be safe, using the standard format """
"""with a comma).
\n
<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """
"""use.</seed:cot_budget_reflect>
\n
The unit parameter can be omitted since it
\'
s optional.</seed:think>
\n
"""
"""<seed:tool_call>
\n
<function=get_weather>
\n
<parameter=location>Barcelona, Spain</parameter>
\n
</function>"""
"""
\n
</seed:tool_call>"""
,
[
ToolCall
(
function
=
FunctionCall
(
name
=
"get_weather"
,
arguments
=
json
.
dumps
({
"location"
:
"Barcelona, Spain"
,
},
),
),
type
=
'function'
)
],
"""<seed:think>The user
\'
s current thinking budget is 512.</seed:cot_budget_reflect>
\n
Let me analyze the """
"""question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
"""there
\'
s a get_weather function that can retrieve the current temperature for a given location.
\n\n
First, """
"""check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """
"""optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """
"""country).
\n
<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use."""
"""</seed:cot_budget_reflect>
\n
Since the unit isn
\'
t specified, the function will default to Celsius, which """
"""is fine.
\n\n
There
\'
s no need to ask for more information because the location is clear. So I should call """
"""the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """
"""user
\'
s input has a space, but the function might accept either; to be safe, using the standard format """
"""with a comma).
\n
<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """
"""use.</seed:cot_budget_reflect>
\n
The unit parameter can be omitted since it
\'
s optional.</seed:think>
\n
"""
,
),
(
"""<seed:think>
\n
Got it, let
\'
s see. The user asked for the weather in Barcelona, Spain. """
"""First, I need to remember the function I can use: get_weather. The function requires a """
"""location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """
"""the user didn
\'
t specify the unit, the default in the function is Celsius, right? Wait, """
"""let me check the function docstring again. Oh, the function says unit is optional, and """
"""returns temperature in Celsius. So I should call get_weather with location "Barcelona, """
"""Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """
"""The format is <seed:tool_call>
\n
<function=get_weather>
\n
<parameter=location>Barcelona, """
"""Spain</parameter>
\n
<parameter=unit>celsius</parameter>
\n
</function>
\n
</seed:tool_call>. """
"""Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """
"""of temperature, but the return is in Celsius anyway. Maybe even if I don
\'
t pass unit, """
"""it
\'
s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """
"""call should be as above. Then wait for the result to come back and tell the user the """
"""temperature in Celsius.</seed:think><seed:tool_call>
\n
<function=get_weather>
\n
<parameter=location>"""
"""Barcelona, Spain</parameter>
\n
<parameter=unit>celsius</parameter>
\n
</function>
\n
</seed:tool_call>"""
,
[
ToolCall
(
function
=
FunctionCall
(
name
=
"get_weather"
,
arguments
=
json
.
dumps
(
{
"location"
:
"Barcelona, Spain"
,
"unit"
:
"celsius"
,
},
),
),
type
=
'function'
)
],
"""<seed:think>
\n
Got it, let
\'
s see. The user asked for the weather in Barcelona, Spain. """
"""First, I need to remember the function I can use: get_weather. The function requires a """
"""location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """
"""the user didn
\'
t specify the unit, the default in the function is Celsius, right? Wait, """
"""let me check the function docstring again. Oh, the function says unit is optional, and """
"""returns temperature in Celsius. So I should call get_weather with location "Barcelona, """
"""Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """
"""The format is <seed:tool_call>
\n
<function=get_weather>
\n
<parameter=location>Barcelona, """
"""Spain</parameter>
\n
<parameter=unit>celsius</parameter>
\n
</function>
\n
</seed:tool_call>. """
"""Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """
"""of temperature, but the return is in Celsius anyway. Maybe even if I don
\'
t pass unit, """
"""it
\'
s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """
"""call should be as above. Then wait for the result to come back and tell the user the """
"""temperature in Celsius.</seed:think>"""
,
),
],
)
def
test_streaming_tool_calls
(
seed_oss_tool_parser
,
seed_oss_tokenizer
,
sample_tools
,
model_output
,
expected_tool_calls
,
expected_content
):
"""Test incremental streaming behavior"""
request
=
ChatCompletionRequest
(
model
=
MODEL
,
messages
=
[],
tools
=
sample_tools
)
other_content
=
''
tool_states
=
{}
# Track state per tool index
for
delta_message
in
stream_delta_message_generator
(
seed_oss_tool_parser
,
seed_oss_tokenizer
,
model_output
,
request
):
# role should never be streamed from tool parser
assert
not
delta_message
.
role
if
delta_message
.
content
:
other_content
+=
delta_message
.
content
if
delta_message
.
tool_calls
:
for
tool_call
in
delta_message
.
tool_calls
:
idx
=
tool_call
.
index
# Initialize state for new tool
if
idx
not
in
tool_states
:
tool_states
[
idx
]
=
{
"id"
:
None
,
"name"
:
None
,
"arguments"
:
""
,
"type"
:
None
}
# First chunk should have id, name, and type
if
tool_call
.
id
:
tool_states
[
idx
][
"id"
]
=
tool_call
.
id
if
tool_call
.
type
:
assert
tool_call
.
type
==
"function"
tool_states
[
idx
][
"type"
]
=
tool_call
.
type
if
tool_call
.
function
:
if
tool_call
.
function
.
name
:
# Should only be set once
assert
tool_states
[
idx
][
"name"
]
is
None
tool_states
[
idx
][
"name"
]
=
tool_call
.
function
.
name
if
tool_call
.
function
.
arguments
is
not
None
:
# Accumulate arguments incrementally
tool_states
[
idx
][
"arguments"
]
+=
tool_call
.
function
.
arguments
# Verify final content
assert
other_content
==
expected_content
# Verify we got all expected tool calls
assert
len
(
tool_states
)
==
len
(
expected_tool_calls
)
# Verify each tool call
for
idx
,
expected_tool
in
enumerate
(
expected_tool_calls
):
state
=
tool_states
[
idx
]
assert
state
[
"id"
]
is
not
None
assert
state
[
"type"
]
==
"function"
assert
state
[
"name"
]
==
expected_tool
.
function
.
name
# Parse accumulated arguments
arguments_str
=
state
[
"arguments"
]
assert
arguments_str
is
not
None
actual_args
=
json
.
loads
(
arguments_str
)
expected_args
=
json
.
loads
(
expected_tool
.
function
.
arguments
)
assert
actual_args
==
expected_args
Prev
1
…
12
13
14
15
16
17
18
19
20
…
26
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment