Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
711aa9d5
Commit
711aa9d5
authored
Jul 30, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.10.0' into v0.10.0-dev
parents
751c492c
6d8d0a24
Changes
519
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1389 additions
and
1593 deletions
+1389
-1593
tests/samplers/test_no_bad_words.py
tests/samplers/test_no_bad_words.py
+6
-6
tests/samplers/test_seeded_generate.py
tests/samplers/test_seeded_generate.py
+1
-1
tests/samplers/test_typical_acceptance_sampler.py
tests/samplers/test_typical_acceptance_sampler.py
+0
-480
tests/spec_decode/conftest.py
tests/spec_decode/conftest.py
+0
-12
tests/spec_decode/e2e/conftest.py
tests/spec_decode/e2e/conftest.py
+0
-307
tests/spec_decode/test_batch_expansion.py
tests/spec_decode/test_batch_expansion.py
+0
-110
tests/spec_decode/test_dynamic_spec_decode.py
tests/spec_decode/test_dynamic_spec_decode.py
+0
-90
tests/spec_decode/test_memory_usage.py
tests/spec_decode/test_memory_usage.py
+0
-91
tests/spec_decode/test_utils.py
tests/spec_decode/test_utils.py
+0
-150
tests/spec_decode/utils.py
tests/spec_decode/utils.py
+0
-290
tests/tensorizer_loader/conftest.py
tests/tensorizer_loader/conftest.py
+85
-0
tests/tensorizer_loader/test_tensorizer.py
tests/tensorizer_loader/test_tensorizer.py
+303
-40
tests/test_config.py
tests/test_config.py
+81
-10
tests/test_sequence.py
tests/test_sequence.py
+0
-1
tests/test_utils.py
tests/test_utils.py
+80
-3
tests/tokenization/test_detokenize.py
tests/tokenization/test_detokenize.py
+1
-1
tests/tokenization/test_do_lower_case.py
tests/tokenization/test_do_lower_case.py
+18
-0
tests/tool_use/test_glm4_moe_tool_parser.py
tests/tool_use/test_glm4_moe_tool_parser.py
+3
-1
tests/tool_use/test_kimi_k2_tool_parser.py
tests/tool_use/test_kimi_k2_tool_parser.py
+193
-0
tests/tool_use/test_qwen3coder_tool_parser.py
tests/tool_use/test_qwen3coder_tool_parser.py
+618
-0
No files found.
Too many changes to show.
To preserve performance only
519 of 519+
files are displayed.
Plain diff
Email patch
tests/samplers/test_no_bad_words.py
View file @
711aa9d5
...
...
@@ -21,7 +21,7 @@ def v1(run_with_both_engines):
def
_generate
(
model
:
LLM
,
llm
:
LLM
,
prompt
:
str
,
num_prompt_tokens
:
int
,
temperature
:
float
=
0
,
...
...
@@ -33,7 +33,7 @@ def _generate(
)
# [([output_token_ids, ], [output_text, ]), ]
output
=
model
.
generate
([
prompt
],
sampling_params
=
sampling_params
)
output
=
llm
.
generate
([
prompt
],
sampling_params
=
sampling_params
)
output_token_ids
=
output
[
0
][
0
][
0
][
num_prompt_tokens
:]
# [0] first (and only) request output
...
...
@@ -68,10 +68,10 @@ class TestOneTokenBadWord:
assert
self
.
target_token_id
not
in
output_token_ids
def
_generate
(
self
,
model
:
LLM
,
llm
:
LLM
,
bad_words
:
Optional
[
list
[
str
]]
=
None
)
->
list
[
int
]:
return
_generate
(
model
=
model
,
llm
=
llm
,
prompt
=
self
.
PROMPT
,
num_prompt_tokens
=
self
.
num_prompt_tokens
,
bad_words
=
bad_words
,
...
...
@@ -158,10 +158,10 @@ class TestTwoTokenBadWord:
or
(
self
.
neighbour_token_id2
in
output_token_ids
))
def
_generate
(
self
,
model
:
LLM
,
llm
:
LLM
,
bad_words
:
Optional
[
list
[
str
]]
=
None
)
->
list
[
int
]:
return
_generate
(
model
=
model
,
llm
=
llm
,
prompt
=
self
.
PROMPT
,
num_prompt_tokens
=
self
.
num_prompt_tokens
,
bad_words
=
bad_words
,
...
...
tests/samplers/test_seeded_generate.py
View file @
711aa9d5
...
...
@@ -51,7 +51,7 @@ def test_random_sample_with_seed(
sampling_params_seed_2
=
copy
.
deepcopy
(
sampling_params
)
sampling_params_seed_2
.
seed
=
200
llm
=
vllm_model
.
model
llm
=
vllm_model
.
llm
for
prompt
in
example_prompts
:
for
params
in
(
...
...
tests/samplers/test_typical_acceptance_sampler.py
deleted
100644 → 0
View file @
751c492c
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for rejection sampling."""
import
pytest
import
torch
from
vllm.model_executor.layers.typical_acceptance_sampler
import
(
TypicalAcceptanceSampler
)
from
vllm.model_executor.utils
import
set_random_seed
CUDA_DEVICES
=
[
f
"cuda:
{
i
}
"
for
i
in
range
(
1
)]
@
pytest
.
fixture
(
scope
=
"function"
,
autouse
=
True
)
def
use_v0_only
(
monkeypatch
):
"""
This file tests V0 internals, so set VLLM_USE_V1=0.
"""
monkeypatch
.
setenv
(
'VLLM_USE_V1'
,
'0'
)
def
get_zero_temperature_prob_dist
(
batch_size
,
k
,
vocab_size
):
"""
Generates a fake temperature zero probability distribution.
Returns:
1. A fake temperature zero probability distribution of shape
[batch_size, k, vocab_size]
2. Tensor of shape [batch_size, k] containing the token ids
of the probability 1.0 tokens at each position.
"""
# Simulate temperature 0 probability distribution for target probabilities
# and create target probabilities such that only 1 token id has
# probability 1.0
target_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
)
probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
)
_
,
zero_temperature_token_ids
=
torch
.
max
(
probs
,
dim
=-
1
)
# set the probability of the tokens with ids in zero_temperature_token_ids
# to 1 and the rest to 0.
target_probs
=
torch
.
zeros_like
(
probs
).
scatter_
(
-
1
,
zero_temperature_token_ids
.
unsqueeze
(
-
1
),
1.0
)
return
target_probs
,
zero_temperature_token_ids
def
get_draft_token_ids
(
batch_size
:
int
,
k
:
int
,
vocab_size
:
int
,
token_ids_to_exclude
:
torch
.
Tensor
):
"""
Returns a tensor of shape [batch_size, k] of fake draft token ids
drawn randomly from a vocab of size vocab_size. We however ensure
that token_ids from token_ids_to_exclude are excluded at the
corresponding positions.
"""
draft_token_ids
=
torch
.
empty
(
batch_size
,
k
,
dtype
=
torch
.
long
)
for
i
in
range
(
batch_size
):
for
j
in
range
(
k
):
# Generate a random token ID excluding token_ids_to_exclude[i, j]
while
True
:
token_id
=
torch
.
randint
(
0
,
vocab_size
,
(
1
,
)).
item
()
if
token_id
!=
token_ids_to_exclude
[
i
,
j
]:
draft_token_ids
[
i
,
j
]
=
token_id
break
return
draft_token_ids
def
get_acceptance_sampler
(
posterior_threshold
:
float
=
0.03
,
posterior_alpha
:
float
=
0.9
,
strict_mode
:
bool
=
False
,
)
->
TypicalAcceptanceSampler
:
"""
Initializes and returns a TypicalAcceptanceSampler.
"""
return
TypicalAcceptanceSampler
(
posterior_threshold
,
posterior_alpha
,
strict_mode
)
@
pytest
.
mark
.
parametrize
(
"k"
,
list
(
range
(
1
,
6
)))
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
30_000
,
50_000
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
list
(
range
(
1
,
32
)))
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_no_crash_with_varying_dims
(
k
:
int
,
vocab_size
:
int
,
batch_size
:
int
,
device
:
str
):
"""
Tests that the TypicalAcceptancSampler forward succeeds for
different combinations of k, vocab_size, batch_size and num devices.
"""
torch
.
set_default_device
(
device
)
typical_acceptance_sampler
=
get_acceptance_sampler
()
typical_acceptance_sampler
.
init_gpu_tensors
(
device
=
device
)
target_with_bonus_probs
=
torch
.
rand
(
batch_size
,
k
+
1
,
vocab_size
,
dtype
=
torch
.
float32
)
bonus_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
1
),
dtype
=
torch
.
int64
)
draft_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
k
),
dtype
=
torch
.
int64
)
# Verify that sampling succeeds for all cases.
typical_acceptance_sampler
(
target_with_bonus_probs
,
bonus_token_ids
,
draft_probs
=
None
,
draft_token_ids
=
draft_token_ids
)
@
pytest
.
mark
.
parametrize
(
"above_or_below_vocab_range"
,
[
"above"
,
"below"
])
@
pytest
.
mark
.
parametrize
(
"which_token_ids"
,
[
"bonus_token_ids"
,
"draft_token_ids"
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_raises_when_vocab_oob
(
above_or_below_vocab_range
:
str
,
which_token_ids
:
str
,
device
:
str
):
"""
Tests that we throw an exception of the token ids fall outside
the bound of the provided vocabulary.
"""
k
=
3
batch_size
=
5
vocab_size
=
30_000
torch
.
set_default_device
(
device
)
typical_acceptance_sampler
=
get_acceptance_sampler
(
strict_mode
=
True
)
typical_acceptance_sampler
.
init_gpu_tensors
(
device
=
device
)
target_with_bonus_probs
=
torch
.
rand
(
batch_size
,
k
+
1
,
vocab_size
,
dtype
=
torch
.
float32
)
bonus_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
1
),
dtype
=
torch
.
int64
)
draft_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
k
),
dtype
=
torch
.
int64
)
# Verify that appropriate exceptions are thrown for out
# of bound vocabs.
oob_token_ids
=
None
if
which_token_ids
==
"bonus_token_ids"
:
oob_token_ids
=
bonus_token_ids
elif
which_token_ids
==
"draft_token_ids"
:
oob_token_ids
=
draft_token_ids
else
:
raise
AssertionError
()
if
above_or_below_vocab_range
==
"above"
:
rogue_token_id
=
vocab_size
+
1
elif
above_or_below_vocab_range
==
"below"
:
rogue_token_id
=
-
1
else
:
raise
AssertionError
()
oob_token_ids
[
0
][
0
]
=
rogue_token_id
with
pytest
.
raises
(
AssertionError
):
typical_acceptance_sampler
(
target_with_bonus_probs
,
bonus_token_ids
,
draft_probs
=
None
,
draft_token_ids
=
draft_token_ids
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
10
)))
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_uniform_target_distribution_accepts_all_tokens
(
seed
:
int
,
device
:
str
):
"""
Test the TypicalAcceptanceSampler with a uniform target probability
distribution.
This test verifies that when provided with a uniform target probability
distribution, the TypicalAcceptanceSampler accepts all draft tokens. The
entropy of the uniform target distribution being high should lead to all
draft tokens being accepted.
"""
set_random_seed
(
seed
)
k
=
3
batch_size
=
5
vocab_size
=
30_000
torch
.
set_default_device
(
device
)
typical_acceptance_sampler
=
get_acceptance_sampler
(
strict_mode
=
True
)
typical_acceptance_sampler
.
init_gpu_tensors
(
device
=
device
)
target_with_bonus_probs
=
torch
.
rand
(
batch_size
,
k
+
1
,
vocab_size
,
dtype
=
torch
.
float32
)
draft_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
k
),
dtype
=
torch
.
int64
)
bonus_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
1
),
dtype
=
torch
.
int64
)
output_token_ids
=
typical_acceptance_sampler
(
target_with_bonus_probs
,
bonus_token_ids
,
draft_probs
=
None
,
draft_token_ids
=
draft_token_ids
)
# We are using a uniform target probability distribution.
# For a uniform distribution the entropy is very high and it
# should lead to all draft tokens being accepted. Verify that.
assert
output_token_ids
.
shape
[
0
]
==
batch_size
assert
output_token_ids
.
shape
[
1
]
==
(
k
+
1
)
assert
torch
.
all
(
output_token_ids
[:,
-
1
]
==
bonus_token_ids
.
squeeze
())
assert
torch
.
all
(
output_token_ids
[:,
:
k
]
==
draft_token_ids
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
10
)))
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_temperature_zero_target_distribution
(
seed
:
int
,
device
:
str
):
"""
Test the TypicalAcceptanceSampler with a zero-temperature target
probability distribution.
This test verifies that when using a zero-temperature target probability
distribution, where only one token has a probability of 1.0, the
TypicalAcceptanceSampler correctly rejects all draft tokens that do not
match this probability. Additionally, it ensures that when all draft
tokens are rejected, the sampler falls back to greedy sampling to select a
single token from the target distribution.
"""
set_random_seed
(
seed
)
k
=
3
batch_size
=
5
vocab_size
=
30_000
torch
.
set_default_device
(
device
)
typical_acceptance_sampler
=
get_acceptance_sampler
(
strict_mode
=
True
)
typical_acceptance_sampler
.
init_gpu_tensors
(
device
=
device
)
# Simulate temperature 0 probability distribution for target probabilities
# and create target probabilities such that only 1 token id has
# probability 1.0
target_with_bonus_probs
,
zero_temperature_token_ids
=
\
get_zero_temperature_prob_dist
(
batch_size
,
k
+
1
,
vocab_size
)
zero_temperature_token_ids
=
zero_temperature_token_ids
[:,
:
-
1
]
# Populate draft_token_ids such that they exclude the token_ids
# with probability = 1.0
draft_token_ids
=
get_draft_token_ids
(
batch_size
,
k
,
vocab_size
,
zero_temperature_token_ids
)
bonus_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
1
),
dtype
=
torch
.
int64
)
# The target probaility distribution is a temperature zero distribution
# with zero entropy. Since our draft token ids don't match the probability
# 1.0 tokens in the target distribution we will reject all of them and
# fallback to the greedy sampling for selecting 1 token for each sequence.
# Verify the same.
output_token_ids
=
typical_acceptance_sampler
(
target_with_bonus_probs
,
bonus_token_ids
,
draft_probs
=
None
,
draft_token_ids
=
draft_token_ids
)
assert
output_token_ids
.
shape
[
0
]
==
batch_size
assert
output_token_ids
.
shape
[
1
]
==
(
k
+
1
)
assert
torch
.
all
(
output_token_ids
[:,
-
1
]
==
-
1
)
assert
torch
.
all
(
output_token_ids
[:,
0
]
==
zero_temperature_token_ids
[:,
0
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
10
)))
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_mixed_target_distribution
(
seed
:
int
,
device
:
str
):
"""
Test the TypicalAcceptanceSampler with a mixed target probability
distribution.
This test ensures that the TypicalAcceptanceSampler handles a mixed
target probability distribution correctly. Specifically, it uses a
zero-temperature distribution for some sequences and a uniform
distribution for others. The test verifies that:
- For sequences with a zero-temperature distribution, only the token
with a probability of 1.0 is accepted, and all other tokens are rejected.
- For sequences with a uniform distribution, all draft tokens are
accepted.
"""
set_random_seed
(
seed
)
k
=
3
batch_size
=
4
vocab_size
=
30_000
torch
.
set_default_device
(
device
)
typical_acceptance_sampler
=
get_acceptance_sampler
(
strict_mode
=
True
)
typical_acceptance_sampler
.
init_gpu_tensors
(
device
=
device
)
# For sequences 0 and 2 set the distribution to a temperature
# zero distribution. For sequences 1 and 3 set it to a uniform
# distribution.
target_with_bonus_probs
,
zero_temperature_token_ids
=
\
get_zero_temperature_prob_dist
(
batch_size
,
k
+
1
,
vocab_size
)
zero_temperature_token_ids
=
zero_temperature_token_ids
[:,
:
-
1
]
target_probs
=
target_with_bonus_probs
[:,
:
-
1
]
draft_token_ids
=
get_draft_token_ids
(
batch_size
,
k
,
vocab_size
,
zero_temperature_token_ids
)
uniform_probs
=
torch
.
rand
(
2
,
k
,
vocab_size
,
dtype
=
torch
.
float32
)
target_probs
[[
1
,
3
]]
=
uniform_probs
bonus_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
1
),
dtype
=
torch
.
int64
)
output_token_ids
=
typical_acceptance_sampler
(
target_with_bonus_probs
,
bonus_token_ids
,
draft_probs
=
None
,
draft_token_ids
=
draft_token_ids
)
# verify the shape of output_token_ids
assert
output_token_ids
.
shape
[
0
]
==
batch_size
assert
output_token_ids
.
shape
[
1
]
==
(
k
+
1
)
# For sequences 0 and 2 verify that only 1 token is accepted
# which is the token with probability 1.0 in the target distribution
# at position 0.
assert
torch
.
all
(
output_token_ids
[[
0
,
2
],
1
:]
==
-
1
)
assert
(
torch
.
all
(
output_token_ids
[[
0
,
2
],
0
]
==
zero_temperature_token_ids
[[
0
,
2
],
0
]))
# For sequences 1 and 3 verify that all tokens are accepted since the
# target probability distribution is uniform. In addition verify that
# we also accept the bonus tokens.
assert
torch
.
all
(
output_token_ids
[[
1
,
3
],
:
-
1
]
==
draft_token_ids
[[
1
,
3
],
:])
assert
torch
.
all
(
output_token_ids
[[
1
,
3
],
-
1
]
!=
-
1
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
10
)))
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_accept_tokens_partially
(
seed
:
int
,
device
:
str
):
"""
Test the TypicalAcceptanceSampler's behavior when only a subset of draft
tokens should be accepted.
This test verifies that the TypicalAcceptanceSampler correctly accepts or
rejects draft tokens based on a zero-temperature target probability
distribution. Specifically, it ensures that:
- When all draft tokens match tokens with a probability of 1.0 in the
target distribution, all draft tokens are accepted.
- When only some draft tokens match tokens with a probability of 1.0 in
the target distribution, only those matching tokens are accepted, and the
rest are rejected.
"""
set_random_seed
(
seed
)
k
=
5
batch_size
=
1
vocab_size
=
30_000
torch
.
set_default_device
(
device
)
typical_acceptance_sampler
=
get_acceptance_sampler
(
strict_mode
=
True
)
typical_acceptance_sampler
.
init_gpu_tensors
(
device
=
device
)
# Create a temperature zero target probability distribution and ensure
# all draft token ids correspond to the tokens with 1.0 probability.
# Verify that all of them are accepted.
target_with_bonus_probs
,
zero_temperature_token_ids
=
\
get_zero_temperature_prob_dist
(
batch_size
,
k
+
1
,
vocab_size
)
zero_temperature_token_ids
=
zero_temperature_token_ids
[:,
:
-
1
]
draft_token_ids
=
zero_temperature_token_ids
bonus_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
1
),
dtype
=
torch
.
int64
)
output_token_ids
=
typical_acceptance_sampler
(
target_with_bonus_probs
,
bonus_token_ids
,
draft_probs
=
None
,
draft_token_ids
=
draft_token_ids
)
assert
output_token_ids
.
shape
[
0
]
==
batch_size
assert
output_token_ids
.
shape
[
1
]
==
(
k
+
1
)
assert
torch
.
all
(
output_token_ids
[:,
0
:
-
1
]
==
draft_token_ids
)
assert
torch
.
all
(
output_token_ids
[:,
-
1
]
==
bonus_token_ids
)
# Next only keep the first 2 draft tokens same as the zero temperature
# tokens. For the remaining 3 choose some other tokens. In the
# response we will expect the first 2 tokens to be the same as the
# draft tokens and the recovered token and rest as -1
draft_token_ids_to_replace
=
get_draft_token_ids
(
batch_size
,
k
,
vocab_size
,
zero_temperature_token_ids
)
draft_token_ids
=
torch
.
cat
(
(
draft_token_ids
[:,
:
2
],
draft_token_ids_to_replace
[:,
-
3
:]),
dim
=
1
)
output_token_ids
=
typical_acceptance_sampler
(
target_with_bonus_probs
,
bonus_token_ids
,
draft_probs
=
None
,
draft_token_ids
=
draft_token_ids
)
assert
output_token_ids
.
shape
[
0
]
==
batch_size
assert
output_token_ids
.
shape
[
1
]
==
(
k
+
1
)
assert
torch
.
all
(
output_token_ids
[:,
:
2
]
==
draft_token_ids
[:,
:
2
])
assert
torch
.
all
(
output_token_ids
[:,
2
]
==
target_with_bonus_probs
.
argmax
(
-
1
)[:,
2
])
assert
torch
.
all
(
output_token_ids
[:,
-
3
:]
==
-
1
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
1
)))
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_accept_tokens_set_non_default_posteriors
(
seed
:
int
,
device
:
str
):
"""
Test the TypicalAcceptanceSampler with custom posterior thresholds and
alpha values. This test verifies that by modifying the posterior
thresholds and alpha values we can change the acceptance behavior of the
sampler.
"""
set_random_seed
(
seed
)
k
=
5
batch_size
=
1
vocab_size
=
30_000
torch
.
set_default_device
(
device
)
typical_acceptance_sampler
=
get_acceptance_sampler
(
strict_mode
=
True
)
typical_acceptance_sampler
.
init_gpu_tensors
(
device
=
device
)
# Simulate temperature 0 probability distribution for target
# probabilities and create target probabilities such that only 1 token
# id has probability 1.0 and others have a very low probability of
# 0.00001. Populate draft_token_ids such that they exclude the token_ids
# with probability = 1.0. Without any changes to the posterior thresholds
# none of the draft tokens are accepted.
target_probs
,
zero_temperature_token_ids
=
get_zero_temperature_prob_dist
(
batch_size
,
k
+
1
,
vocab_size
)
zero_temperature_token_ids
=
zero_temperature_token_ids
[:,
:
-
1
]
target_probs
[
target_probs
==
0
]
=
0.00001
draft_token_ids
=
get_draft_token_ids
(
batch_size
,
k
,
vocab_size
,
zero_temperature_token_ids
)
bonus_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
1
),
dtype
=
torch
.
int64
)
output_token_ids
=
typical_acceptance_sampler
(
target_probs
,
bonus_token_ids
,
draft_probs
=
None
,
draft_token_ids
=
draft_token_ids
)
assert
output_token_ids
.
shape
[
0
]
==
batch_size
assert
output_token_ids
.
shape
[
1
]
==
(
k
+
1
)
assert
torch
.
all
(
output_token_ids
[:,
1
:
-
1
]
==
-
1
)
# Change the posterior threshold values to 0.0 so that we will
# now accept even draft tokens with very low probability in the
# target distribution. Simulate and verify the same.
typical_acceptance_sampler
=
TypicalAcceptanceSampler
(
strict_mode
=
True
,
posterior_threshold
=
0.0
,
posterior_alpha
=
0.0
)
typical_acceptance_sampler
.
init_gpu_tensors
(
device
=
device
)
output_token_ids
=
typical_acceptance_sampler
(
target_probs
,
bonus_token_ids
,
draft_probs
=
None
,
draft_token_ids
=
draft_token_ids
)
assert
output_token_ids
.
shape
[
0
]
==
batch_size
assert
output_token_ids
.
shape
[
1
]
==
(
k
+
1
)
assert
torch
.
all
(
output_token_ids
[:,
0
:
-
1
]
==
draft_token_ids
)
assert
torch
.
all
(
output_token_ids
[:,
-
1
]
==
bonus_token_ids
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
10
)))
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_get_recovered_token_ids
(
seed
:
int
,
device
:
str
):
"""
Test the TypicalAcceptanceSampler's method for generating
replacement token IDs.
This test verifies that the `_get_recovered_token_ids` method of the
TypicalAcceptanceSampler correctly identifies the token IDs to be used
as recovered token IDs based on the target probability distribution.
Specifically, it ensures that the method correctly identifies the
tokens with the highest probability for each sequence in the batch.
"""
set_random_seed
(
seed
)
k
=
10
batch_size
=
5
vocab_size
=
30_000
torch
.
set_default_device
(
device
)
typical_acceptance_sampler
=
get_acceptance_sampler
(
strict_mode
=
True
)
typical_acceptance_sampler
.
init_gpu_tensors
(
device
=
device
)
target_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
)
expected_replacement_tokens
=
torch
.
argmax
(
target_probs
,
dim
=-
1
)
actual_replacement_tokens
=
(
typical_acceptance_sampler
.
_get_recovered_token_ids
(
target_probs
))
assert
torch
.
all
(
expected_replacement_tokens
==
actual_replacement_tokens
)
tests/spec_decode/conftest.py
deleted
100644 → 0
View file @
751c492c
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
@
pytest
.
fixture
(
scope
=
"function"
,
autouse
=
True
)
def
use_v0_only
(
monkeypatch
):
"""
Since this module is V0 only, set VLLM_USE_V1=0 for
all tests in the module.
"""
monkeypatch
.
setenv
(
'VLLM_USE_V1'
,
'0'
)
tests/spec_decode/e2e/conftest.py
deleted
100644 → 0
View file @
751c492c
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Sequence
from
itertools
import
cycle
from
typing
import
Optional
,
Union
import
pytest
import
torch
from
vllm
import
LLM
,
SamplingParams
from
vllm.distributed
import
cleanup_dist_env_and_memory
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
PromptLogprobs
,
SampleLogprobs
from
...models.utils
import
(
TokensTextLogprobs
,
TokensTextLogprobsPromptLogprobs
,
check_logprobs_close
,
check_outputs_equal
)
from
...utils
import
RemoteOpenAIServer
PROMPTS
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
"San Francisco is know for its"
,
"Facebook was created in 2004 by"
,
"Curious George is a"
,
"Python 3.11 brings improvements to its"
,
]
@
pytest
.
fixture
def
test_llm_generator
(
common_llm_kwargs
,
per_test_common_llm_kwargs
,
test_llm_kwargs
,
seed
):
def
generate
():
kwargs
=
{
**
common_llm_kwargs
,
**
per_test_common_llm_kwargs
,
**
test_llm_kwargs
,
}
llm
=
LLM
(
**
kwargs
)
if
seed
is
not
None
:
set_random_seed
(
seed
)
yield
llm
del
llm
cleanup_dist_env_and_memory
()
return
generate
def
maybe_assert_ngram_worker
(
llm
):
# Verify the proposer worker is ngram if ngram is specified.
if
(
llm
.
llm_engine
.
speculative_config
is
not
None
and
llm
.
llm_engine
.
speculative_config
.
method
==
"ngram"
):
from
vllm.spec_decode.ngram_worker
import
NGramWorker
assert
isinstance
(
llm
.
llm_engine
.
model_executor
.
driver_worker
.
proposer_worker
,
NGramWorker
)
def
get_output_from_llm_generator
(
llm_generator
,
prompts
,
sampling_params
)
->
tuple
[
list
[
str
],
list
[
list
[
int
]],
float
]:
tokens
:
list
[
str
]
=
[]
token_ids
:
list
[
list
[
int
]]
=
[]
acceptance_rate
:
float
=
-
1.0
for
llm
in
llm_generator
():
maybe_assert_ngram_worker
(
llm
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
,
use_tqdm
=
True
)
token_ids
=
[
output
.
outputs
[
0
].
token_ids
for
output
in
outputs
]
tokens
=
[
output
.
outputs
[
0
].
text
for
output
in
outputs
]
# Fetch acceptance rate if logging is enabled.
if
stat_loggers
:
=
getattr
(
llm
.
llm_engine
,
"stat_loggers"
,
None
):
stat_logger
=
stat_loggers
[
"prometheus"
]
acceptance_rate
=
(
stat_logger
.
metrics
.
gauge_spec_decode_draft_acceptance_rate
.
labels
(
**
stat_logger
.
labels
).
_value
.
get
())
del
llm
return
tokens
,
token_ids
,
acceptance_rate
def
check_logprobs_correctness
(
spec_outputs
:
Sequence
[
Union
[
TokensTextLogprobs
,
TokensTextLogprobsPromptLogprobs
]],
baseline_outputs
:
Sequence
[
Union
[
TokensTextLogprobs
,
TokensTextLogprobsPromptLogprobs
]],
disable_logprobs
:
bool
=
False
,
):
"""Compare sampled and prompt logprobs between baseline and spec decoding
"""
if
not
disable_logprobs
:
return
check_logprobs_close
(
outputs_0_lst
=
baseline_outputs
,
outputs_1_lst
=
spec_outputs
,
name_0
=
"org"
,
name_1
=
"sd"
,
)
# Check correctness when disable_logprobs == True
for
spec_output
,
baseline_output
in
zip
(
spec_outputs
,
baseline_outputs
):
# Check generated token logprobs.
spec_logprobs
=
spec_output
[
2
]
baseline_logprobs
=
baseline_output
[
2
]
_check_logprobs_when_output_disabled
(
spec_logprobs
,
baseline_logprobs
,
is_prompt_logprobs
=
False
)
# Check prompt logprobs too, if they exist
if
len
(
baseline_output
)
==
4
:
assert
len
(
spec_output
)
==
4
spec_prompt_logprobs
=
spec_output
[
3
]
baseline_prompt_logprobs
=
baseline_output
[
3
]
_check_logprobs_when_output_disabled
(
spec_prompt_logprobs
,
baseline_prompt_logprobs
,
is_prompt_logprobs
=
True
)
def
_check_logprobs_when_output_disabled
(
spec_logprobs
:
Union
[
Optional
[
PromptLogprobs
],
SampleLogprobs
],
baseline_logprobs
:
Union
[
Optional
[
PromptLogprobs
],
SampleLogprobs
],
is_prompt_logprobs
:
bool
=
False
,
):
# Prompt logprobs are optional
if
is_prompt_logprobs
and
baseline_logprobs
is
None
:
assert
spec_logprobs
is
None
return
assert
spec_logprobs
is
not
None
assert
baseline_logprobs
is
not
None
assert
len
(
spec_logprobs
)
==
len
(
baseline_logprobs
)
# For each generated position of the sequence.
for
pos
,
(
spec_pos_logprobs
,
baseline_pos_logprobs
)
in
enumerate
(
zip
(
spec_logprobs
,
baseline_logprobs
)):
# First prompt logprob is expected to be None
if
is_prompt_logprobs
and
baseline_pos_logprobs
is
None
:
assert
spec_pos_logprobs
is
None
assert
pos
==
0
continue
assert
spec_pos_logprobs
is
not
None
assert
baseline_pos_logprobs
is
not
None
# When disabled, the 1 logprob is returned with dummy values for the
# score and rank, but the token id should match the baseline model
assert
len
(
spec_pos_logprobs
)
==
1
(
spec_pos_logprob_token_id
,
spec_pos_logprob
)
=
next
(
iter
(
spec_pos_logprobs
.
items
()))
assert
spec_pos_logprob
.
rank
==
-
1
assert
spec_pos_logprob
.
logprob
==
0.0
if
isinstance
(
spec_pos_logprob_token_id
,
torch
.
Tensor
):
spec_pos_logprob_token_id
=
spec_pos_logprob_token_id
.
item
()
assert
spec_pos_logprob_token_id
in
baseline_pos_logprobs
def
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
max_output_len
:
int
,
seed
:
Optional
[
int
]
=
0
,
temperature
:
float
=
0.0
,
disable_seed
:
bool
=
False
,
ignore_eos
:
bool
=
True
,
ensure_all_accepted
:
bool
=
False
,
expected_acceptance_rate
:
Optional
[
float
]
=
None
,
logprobs
:
Optional
[
int
]
=
None
,
prompt_logprobs
:
Optional
[
int
]
=
None
,
disable_logprobs
:
bool
=
False
):
org_args
=
{
**
common_llm_kwargs
,
**
per_test_common_llm_kwargs
,
**
baseline_llm_kwargs
,
}
sd_args
=
{
**
common_llm_kwargs
,
**
per_test_common_llm_kwargs
,
**
test_llm_kwargs
,
}
prompts
=
[
prompt
for
prompt
,
_
in
zip
(
cycle
(
PROMPTS
),
range
(
batch_size
))]
if
disable_seed
:
seed
=
None
sampling_params
=
SamplingParams
(
temperature
=
temperature
,
max_tokens
=
max_output_len
,
seed
=
seed
,
ignore_eos
=
ignore_eos
,
logprobs
=
logprobs
,
prompt_logprobs
=
prompt_logprobs
)
with
vllm_runner
(
**
org_args
)
as
vllm_model
:
org_outputs
=
vllm_model
.
generate_w_logprobs
(
prompts
,
sampling_params
)
with
vllm_runner
(
**
sd_args
)
as
vllm_model
:
if
ensure_all_accepted
or
expected_acceptance_rate
is
not
None
:
# Force log interval to be 0 to catch all metrics.
stat_logger
=
vllm_model
.
model
.
llm_engine
.
stat_loggers
[
'prometheus'
]
stat_logger
.
local_interval
=
-
100
sd_outputs
=
vllm_model
.
generate_w_logprobs
(
prompts
,
sampling_params
)
if
ensure_all_accepted
or
expected_acceptance_rate
is
not
None
:
acceptance_rate
=
(
stat_logger
.
metrics
.
gauge_spec_decode_draft_acceptance_rate
.
labels
(
**
stat_logger
.
labels
).
_value
.
get
())
if
ensure_all_accepted
:
assert
True
# FIXME: ci fails to log acceptance rate.
# It works locally.
# assert acceptance_rate == 1.0
if
expected_acceptance_rate
is
not
None
:
assert
acceptance_rate
>=
expected_acceptance_rate
-
1e-2
# Only pass token entries, not the logprobs
check_outputs_equal
(
outputs_0_lst
=
[
out
[
0
:
2
]
for
out
in
org_outputs
],
outputs_1_lst
=
[
out
[
0
:
2
]
for
out
in
sd_outputs
],
name_0
=
"org"
,
name_1
=
"sd"
)
# Check logprobs if requested
if
logprobs
is
not
None
or
prompt_logprobs
is
not
None
:
check_logprobs_correctness
(
sd_outputs
,
org_outputs
,
disable_logprobs
)
def
run_equality_correctness_test_tp
(
model
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
max_output_len
:
int
,
seed
:
int
=
0
,
temperature
:
float
=
0.0
,
logprobs
:
Optional
[
int
]
=
None
):
"""Helper method that compares the outputs of both the baseline LLM and
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero.
"""
arg1
=
common_llm_kwargs
+
per_test_common_llm_kwargs
+
baseline_llm_kwargs
arg2
=
common_llm_kwargs
+
per_test_common_llm_kwargs
+
test_llm_kwargs
env1
=
env2
=
None
max_wait_seconds
=
240
results
=
[]
prompts
=
[
prompt
for
prompt
,
_
in
zip
(
cycle
(
PROMPTS
),
range
(
batch_size
))]
for
args
,
env
in
((
arg1
,
env1
),
(
arg2
,
env2
)):
with
RemoteOpenAIServer
(
model
,
args
,
env_dict
=
env
,
max_wait_seconds
=
max_wait_seconds
)
as
server
:
client
=
server
.
get_client
()
completion
=
client
.
completions
.
create
(
model
=
model
,
prompt
=
prompts
,
max_tokens
=
max_output_len
,
seed
=
seed
,
temperature
=
temperature
,
logprobs
=
logprobs
)
results
.
append
({
"test"
:
"seeded_sampling"
,
"text"
:
[
choice
.
text
for
choice
in
completion
.
choices
],
"logprobs"
:
[
choice
.
logprobs
for
choice
in
completion
.
choices
],
"finish_reason"
:
[
choice
.
finish_reason
for
choice
in
completion
.
choices
],
"usage"
:
completion
.
usage
,
})
n
=
len
(
results
)
//
2
arg1_results
=
results
[:
n
]
arg2_results
=
results
[
n
:]
# Separate logprobs to avoid asserting exact equality.
arg1_logprobs
=
[
r
.
pop
(
"logprobs"
)
for
r
in
arg1_results
]
arg2_logprobs
=
[
r
.
pop
(
"logprobs"
)
for
r
in
arg2_results
]
for
arg1_result
,
arg2_result
in
zip
(
arg1_results
,
arg2_results
):
assert
arg1_result
==
arg2_result
,
(
f
"Results for
{
model
=
}
are not the same with
{
arg1
=
}
and
{
arg2
=
}
. "
f
"
{
arg1_result
=
}
!=
{
arg2_result
=
}
"
)
if
logprobs
:
for
logs1
,
logs2
in
zip
(
arg1_logprobs
,
arg2_logprobs
):
for
l1
,
l2
in
zip
(
logs1
,
logs2
):
assert
l1
.
tokens
==
l2
.
tokens
tests/spec_decode/test_batch_expansion.py
deleted
100644 → 0
View file @
751c492c
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
import
torch
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTop1Scorer
from
.utils
import
create_seq_group_metadata_from_prompts
,
mock_worker
@
pytest
.
mark
.
parametrize
(
'num_target_seq_ids'
,
[
100
])
@
pytest
.
mark
.
skip_global_cleanup
def
test_create_target_seq_id_iterator
(
num_target_seq_ids
:
int
):
"""Verify all new sequence ids are greater than all input
seq ids.
"""
scorer
=
BatchExpansionTop1Scorer
(
mock_worker
(),
'cuda:0'
,
32_000
)
all_seq_ids
=
[
[
1
,
3
,
5
,
7
],
list
(
range
(
100
))
+
[
0
],
[
100
],
]
for
seq_ids
in
all_seq_ids
:
max_seq_id
=
max
(
seq_ids
)
iterator
=
scorer
.
_create_target_seq_id_iterator
(
seq_ids
)
# pylint: disable=protected-access
for
_
in
range
(
num_target_seq_ids
):
assert
next
(
iterator
)
>
max_seq_id
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
,
6
])
@
pytest
.
mark
.
skip_global_cleanup
def
test_get_token_ids_to_score
(
k
:
int
):
"""Verify correct tokens are selected for scoring.
"""
proposal_token_ids
=
torch
.
tensor
(
list
(
range
(
k
)),
dtype
=
torch
.
int64
,
device
=
'cuda'
,
)
expected_output
:
list
[
list
[
int
]]
=
[
[],
]
for
i
in
range
(
proposal_token_ids
.
shape
[
0
]):
expected_output
.
append
(
proposal_token_ids
[:
i
+
1
].
tolist
())
scorer
=
BatchExpansionTop1Scorer
(
mock_worker
(),
'cuda:0'
,
32_000
)
actual_output
=
scorer
.
_get_token_ids_to_score
(
proposal_token_ids
.
tolist
())
# pylint: disable=protected-access
actual_output
=
[
x
.
tolist
()
if
isinstance
(
x
,
torch
.
Tensor
)
else
x
for
x
in
actual_output
]
assert
actual_output
==
expected_output
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
,
6
])
@
pytest
.
mark
.
skip_global_cleanup
def
test_create_single_target_seq_group_metadata
(
k
:
int
):
"""Verify correct creation of a batch-expanded seq group metadata.
"""
prompt_tokens
=
[
1
,
2
,
3
]
prev_output_tokens
=
[
4
,
5
,
6
]
token_ids
=
list
(
range
(
k
))
num_tokens_processed
=
len
(
prompt_tokens
)
+
len
(
prev_output_tokens
)
-
1
final_seq_len
=
len
(
prompt_tokens
)
+
len
(
prev_output_tokens
)
+
len
(
token_ids
)
block_size
=
32
input_seq_group_metadata
=
create_seq_group_metadata_from_prompts
(
[
prompt_tokens
],
2048
//
block_size
,
block_size
,
[
final_seq_len
],
[
prev_output_tokens
],
[
num_tokens_processed
])[
0
]
input_seq_id
=
list
(
input_seq_group_metadata
.
seq_data
.
keys
())[
0
]
target_seq_id
=
100
scorer
=
BatchExpansionTop1Scorer
(
mock_worker
(),
'cuda:0'
,
32_000
)
output
=
scorer
.
_create_single_target_seq_group_metadata
(
# pylint: disable=protected-access
input_seq_group_metadata
,
input_seq_id
,
target_seq_id
,
token_ids
,
input_seq_group_metadata
.
sampling_params
,
)
assert
output
.
request_id
==
input_seq_group_metadata
.
request_id
assert
output
.
sampling_params
.
repetition_penalty
==
\
input_seq_group_metadata
.
sampling_params
.
repetition_penalty
assert
output
.
sampling_params
.
temperature
==
\
input_seq_group_metadata
.
sampling_params
.
temperature
assert
output
.
sampling_params
.
top_p
==
\
input_seq_group_metadata
.
sampling_params
.
top_p
assert
output
.
sampling_params
.
top_k
==
\
input_seq_group_metadata
.
sampling_params
.
top_k
assert
len
(
output
.
seq_data
)
==
1
assert
output
.
seq_data
[
target_seq_id
].
get_prompt_token_ids
()
==
tuple
(
prompt_tokens
)
assert
output
.
seq_data
[
target_seq_id
].
get_output_token_ids
()
==
tuple
(
prev_output_tokens
+
token_ids
)
assert
len
(
output
.
block_tables
)
==
1
assert
output
.
block_tables
[
target_seq_id
]
==
input_seq_group_metadata
.
block_tables
[
input_seq_id
]
tests/spec_decode/test_dynamic_spec_decode.py
deleted
100644 → 0
View file @
751c492c
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
unittest.mock
import
MagicMock
,
patch
import
pytest
import
torch
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.spec_decode.metrics
import
AsyncMetricsCollector
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.spec_decode_worker
import
SpecDecodeWorker
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
.test_utils
import
mock_spec_decode_sampler
from
.utils
import
create_batch
,
mock_worker
@
pytest
.
mark
.
parametrize
(
'queue_size'
,
[
4
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
1
])
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"acceptance_sampler_method"
,
[
"rejection_sampler"
,
"typical_acceptance_sampler"
])
@
torch
.
inference_mode
()
def
test_disable_spec_tokens
(
queue_size
:
int
,
batch_size
:
int
,
k
:
int
,
acceptance_sampler_method
:
str
):
"""Verify that speculative tokens are disabled when the batch size
exceeds the threshold.
"""
disable_by_batch_size
=
3
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
)
target_worker
=
mock_worker
()
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
worker
=
SpecDecodeWorker
(
proposer_worker
=
draft_worker
,
scorer_worker
=
target_worker
,
spec_decode_sampler
=
mock_spec_decode_sampler
(
acceptance_sampler_method
),
disable_logprobs
=
False
,
metrics_collector
=
metrics_collector
,
disable_by_batch_size
=
disable_by_batch_size
)
exception_secret
=
'artificial stop'
draft_worker
.
get_spec_proposals
.
side_effect
=
ValueError
(
exception_secret
)
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
)
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
,
running_queue_size
=
queue_size
)
if
queue_size
>
disable_by_batch_size
:
with
patch
.
object
(
worker
,
'_run_no_spec'
,
side_effect
=
ValueError
(
exception_secret
)),
\
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
# When the batch size is larger than the threshold,
# we expect no speculative tokens (0).
expected_num_spec_tokens
=
None
if
queue_size
<
disable_by_batch_size
else
0
assert
seq_group_metadata_list
[
0
].
num_speculative_tokens
==
expected_num_spec_tokens
draft_worker
.
sampler_output
.
side_effect
=
ValueError
(
exception_secret
)
proposer
=
Top1Proposer
(
worker
=
draft_worker
,
device
=
'cpu'
,
# not used
vocab_size
=
100
,
# not used
# Must be long enough to avoid being skipped due to length.
max_proposal_len
=
1024
,
)
if
queue_size
<
disable_by_batch_size
:
# Should raise exception when executing the mocked draft model.
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
proposer
.
get_spec_proposals
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
),
seq_ids_with_bonus_token_in_last_step
=
set
())
else
:
# Should not execute the draft model because spec decode is disabled
# for all requests. Accordingly, the proposal length should be 0.
proposals
=
proposer
.
get_spec_proposals
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
),
seq_ids_with_bonus_token_in_last_step
=
set
())
assert
proposals
.
proposal_lens
.
tolist
()
==
[
0
]
*
batch_size
tests/spec_decode/test_memory_usage.py
deleted
100644 → 0
View file @
751c492c
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""This docstring details important information on the testing methodology.
This test verifies that memory usage remains constant (or never grows) when
we enable / disable speculation via --speculative-disable-by-batch-size.
There are a lot of things we try to keep track of between batches of requests
and if certain tensors are not freed from memory, can result in CUDA ooms.
This is particularly relevant for production situations where speculation might
be enabled during off hours, but disabled once traffic peaks during the workday.
Since traffic will stay high for a long period of time, verifying we do not
increase our memory usage over time is essential to prevent possible CUDA ooms.
"""
import
torch
import
vllm
from
tests.core.utils
import
create_dummy_prompt
from
vllm.sequence
import
SequenceGroup
ITERATIONS
=
100
MAIN_MODEL
=
"JackFram/llama-68m"
# speculative model
SPEC_MODEL
=
"abhigoyal/vllm-medusa-llama-68m-random"
BATCH_SIZE
=
5
SPEC_DISABLE_BATCH_SIZE
=
2
def
add_seq_group_to_engine
(
engine
:
vllm
.
LLMEngine
,
seq_group
:
SequenceGroup
):
scheduler
=
engine
.
scheduler
[
0
]
scheduler
.
add_seq_group
(
seq_group
)
"""
Since we are using a batch size greater than the disabled batch size,
we can ensure we go through the _no_spec codepath for most of our engine steps.
"""
def
test_memory_usage_no_spec
():
previous_memory_allocated
=
None
llm
=
vllm
.
LLM
(
model
=
MAIN_MODEL
,
speculative_config
=
{
"model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
3
,
"disable_by_batch_size"
:
SPEC_DISABLE_BATCH_SIZE
,
})
batch_sequences
=
set
()
engine
=
llm
.
llm_engine
for
i
in
range
(
ITERATIONS
):
seq
,
seq_group
=
create_dummy_prompt
(
request_id
=
str
(
i
),
prompt_length
=
10
,
min_tokens
=
10
,
max_tokens
=
10
)
add_seq_group_to_engine
(
engine
,
seq_group
)
batch_sequences
.
add
(
seq
)
engine
.
step
()
for
seq
in
list
(
batch_sequences
):
if
seq
.
is_finished
():
batch_sequences
.
remove
(
seq
)
# If we aren't at our batch size yet, continue
if
len
(
batch_sequences
)
<=
BATCH_SIZE
:
continue
# Otherwise, loop until at least one request is done
while
not
any
(
seq
.
is_finished
()
for
seq
in
batch_sequences
):
engine
.
step
()
# Remove it from the set
for
seq
in
list
(
batch_sequences
):
if
seq
.
is_finished
():
batch_sequences
.
remove
(
seq
)
# At this point, we are always at the case where we have finished
# processing some number of requests from the batch after running
# several _no_spec executions. The memory should not have
# increased between the previous time this was recorded and the
# current time.
if
previous_memory_allocated
is
None
:
previous_memory_allocated
=
torch
.
cuda
.
memory_allocated
()
else
:
assert
previous_memory_allocated
==
torch
.
cuda
.
memory_allocated
()
tests/spec_decode/test_utils.py
deleted
100644 → 0
View file @
751c492c
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
unittest.mock
import
MagicMock
import
pytest
import
torch
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.model_executor.layers.sampler
import
_get_ranks
from
vllm.model_executor.layers.typical_acceptance_sampler
import
(
TypicalAcceptanceSampler
)
from
vllm.sequence
import
SequenceGroupMetadata
,
get_all_seq_ids
from
vllm.spec_decode.util
import
(
get_sampled_token_logprobs
,
split_batch_by_proposal_len
)
def
test_get_all_seq_ids
():
"""Verify get_all_seq_ids extracts all seq ids.
"""
expected_seq_ids
=
list
(
range
(
10
))
+
list
(
range
(
100
,
110
))
seq_group_metadata_list
=
[
SequenceGroupMetadata
(
request_id
=
str
(
seq_id
),
is_prompt
=
True
,
seq_data
=
{
seq_id
:
MagicMock
(),
},
sampling_params
=
MagicMock
(),
block_tables
=
{
seq_id
:
MagicMock
(),
},
lora_request
=
None
,
)
for
seq_id
in
expected_seq_ids
]
actual_seq_ids
=
get_all_seq_ids
(
seq_group_metadata_list
)
assert
actual_seq_ids
==
expected_seq_ids
@
pytest
.
fixture
def
fake_sequence_group_metadata
():
seq_ids
=
list
(
range
(
3
))
return
[
SequenceGroupMetadata
(
request_id
=
str
(
i
),
is_prompt
=
True
,
seq_data
=
{
i
:
MagicMock
(),
},
sampling_params
=
MagicMock
(),
block_tables
=
{
i
:
MagicMock
(),
},
lora_request
=
None
,
)
for
i
in
seq_ids
]
def
test_filter_zero_length_proposals
(
fake_sequence_group_metadata
):
proposal_lens
=
[
0
,
1
,
0
]
_
,
(
filtered_groups
,
indices
)
=
split_batch_by_proposal_len
(
fake_sequence_group_metadata
,
proposal_lens
)
expected_groups
=
[
fake_sequence_group_metadata
[
0
],
fake_sequence_group_metadata
[
2
]
]
expected_indices
=
[
0
,
2
]
assert
filtered_groups
==
expected_groups
assert
indices
==
expected_indices
def
test_filter_non_zero_length_proposals
(
fake_sequence_group_metadata
):
proposal_lens
=
[
0
,
1
,
2
]
(
filtered_groups
,
indices
),
_
=
split_batch_by_proposal_len
(
fake_sequence_group_metadata
,
proposal_lens
)
expected_groups
=
[
fake_sequence_group_metadata
[
1
],
fake_sequence_group_metadata
[
2
]
]
expected_indices
=
[
1
,
2
]
assert
filtered_groups
==
expected_groups
assert
indices
==
expected_indices
def
test_empty_inputs
():
_
,
(
filtered_groups
,
indices
)
=
split_batch_by_proposal_len
([],
[])
assert
filtered_groups
==
[]
assert
indices
==
[]
def
test_all_zero_with_non_zero_filter
(
fake_sequence_group_metadata
):
proposal_lens
=
[
0
,
0
,
0
]
(
filtered_groups
,
indices
),
_
=
split_batch_by_proposal_len
(
fake_sequence_group_metadata
,
proposal_lens
)
assert
filtered_groups
==
[]
assert
indices
==
[]
def
test_all_non_zero_with_zero_filter
(
fake_sequence_group_metadata
):
proposal_lens
=
[
1
,
1
,
1
]
_
,
(
filtered_groups
,
indices
)
=
split_batch_by_proposal_len
(
fake_sequence_group_metadata
,
proposal_lens
)
assert
filtered_groups
==
[]
assert
indices
==
[]
def
mock_spec_decode_sampler
(
acceptance_sampler_method
):
"""
Returns either a RejectionSampler or TypicalAcceptanceSampler
object depending on whether acceptance_sampler_method is
'rejection_sampler' or 'typical_acceptance_sampler' respectively.
"""
if
acceptance_sampler_method
==
"rejection_sampler"
:
sampler
=
MagicMock
(
spec
=
RejectionSampler
)
sampler
.
token_id_dtype
=
torch
.
int64
return
sampler
elif
acceptance_sampler_method
==
"typical_acceptance_sampler"
:
sampler
=
MagicMock
(
spec
=
TypicalAcceptanceSampler
)
sampler
.
token_id_dtype
=
torch
.
int64
return
sampler
else
:
raise
ValueError
(
f
"Invalid sampler name
{
acceptance_sampler_method
}
"
)
def
test_get_sampled_token_logprobs
():
"""Verify get_sampled_token_logprobs returns consistent rankings
with regular get_ranks when probabilities match exactly.
"""
logprob_tensor
=
torch
.
tensor
(
[[[
-
.
1
,
-
.
1
]]
*
2
])
# shape (num_steps, batch_size, vocab_size)
sampled_token_tensor
=
torch
.
tensor
([[
1
,
0
]])
# shape (num_steps, batch_size)
ranks_spec_dec
,
_
=
get_sampled_token_logprobs
(
logprob_tensor
,
sampled_token_tensor
)
ranks_regular
=
_get_ranks
(
logprob_tensor
.
reshape
((
2
,
-
1
)),
sampled_token_tensor
.
reshape
(
-
1
))
assert
torch
.
equal
(
ranks_spec_dec
.
reshape
(
-
1
),
ranks_regular
)
tests/spec_decode/utils.py
deleted
100644 → 0
View file @
751c492c
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Sequence
as
GenericSequence
from
itertools
import
count
from
typing
import
Callable
,
Optional
,
TypeVar
,
Union
from
unittest.mock
import
MagicMock
import
torch
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
Logprob
,
SequenceData
,
SequenceGroupMetadata
,
SequenceOutput
)
from
vllm.utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.model_runner
import
ModelRunner
from
vllm.worker.worker
import
Worker
T
=
TypeVar
(
"T"
,
bound
=
Worker
)
def
round_up_to_next_block
(
seq_len
:
int
,
block_size
:
int
)
->
int
:
return
(
seq_len
+
block_size
-
1
)
//
block_size
def
mock_worker
(
cls
=
None
,
vocab_size
:
int
=
30_000
,
max_model_len
:
int
=
2048
,
rank
:
int
=
0
,
use_spec
:
bool
=
True
)
->
MagicMock
:
if
cls
is
None
:
cls
=
Worker
spec
=
cls
if
use_spec
else
None
worker
=
MagicMock
(
spec
=
spec
)
worker
.
vocab_size
=
vocab_size
worker
.
max_model_len
=
max_model_len
worker
.
rank
=
rank
worker
.
device
=
'cuda:0'
return
worker
def
patch_execute_model_with_seeds
(
worker
:
Worker
,
rand_seeds
:
list
[
int
]):
seed_iter
=
iter
(
rand_seeds
)
original_execute_model
=
worker
.
execute_model
def
new_execute_model
(
*
args
,
**
kwargs
):
result
=
original_execute_model
(
*
args
,
**
kwargs
)
set_random_seed
(
next
(
seed_iter
))
return
result
return
new_execute_model
def
zero_kv_cache
(
cache_engine
:
list
[
CacheEngine
]):
assert
cache_engine
[
0
].
gpu_cache
for
key_blocks
,
value_blocks
in
cache_engine
[
0
].
gpu_cache
:
key_blocks
.
zero_
()
value_blocks
.
zero_
()
def
create_worker
(
cls
:
Callable
[...,
T
],
model_name
:
str
,
block_size
:
int
,
num_gpu_blocks
:
int
,
seed
:
int
,
is_driver_worker
:
bool
=
True
,
enforce_eager
:
bool
=
True
,
model_runner_cls
:
Optional
[
ModelRunner
]
=
None
,
dtype
:
Optional
[
str
]
=
"auto"
)
->
T
:
engine_args
=
EngineArgs
(
model
=
model_name
,
seed
=
seed
,
block_size
=
block_size
,
enforce_eager
=
enforce_eager
,
dtype
=
dtype
,
)
engine_config
=
engine_args
.
create_engine_config
()
distributed_init_method
=
get_distributed_init_method
(
get_ip
(),
get_open_port
())
worker
=
cls
(
vllm_config
=
engine_config
,
local_rank
=
0
,
rank
=
0
,
distributed_init_method
=
distributed_init_method
,
is_driver_worker
=
is_driver_worker
,
model_runner_cls
=
model_runner_cls
,
)
worker
.
init_device
()
worker
.
load_model
()
engine_config
.
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
engine_config
.
cache_config
.
num_cpu_blocks
=
0
worker
.
initialize_cache
(
num_gpu_blocks
=
engine_config
.
cache_config
.
num_gpu_blocks
,
num_cpu_blocks
=
engine_config
.
cache_config
.
num_cpu_blocks
)
return
worker
def
create_seq_group_metadata_from_prompts
(
prompts
:
list
[
list
[
int
]],
num_gpu_blocks
:
int
,
block_size
:
int
,
final_prompt_lens
:
list
[
int
],
continuations
:
Optional
[
list
[
list
[
int
]]]
=
None
,
seq_ids
:
Optional
[
list
[
int
]]
=
None
,
)
->
list
[
SequenceGroupMetadata
]:
if
continuations
is
None
:
continuations
=
[[]
for
_
in
prompts
]
if
seq_ids
is
None
:
seq_ids
=
list
(
i
for
i
,
_
in
enumerate
(
prompts
))
free_gpu_blocks
=
list
(
range
(
num_gpu_blocks
))
block_allocations
=
{
i
:
[
free_gpu_blocks
.
pop
()
for
_
in
range
(
round_up_to_next_block
(
final_len
,
block_size
))
]
for
i
,
final_len
in
enumerate
(
final_prompt_lens
)
}
seq_grou_metadata_list
=
[]
for
i
,
(
prompt_token_ids
,
cont_token_ids
)
in
enumerate
(
zip
(
prompts
,
continuations
)):
data
=
SequenceData
.
from_seqs
(
prompt_token_ids
,
cont_token_ids
)
data
.
update_num_computed_tokens
(
len
(
prompt_token_ids
)
+
len
(
cont_token_ids
)
-
1
)
seq_data
=
{
i
:
data
}
seq_grou_metadata_list
.
append
(
SequenceGroupMetadata
(
request_id
=
str
(
i
),
is_prompt
=
len
(
cont_token_ids
)
==
0
,
seq_data
=
seq_data
,
sampling_params
=
SamplingParams
(
temperature
=
0.0
),
block_tables
=
{
i
:
block_allocations
[
i
][:]},
))
return
seq_grou_metadata_list
def
create_chunked_seq_group_metadata_from_prompt
(
prompt
:
list
[
int
],
num_gpu_blocks
:
int
,
chunk_size
:
int
,
block_size
:
int
,
seq_id
:
Optional
[
int
]
=
None
)
->
list
[
SequenceGroupMetadata
]:
if
seq_id
is
None
:
seq_id
=
0
free_gpu_blocks
=
list
(
range
(
num_gpu_blocks
))
block_allocations
=
[
free_gpu_blocks
.
pop
()
for
_
in
range
(
round_up_to_next_block
(
len
(
prompt
),
block_size
))
]
seq_group_metadata_list
=
[]
for
i
,
idx
in
enumerate
(
range
(
0
,
len
(
prompt
),
chunk_size
)):
chunk_ids
=
prompt
[
idx
:
idx
+
chunk_size
]
data
=
SequenceData
.
from_seqs
(
prompt
)
data
.
update_num_computed_tokens
(
idx
)
seq_data
=
{
i
:
data
}
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
request_id
=
str
(
seq_id
),
is_prompt
=
True
,
do_sample
=
idx
+
chunk_size
>=
len
(
prompt
),
# terminal chunk
seq_data
=
seq_data
,
sampling_params
=
SamplingParams
(
temperature
=
0.0
),
block_tables
=
{
i
:
block_allocations
},
token_chunk_size
=
len
(
chunk_ids
)))
return
seq_group_metadata_list
def
assert_logprobs_dict_allclose
(
actual_logprobs
:
list
[
dict
[
int
,
Logprob
]],
expected_logprobs
:
list
[
dict
[
int
,
Logprob
]])
->
None
:
for
single_step_actual_logprobs
,
single_step_expected_logprobs
in
zip
(
actual_logprobs
,
expected_logprobs
):
assert
set
(
single_step_actual_logprobs
.
keys
())
==
set
(
single_step_expected_logprobs
.
keys
())
for
token_id
in
single_step_actual_logprobs
:
actual
=
torch
.
tensor
(
single_step_actual_logprobs
[
token_id
].
logprob
)
expected
=
torch
.
tensor
(
single_step_expected_logprobs
[
token_id
].
logprob
)
torch
.
testing
.
assert_close
(
actual
,
expected
)
def
create_sampler_output_list
(
token_ids
:
torch
.
Tensor
,
probs
:
GenericSequence
[
Optional
[
torch
.
Tensor
]],
logprobs
:
GenericSequence
[
Optional
[
torch
.
Tensor
]],
seq_ids
:
Optional
[
list
[
int
]]
=
None
)
->
list
[
SamplerOutput
]:
num_steps
,
batch_size
=
token_ids
.
shape
token_ids_by_step
=
token_ids
.
tolist
()
if
seq_ids
is
None
:
seq_ids
=
list
(
range
(
batch_size
))
return
[
SamplerOutput
(
outputs
=
[
CompletionSequenceGroupOutput
(
samples
=
[
SequenceOutput
(
output_token
=
token_id
,
parent_seq_id
=
seq_ids
[
seq_index
],
logprobs
=
{
token_id
:
Logprob
(
0
)},
)
],
prompt_logprobs
=
None
,
)
for
seq_index
,
token_id
in
enumerate
(
token_ids_by_step
[
step
])
],
sampled_token_probs
=
probs
[
step
],
logprobs
=
logprobs
[
step
],
sampled_token_ids
=
token_ids
[
step
])
for
step
in
range
(
num_steps
)
]
def
create_batch
(
batch_size
,
k
,
prompt_len
:
Union
[
int
,
list
[
int
]]
=
10
,
prev_output_token_len
:
int
=
10
,
seq_ids
:
Optional
[
list
[
int
]]
=
None
,
num_gpu_blocks
:
Optional
[
int
]
=
None
,
block_size
:
Optional
[
int
]
=
None
,
prefill_chunk_size
:
Optional
[
int
]
=
None
):
if
block_size
is
None
:
block_size
=
8
if
num_gpu_blocks
is
None
:
num_gpu_blocks
=
2048
//
block_size
iterator
=
count
()
if
isinstance
(
prompt_len
,
int
):
prompt_lens
=
[
prompt_len
for
_
in
range
(
batch_size
)]
else
:
prompt_lens
=
prompt_len
prompts
=
[[
next
(
iterator
)
for
_
in
range
(
p_len
)]
for
p_len
in
prompt_lens
]
if
prefill_chunk_size
:
# Create a batch of chunked prompts.
if
not
seq_ids
:
seq_ids
=
list
(
range
(
len
(
prompts
)))
seq_group_metadata_list
=
[]
for
p
,
sid
in
zip
(
prompts
,
seq_ids
):
seq_group_metadata_list
+=
\
create_chunked_seq_group_metadata_from_prompt
(
p
,
num_gpu_blocks
,
prefill_chunk_size
,
block_size
,
sid
)
seq_group_metadata_list
=
seq_group_metadata_list
[:
batch_size
]
prev_output_tokens
=
[]
else
:
prev_output_tokens
=
[[
next
(
iterator
)
for
_
in
range
(
prev_output_token_len
)
]
for
_
in
range
(
batch_size
)]
final_prompt_lens
=
[
len
(
prompt
)
+
len
(
prev_output_token
)
+
k
+
1
for
prompt
,
prev_output_token
in
zip
(
prompts
,
prev_output_tokens
)
]
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
final_prompt_lens
,
prev_output_tokens
,
seq_ids
)
return
seq_group_metadata_list
,
prompts
,
prev_output_tokens
def
maybe_enable_chunked_prefill
(
prefill_chunk_size
,
llm_kwargs
):
if
prefill_chunk_size
>
0
:
llm_kwargs
.
update
(
**
{
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
prefill_chunk_size
,
"max_num_seqs"
:
prefill_chunk_size
})
else
:
llm_kwargs
[
"enable_chunked_prefill"
]
=
False
tests/tensorizer_loader/conftest.py
View file @
711aa9d5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Callable
import
pytest
from
vllm
import
LLM
,
EngineArgs
from
vllm.distributed
import
cleanup_dist_env_and_memory
from
vllm.model_executor.model_loader
import
tensorizer
as
tensorizer_mod
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
from
vllm.utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
from
vllm.v1.executor.abstract
import
UniProcExecutor
from
vllm.worker.worker_base
import
WorkerWrapperBase
MODEL_REF
=
"facebook/opt-125m"
@
pytest
.
fixture
()
def
model_ref
():
return
MODEL_REF
@
pytest
.
fixture
(
autouse
=
True
)
def
allow_insecure_serialization
(
monkeypatch
):
monkeypatch
.
setenv
(
"VLLM_ALLOW_INSECURE_SERIALIZATION"
,
"1"
)
@
pytest
.
fixture
(
autouse
=
True
)
...
...
@@ -11,7 +30,73 @@ def cleanup():
cleanup_dist_env_and_memory
(
shutdown_ray
=
True
)
@
pytest
.
fixture
()
def
just_serialize_model_tensors
(
model_ref
,
monkeypatch
,
tmp_path
):
def
noop
(
*
args
,
**
kwargs
):
return
None
args
=
EngineArgs
(
model
=
model_ref
)
tc
=
TensorizerConfig
(
tensorizer_uri
=
f
"
{
tmp_path
}
/model.tensors"
)
monkeypatch
.
setattr
(
tensorizer_mod
,
"serialize_extra_artifacts"
,
noop
)
tensorizer_mod
.
tensorize_vllm_model
(
args
,
tc
)
yield
tmp_path
@
pytest
.
fixture
(
autouse
=
True
)
def
tensorizer_config
():
config
=
TensorizerConfig
(
tensorizer_uri
=
"vllm"
)
return
config
@
pytest
.
fixture
()
def
model_path
(
model_ref
,
tmp_path
):
yield
tmp_path
/
model_ref
/
"model.tensors"
def
assert_from_collective_rpc
(
engine
:
LLM
,
closure
:
Callable
,
closure_kwargs
:
dict
):
res
=
engine
.
collective_rpc
(
method
=
closure
,
kwargs
=
closure_kwargs
)
return
all
(
res
)
# This is an object pulled from tests/v1/engine/test_engine_core.py
# Modified to strip the `load_model` method from its `_init_executor`
# method. It's purely used as a dummy utility to run methods that test
# Tensorizer functionality
class
DummyExecutor
(
UniProcExecutor
):
def
_init_executor
(
self
)
->
None
:
"""Initialize the worker and load the model.
"""
self
.
driver_worker
=
WorkerWrapperBase
(
vllm_config
=
self
.
vllm_config
,
rpc_rank
=
0
)
distributed_init_method
=
get_distributed_init_method
(
get_ip
(),
get_open_port
())
local_rank
=
0
# set local rank as the device index if specified
device_info
=
self
.
vllm_config
.
device_config
.
device
.
__str__
().
split
(
":"
)
if
len
(
device_info
)
>
1
:
local_rank
=
int
(
device_info
[
1
])
rank
=
0
is_driver_worker
=
True
kwargs
=
dict
(
vllm_config
=
self
.
vllm_config
,
local_rank
=
local_rank
,
rank
=
rank
,
distributed_init_method
=
distributed_init_method
,
is_driver_worker
=
is_driver_worker
,
)
self
.
collective_rpc
(
"init_worker"
,
args
=
([
kwargs
],
))
self
.
collective_rpc
(
"init_device"
)
@
property
def
max_concurrent_batches
(
self
)
->
int
:
return
2
def
shutdown
(
self
):
if
hasattr
(
self
,
'thread_pool'
):
self
.
thread_pool
.
shutdown
(
wait
=
False
)
tests/tensorizer_loader/test_tensorizer.py
View file @
711aa9d5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
asyncio
import
gc
import
json
import
os
import
pathlib
import
subprocess
import
sys
from
typing
import
Any
import
pytest
import
torch
from
vllm
import
EngineArgs
,
LLMEngine
,
RequestOutput
,
SamplingParams
import
vllm.model_executor.model_loader.tensorizer
from
vllm
import
LLM
,
SamplingParams
from
vllm.engine.arg_utils
import
EngineArgs
# yapf conflicts with isort for this docstring
# yapf: disable
from
vllm.model_executor.model_loader.tensorizer
import
(
TensorizerConfig
,
TensorSerializer
,
is_vllm_tensorized
,
open_stream
,
tensorize_vllm_model
)
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.model_loader.tensorizer_loader
import
(
BLACKLISTED_TENSORIZER_ARGS
)
# yapf: enable
from
vllm.utils
import
PlaceholderModule
from
..utils
import
VLLM_PATH
,
models_path_prefix
from
..utils
import
VLLM_PATH
,
RemoteOpenAIServer
,
models_path_prefix
from
.conftest
import
DummyExecutor
,
assert_from_collective_rpc
try
:
import
tensorizer
from
tensorizer
import
EncryptionParams
except
ImportError
:
tensorizer
=
PlaceholderModule
(
"tensorizer"
)
# type: ignore[assignment]
EncryptionParams
=
tensorizer
.
placeholder_attr
(
"EncryptionParams"
)
class
TensorizerCaughtError
(
Exception
):
pass
EXAMPLES_PATH
=
VLLM_PATH
/
"examples"
pytest_plugins
=
"pytest_asyncio"
,
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
...
...
@@ -42,9 +56,37 @@ prompts = [
# Create a sampling params object.
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
,
seed
=
0
)
model_ref
=
os
.
path
.
join
(
models_path_prefix
,
"facebook/opt-125m"
)
tensorize_model_for_testing_script
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"tensorize_vllm_model_for_testing.py"
)
def
patch_init_and_catch_error
(
self
,
obj
,
method_name
,
expected_error
:
type
[
Exception
]):
original
=
getattr
(
obj
,
method_name
,
None
)
if
original
is
None
:
raise
ValueError
(
"Method '{}' not found."
.
format
(
method_name
))
def
wrapper
(
*
args
,
**
kwargs
):
try
:
return
original
(
*
args
,
**
kwargs
)
except
expected_error
as
err
:
raise
TensorizerCaughtError
from
err
setattr
(
obj
,
method_name
,
wrapper
)
self
.
load_model
()
def
assert_specific_tensorizer_error_is_raised
(
executor
,
obj
:
Any
,
method_name
:
str
,
expected_error
:
type
[
Exception
],
):
with
pytest
.
raises
(
TensorizerCaughtError
):
executor
.
collective_rpc
(
patch_init_and_catch_error
,
args
=
(
obj
,
method_name
,
expected_error
,
))
def
is_curl_installed
():
...
...
@@ -62,32 +104,12 @@ def write_keyfile(keyfile_path: str):
f
.
write
(
encryption_params
.
key
)
@
pytest
.
mark
.
skipif
(
not
is_curl_installed
(),
reason
=
"cURL is not installed"
)
def
test_can_deserialize_s3
(
vllm_runner
):
model_ref
=
os
.
path
.
join
(
models_path_prefix
,
"EleutherAI/pythia-1.4b"
)
tensorized_path
=
f
"
{
model_ref
}
/fp16/model.tensors"
with
vllm_runner
(
model_ref
,
load_format
=
"tensorizer"
,
model_loader_extra_config
=
TensorizerConfig
(
tensorizer_uri
=
tensorized_path
,
num_readers
=
1
,
s3_endpoint
=
"object.ord1.coreweave.com"
,
))
as
loaded_hf_model
:
deserialized_outputs
=
loaded_hf_model
.
generate
(
prompts
,
sampling_params
)
# noqa: E501
assert
deserialized_outputs
@
pytest
.
mark
.
skipif
(
not
is_curl_installed
(),
reason
=
"cURL is not installed"
)
def
test_deserialized_encrypted_vllm_model_has_same_outputs
(
vllm_runner
,
tmp_path
):
model_ref
,
vllm_runner
,
tmp_path
,
model_path
):
args
=
EngineArgs
(
model
=
model_ref
)
with
vllm_runner
(
model_ref
)
as
vllm_model
:
model_path
=
tmp_path
/
(
model_ref
+
".tensors"
)
key_path
=
tmp_path
/
(
model_ref
+
".key"
)
key_path
=
tmp_path
/
model_ref
/
"model.key"
write_keyfile
(
key_path
)
outputs
=
vllm_model
.
generate
(
prompts
,
sampling_params
)
...
...
@@ -113,9 +135,9 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs(
def
test_deserialized_hf_model_has_same_outputs
(
hf_runner
,
vllm_runner
,
tmp_path
):
tmp_path
,
model_ref
,
model_path
):
with
hf_runner
(
model_ref
)
as
hf_model
:
model_path
=
tmp_path
/
(
model_ref
+
".tensors"
)
max_tokens
=
50
outputs
=
hf_model
.
generate_greedy
(
prompts
,
max_tokens
=
max_tokens
)
with
open_stream
(
model_path
,
"wb+"
)
as
stream
:
...
...
@@ -125,7 +147,7 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
with
vllm_runner
(
model_ref
,
load_format
=
"tensorizer"
,
model_loader_extra_config
=
TensorizerConfig
(
tensorizer_uri
=
model_path
,
tensorizer_uri
=
str
(
model_path
)
,
num_readers
=
1
,
))
as
loaded_hf_model
:
deserialized_outputs
=
loaded_hf_model
.
generate_greedy
(
...
...
@@ -134,7 +156,7 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
assert
outputs
==
deserialized_outputs
def
test_load_without_tensorizer_load_format
(
vllm_runner
,
capfd
):
def
test_load_without_tensorizer_load_format
(
vllm_runner
,
capfd
,
model_ref
):
model
=
None
try
:
model
=
vllm_runner
(
...
...
@@ -152,7 +174,8 @@ def test_load_without_tensorizer_load_format(vllm_runner, capfd):
torch
.
cuda
.
empty_cache
()
def
test_raise_value_error_on_invalid_load_format
(
vllm_runner
,
capfd
):
def
test_raise_value_error_on_invalid_load_format
(
vllm_runner
,
capfd
,
model_ref
):
model
=
None
try
:
model
=
vllm_runner
(
...
...
@@ -211,7 +234,7 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(
outputs
=
base_model
.
generate
(
prompts
,
sampling_params
)
# load model with two shards and serialize with encryption
model_path
=
str
(
tmp_path
/
(
model_ref
+
"-%02d.tensors"
)
)
model_path
=
str
(
tmp_path
/
model_ref
/
"
model
-%02d.tensors"
)
key_path
=
tmp_path
/
(
model_ref
+
".key"
)
tensorizer_config
=
TensorizerConfig
(
...
...
@@ -245,13 +268,13 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(
@
pytest
.
mark
.
flaky
(
reruns
=
3
)
def
test_vllm_tensorized_model_has_same_outputs
(
vllm_runner
,
tmp_path
):
def
test_vllm_tensorized_model_has_same_outputs
(
model_ref
,
vllm_runner
,
tmp_path
,
model_path
):
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
model_ref
=
os
.
path
.
join
(
models_path_prefix
,
"facebook/opt-125m"
)
model_path
=
tmp_path
/
(
model_ref
+
".tensors"
)
config
=
TensorizerConfig
(
tensorizer_uri
=
str
(
model_path
))
args
=
EngineArgs
(
model
=
model_ref
,
device
=
"cuda"
)
args
=
EngineArgs
(
model
=
model_ref
)
with
vllm_runner
(
model_ref
)
as
vllm_model
:
outputs
=
vllm_model
.
generate
(
prompts
,
sampling_params
)
...
...
@@ -266,4 +289,244 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
prompts
,
sampling_params
)
# noqa: E501
assert
outputs
==
deserialized_outputs
\ No newline at end of file
assert
outputs
==
deserialized_outputs
def
test_load_with_just_model_tensors
(
just_serialize_model_tensors
,
model_ref
):
# For backwards compatibility, ensure Tensorizer can be still be loaded
# for inference by passing the model reference name, not a local/S3 dir,
# and the location of the model tensors
model_dir
=
just_serialize_model_tensors
extra_config
=
{
"tensorizer_uri"
:
f
"
{
model_dir
}
/model.tensors"
}
## Start OpenAI API server
args
=
[
"--load-format"
,
"tensorizer"
,
"--model-loader-extra-config"
,
json
.
dumps
(
extra_config
),
]
with
RemoteOpenAIServer
(
model_ref
,
args
):
# This test only concerns itself with being able to load the model
# and successfully initialize the server
pass
def
test_assert_serialization_kwargs_passed_to_tensor_serializer
(
tmp_path
):
serialization_params
=
{
"limit_cpu_concurrency"
:
2
,
}
model_ref
=
"facebook/opt-125m"
model_path
=
tmp_path
/
(
model_ref
+
".tensors"
)
config
=
TensorizerConfig
(
tensorizer_uri
=
str
(
model_path
),
serialization_kwargs
=
serialization_params
)
llm
=
LLM
(
model
=
model_ref
,
)
def
serialization_test
(
self
,
*
args
,
**
kwargs
):
# This is performed in the ephemeral worker process, so monkey-patching
# will actually work, and cleanup is guaranteed so don't
# need to reset things
original_dict
=
serialization_params
to_compare
=
{}
original
=
tensorizer
.
serialization
.
TensorSerializer
.
__init__
def
tensorizer_serializer_wrapper
(
self
,
*
args
,
**
kwargs
):
nonlocal
to_compare
to_compare
=
kwargs
.
copy
()
return
original
(
self
,
*
args
,
**
kwargs
)
tensorizer
.
serialization
.
TensorSerializer
.
__init__
=
(
tensorizer_serializer_wrapper
)
tensorizer_config
=
TensorizerConfig
(
**
kwargs
[
"tensorizer_config"
])
self
.
save_tensorized_model
(
tensorizer_config
=
tensorizer_config
,
)
return
to_compare
|
original_dict
==
to_compare
kwargs
=
{
"tensorizer_config"
:
config
.
to_serializable
()}
assert
assert_from_collective_rpc
(
llm
,
serialization_test
,
kwargs
)
def
test_assert_deserialization_kwargs_passed_to_tensor_deserializer
(
tmp_path
,
capfd
):
deserialization_kwargs
=
{
"num_readers"
:
"bar"
,
# illegal value
}
serialization_params
=
{
"limit_cpu_concurrency"
:
2
,
}
model_ref
=
"facebook/opt-125m"
model_path
=
tmp_path
/
(
model_ref
+
".tensors"
)
config
=
TensorizerConfig
(
tensorizer_uri
=
str
(
model_path
),
serialization_kwargs
=
serialization_params
)
args
=
EngineArgs
(
model
=
model_ref
)
tensorize_vllm_model
(
args
,
config
)
loader_tc
=
TensorizerConfig
(
tensorizer_uri
=
str
(
model_path
),
deserialization_kwargs
=
deserialization_kwargs
,
)
engine_args
=
EngineArgs
(
model
=
"facebook/opt-125m"
,
load_format
=
"tensorizer"
,
model_loader_extra_config
=
loader_tc
.
to_serializable
(),
)
vllm_config
=
engine_args
.
create_engine_config
()
executor
=
DummyExecutor
(
vllm_config
)
assert_specific_tensorizer_error_is_raised
(
executor
,
tensorizer
.
serialization
.
TensorDeserializer
,
"__init__"
,
TypeError
,
)
def
test_assert_stream_kwargs_passed_to_tensor_deserializer
(
tmp_path
,
capfd
):
deserialization_kwargs
=
{
"num_readers"
:
1
,
}
serialization_params
=
{
"limit_cpu_concurrency"
:
2
,
}
model_ref
=
"facebook/opt-125m"
model_path
=
tmp_path
/
(
model_ref
+
".tensors"
)
config
=
TensorizerConfig
(
tensorizer_uri
=
str
(
model_path
),
serialization_kwargs
=
serialization_params
)
args
=
EngineArgs
(
model
=
model_ref
)
tensorize_vllm_model
(
args
,
config
)
stream_kwargs
=
{
"mode"
:
"foo"
}
loader_tc
=
TensorizerConfig
(
tensorizer_uri
=
str
(
model_path
),
deserialization_kwargs
=
deserialization_kwargs
,
stream_kwargs
=
stream_kwargs
,
)
engine_args
=
EngineArgs
(
model
=
"facebook/opt-125m"
,
load_format
=
"tensorizer"
,
model_loader_extra_config
=
loader_tc
.
to_serializable
(),
)
vllm_config
=
engine_args
.
create_engine_config
()
executor
=
DummyExecutor
(
vllm_config
)
assert_specific_tensorizer_error_is_raised
(
executor
,
vllm
.
model_executor
.
model_loader
.
tensorizer
,
"open_stream"
,
ValueError
,
)
@
pytest
.
mark
.
asyncio
async
def
test_serialize_and_serve_entrypoints
(
tmp_path
):
model_ref
=
"facebook/opt-125m"
suffix
=
"test"
try
:
result
=
subprocess
.
run
([
sys
.
executable
,
f
"
{
VLLM_PATH
}
/examples/others/tensorize_vllm_model.py"
,
"--model"
,
model_ref
,
"serialize"
,
"--serialized-directory"
,
str
(
tmp_path
),
"--suffix"
,
suffix
,
"--serialization-kwargs"
,
'{"limit_cpu_concurrency": 4}'
],
check
=
True
,
capture_output
=
True
,
text
=
True
)
except
subprocess
.
CalledProcessError
as
e
:
print
(
"Tensorizing failed."
)
print
(
"STDOUT:
\n
"
,
e
.
stdout
)
print
(
"STDERR:
\n
"
,
e
.
stderr
)
raise
assert
"Successfully serialized"
in
result
.
stdout
# Next, try to serve with vllm serve
model_uri
=
tmp_path
/
"vllm"
/
model_ref
/
suffix
/
"model.tensors"
model_loader_extra_config
=
{
"tensorizer_uri"
:
str
(
model_uri
),
"stream_kwargs"
:
{
"force_http"
:
False
,
},
"deserialization_kwargs"
:
{
"verify_hash"
:
True
,
"num_readers"
:
8
,
}
}
cmd
=
[
"-m"
,
"vllm.entrypoints.cli.main"
,
"serve"
,
"--host"
,
"localhost"
,
"--load-format"
,
"tensorizer"
,
model_ref
,
"--model-loader-extra-config"
,
json
.
dumps
(
model_loader_extra_config
,
indent
=
2
)
]
proc
=
await
asyncio
.
create_subprocess_exec
(
sys
.
executable
,
*
cmd
,
stdout
=
asyncio
.
subprocess
.
PIPE
,
stderr
=
asyncio
.
subprocess
.
STDOUT
,
)
assert
proc
.
stdout
is
not
None
fut
=
proc
.
stdout
.
readuntil
(
b
"Application startup complete."
)
try
:
await
asyncio
.
wait_for
(
fut
,
180
)
except
asyncio
.
TimeoutError
:
pytest
.
fail
(
"Server did not start successfully"
)
finally
:
proc
.
terminate
()
await
proc
.
communicate
()
@
pytest
.
mark
.
parametrize
(
"illegal_value"
,
BLACKLISTED_TENSORIZER_ARGS
)
def
test_blacklisted_parameter_for_loading
(
tmp_path
,
vllm_runner
,
capfd
,
illegal_value
):
serialization_params
=
{
"limit_cpu_concurrency"
:
2
,
}
model_ref
=
"facebook/opt-125m"
model_path
=
tmp_path
/
(
model_ref
+
".tensors"
)
config
=
TensorizerConfig
(
tensorizer_uri
=
str
(
model_path
),
serialization_kwargs
=
serialization_params
)
args
=
EngineArgs
(
model
=
model_ref
)
tensorize_vllm_model
(
args
,
config
)
loader_tc
=
{
"tensorizer_uri"
:
str
(
model_path
),
illegal_value
:
"foo"
}
try
:
vllm_runner
(
model_ref
,
load_format
=
"tensorizer"
,
model_loader_extra_config
=
loader_tc
,
)
except
RuntimeError
:
out
,
err
=
capfd
.
readouterr
()
combined_output
=
out
+
err
assert
(
f
"ValueError:
{
illegal_value
}
is not an allowed "
f
"Tensorizer argument."
)
in
combined_output
tests/test_config.py
View file @
711aa9d5
...
...
@@ -8,7 +8,7 @@ import os
from
vllm.compilation.backends
import
VllmBackend
from
vllm.config
import
(
LoadConfig
,
ModelConfig
,
PoolerConfig
,
VllmConfig
,
get_field
)
get_field
,
update_config
)
from
vllm.model_executor.layers.pooler
import
PoolingType
from
vllm.platforms
import
current_platform
from
utils
import
models_path_prefix
...
...
@@ -48,15 +48,43 @@ def test_get_field():
assert
c
.
default_factory
is
MISSING
@
dataclass
class
_TestNestedConfig
:
a
:
_TestConfigFields
=
field
(
default_factory
=
lambda
:
_TestConfigFields
(
a
=
0
))
def
test_update_config
():
# Simple update
config1
=
_TestConfigFields
(
a
=
0
)
new_config1
=
update_config
(
config1
,
{
"a"
:
42
})
assert
new_config1
.
a
==
42
# Nonexistent field
with
pytest
.
raises
(
AssertionError
):
new_config1
=
update_config
(
config1
,
{
"nonexistent"
:
1
})
# Nested update with dataclass
config2
=
_TestNestedConfig
()
new_inner_config
=
_TestConfigFields
(
a
=
1
,
c
=
"new_value"
)
new_config2
=
update_config
(
config2
,
{
"a"
:
new_inner_config
})
assert
new_config2
.
a
==
new_inner_config
# Nested update with dict
config3
=
_TestNestedConfig
()
new_config3
=
update_config
(
config3
,
{
"a"
:
{
"c"
:
"new_value"
}})
assert
new_config3
.
a
.
c
==
"new_value"
# Nested update with invalid type
with
pytest
.
raises
(
AssertionError
):
new_config3
=
update_config
(
config3
,
{
"a"
:
"new_value"
})
@
pytest
.
mark
.
parametrize
(
(
"model_id"
,
"expected_runner_type"
,
"expected_task"
),
[
(
os
.
path
.
join
(
models_path_prefix
,
"distilbert/distilgpt2"
),
"generate"
,
"generate"
),
(
os
.
path
.
join
(
models_path_prefix
,
"intfloat/multilingual-e5-small"
),
"pooling"
,
"embed"
),
(
os
.
path
.
join
(
models_path_prefix
,
"jason9693/Qwen2.5-1.5B-apeach"
),
"pooling"
,
"classify"
),
(
os
.
path
.
join
(
models_path_prefix
,
"cross-encoder/ms-marco-MiniLM-L-6-v2"
),
"pooling"
,
"classify"
),
(
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen2.5-Math-RM-72B"
),
"pooling"
,
"reward"
),
(
os
.
path
.
join
(
models_path_prefix
,
"openai/whisper-small"
),
"
transcription
"
,
"transcription"
),
(
os
.
path
.
join
(
models_path_prefix
,
"distilbert/distilgpt2"
),
"generate"
,
"generate"
),
(
os
.
path
.
join
(
models_path_prefix
,
"intfloat/multilingual-e5-small"
),
"pooling"
,
"embed"
),
(
os
.
path
.
join
(
models_path_prefix
,
"jason9693/Qwen2.5-1.5B-apeach"
),
"pooling"
,
"classify"
),
(
os
.
path
.
join
(
models_path_prefix
,
"cross-encoder/ms-marco-MiniLM-L-6-v2"
),
"pooling"
,
"classify"
),
(
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen2.5-Math-RM-72B"
),
"pooling"
,
"reward"
),
(
os
.
path
.
join
(
models_path_prefix
,
"openai/whisper-small"
),
"
generate
"
,
"transcription"
),
],
)
def
test_auto_task
(
model_id
,
expected_runner_type
,
expected_task
):
...
...
@@ -71,7 +99,11 @@ def test_auto_task(model_id, expected_runner_type, expected_task):
)
assert
config
.
runner_type
==
expected_runner_type
assert
config
.
task
==
expected_task
if
config
.
runner_type
==
"pooling"
:
assert
config
.
task
==
expected_task
else
:
assert
expected_task
in
config
.
supported_tasks
@
pytest
.
mark
.
parametrize
(
...
...
@@ -100,11 +132,50 @@ def test_score_task(model_id, expected_runner_type, expected_task):
assert
config
.
task
==
expected_task
@
pytest
.
mark
.
parametrize
((
"model_id"
,
"expected_runner_type"
,
"expected_task"
),
[
(
"Qwen/Qwen2.5-1.5B-Instruct"
,
"draft"
,
"auto"
),
])
def
test_draft_task
(
model_id
,
expected_runner_type
,
expected_task
):
config
=
ModelConfig
(
model_id
,
runner
=
"draft"
,
tokenizer
=
model_id
,
seed
=
0
,
dtype
=
"float16"
,
)
assert
config
.
runner_type
==
expected_runner_type
assert
config
.
task
==
expected_task
@
pytest
.
mark
.
parametrize
(
(
"model_id"
,
"expected_runner_type"
,
"expected_task"
),
[
(
"openai/whisper-small"
,
"generate"
,
"transcription"
),
],
)
def
test_transcription_task
(
model_id
,
expected_runner_type
,
expected_task
):
config
=
ModelConfig
(
model_id
,
task
=
"transcription"
,
tokenizer
=
model_id
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
seed
=
0
,
dtype
=
"float16"
,
)
assert
config
.
runner_type
==
expected_runner_type
assert
config
.
task
==
expected_task
@
pytest
.
mark
.
parametrize
((
"model_id"
,
"bad_task"
),
[
(
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen2.5-Math-RM-72B"
),
"generate"
),
(
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen2.5-Math-RM-72B"
),
"generate"
),
(
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen3-0.6B"
),
"transcription"
),
])
def
test_incorrect_task
(
model_id
,
bad_task
):
with
pytest
.
raises
(
ValueError
,
match
=
r
"does not support t
he .* t
ask"
):
with
pytest
.
raises
(
ValueError
,
match
=
r
"does not support task
=.*
"
):
ModelConfig
(
model_id
,
task
=
bad_task
,
...
...
tests/test_sequence.py
View file @
711aa9d5
...
...
@@ -29,7 +29,6 @@ def test_sampler_output_initialization(sampler_output, sample_outputs):
assert
len
(
sampler_output
)
==
len
(
sample_outputs
)
assert
sampler_output
.
sampled_token_probs
is
None
assert
sampler_output
.
sampled_token_ids
is
None
assert
sampler_output
.
spec_decode_worker_metrics
is
None
def
test_sampler_output_getitem
(
sampler_output
,
sample_outputs
):
...
...
tests/test_utils.py
View file @
711aa9d5
...
...
@@ -15,15 +15,18 @@ import os
import
pytest
import
torch
import
zmq
from
transformers
import
AutoTokenizer
from
vllm_test_utils.monitor
import
monitor
from
vllm.config
import
ParallelConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.transformers_utils.detokenizer_utils
import
(
convert_ids_list_to_tokens
)
from
vllm.utils
import
(
CacheInfo
,
FlexibleArgumentParser
,
LRUCache
,
MemorySnapshot
,
PlaceholderModule
,
StoreBoolean
,
bind_kv_cache
,
common_broadcastable_dtype
,
deprecate_kwargs
,
get_open_port
,
get_tcp_uri
,
is_lossless_cast
,
join_host_port
,
make_zmq_path
,
make_zmq_socket
,
memory_profiling
,
current_stream
,
deprecate_kwargs
,
get_open_port
,
get_tcp_uri
,
is_lossless_cast
,
join_host_port
,
make_zmq_path
,
make_zmq_socket
,
memory_profiling
,
merge_async_iterators
,
sha256
,
split_host_port
,
split_zmq_path
,
supports_kw
,
swap_dict_values
)
...
...
@@ -458,6 +461,31 @@ def test_bind_kv_cache():
assert
ctx
[
'layers.2.self_attn'
].
kv_cache
[
0
]
is
kv_cache
[
2
]
assert
ctx
[
'layers.3.self_attn'
].
kv_cache
[
0
]
is
kv_cache
[
3
]
def
test_bind_kv_cache_kv_sharing
():
from
vllm.attention
import
Attention
ctx
=
{
'layers.0.self_attn'
:
Attention
(
32
,
128
,
0.1
),
'layers.1.self_attn'
:
Attention
(
32
,
128
,
0.1
),
'layers.2.self_attn'
:
Attention
(
32
,
128
,
0.1
),
'layers.3.self_attn'
:
Attention
(
32
,
128
,
0.1
),
}
kv_cache
=
[
torch
.
zeros
((
1
,
)),
torch
.
zeros
((
1
,
)),
torch
.
zeros
((
1
,
)),
torch
.
zeros
((
1
,
)),
]
shared_kv_cache_layers
=
{
'layers.2.self_attn'
:
'layers.1.self_attn'
,
'layers.3.self_attn'
:
'layers.0.self_attn'
}
bind_kv_cache
(
ctx
,
[
kv_cache
],
shared_kv_cache_layers
)
assert
ctx
[
'layers.0.self_attn'
].
kv_cache
[
0
]
is
kv_cache
[
0
]
assert
ctx
[
'layers.1.self_attn'
].
kv_cache
[
0
]
is
kv_cache
[
1
]
assert
ctx
[
'layers.2.self_attn'
].
kv_cache
[
0
]
is
kv_cache
[
1
]
assert
ctx
[
'layers.3.self_attn'
].
kv_cache
[
0
]
is
kv_cache
[
0
]
def
test_bind_kv_cache_non_attention
():
from
vllm.attention
import
Attention
...
...
@@ -921,3 +949,52 @@ def test_split_host_port():
def
test_join_host_port
():
assert
join_host_port
(
"127.0.0.1"
,
5555
)
==
"127.0.0.1:5555"
assert
join_host_port
(
"::1"
,
5555
)
==
"[::1]:5555"
def
test_convert_ids_list_to_tokens
():
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"Qwen/Qwen2.5-1.5B-Instruct"
)
token_ids
=
tokenizer
.
encode
(
"Hello, world!"
)
# token_ids = [9707, 11, 1879, 0]
assert
tokenizer
.
convert_ids_to_tokens
(
token_ids
)
==
[
'Hello'
,
','
,
'Ġworld'
,
'!'
]
tokens
=
convert_ids_list_to_tokens
(
tokenizer
,
token_ids
)
assert
tokens
==
[
'Hello'
,
','
,
' world'
,
'!'
]
def
test_current_stream_multithread
():
import
threading
if
not
torch
.
cuda
.
is_available
():
pytest
.
skip
(
"CUDA not available"
)
main_default_stream
=
torch
.
cuda
.
current_stream
()
child_stream
=
torch
.
cuda
.
Stream
()
thread_stream_ready
=
threading
.
Event
()
thread_can_exit
=
threading
.
Event
()
def
child_thread_func
():
with
torch
.
cuda
.
stream
(
child_stream
):
thread_stream_ready
.
set
()
thread_can_exit
.
wait
(
timeout
=
10
)
child_thread
=
threading
.
Thread
(
target
=
child_thread_func
)
child_thread
.
start
()
try
:
assert
thread_stream_ready
.
wait
(
timeout
=
5
),
"Child thread failed to enter stream context in time"
main_current_stream
=
current_stream
()
assert
main_current_stream
!=
child_stream
,
"Main thread's current_stream was contaminated by child thread"
assert
main_current_stream
==
main_default_stream
,
"Main thread's current_stream is not the default stream"
# Notify child thread it can exit
thread_can_exit
.
set
()
finally
:
# Ensure child thread exits properly
child_thread
.
join
(
timeout
=
5
)
if
child_thread
.
is_alive
():
pytest
.
fail
(
"Child thread failed to exit properly"
)
tests/tokenization/test_detokenize.py
View file @
711aa9d5
...
...
@@ -395,7 +395,7 @@ def test_decode_prompt_logprobs_chunked_prefill(
logprobs
=
5
,
prompt_logprobs
=
5
,
temperature
=
0.0
)
vllm_results
=
vllm_model
.
model
.
generate
(
vllm_results
=
vllm_model
.
llm
.
generate
(
example_prompts
,
sampling_params
=
vllm_sampling_params
)
for
idx
,
result
in
enumerate
(
vllm_results
):
...
...
tests/tokenization/test_do_lower_case.py
0 → 100644
View file @
711aa9d5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
TOKENIZER_NAMES
=
[
"BAAI/bge-base-en"
]
@
pytest
.
mark
.
parametrize
(
"tokenizer_name"
,
TOKENIZER_NAMES
)
@
pytest
.
mark
.
parametrize
(
"n_tokens"
,
[
510
])
def
test_special_tokens
(
tokenizer_name
:
str
,
n_tokens
:
int
):
tokenizer
=
get_tokenizer
(
tokenizer_name
,
revision
=
"main"
)
prompts
=
'[UNK]'
*
n_tokens
prompt_token_ids
=
tokenizer
.
encode
(
prompts
)
assert
len
(
prompt_token_ids
)
==
n_tokens
+
2
tests/tool_use/test_glm4_moe_tool_parser.py
View file @
711aa9d5
...
...
@@ -254,7 +254,9 @@ def test_extract_tool_calls_mixed_content(glm4_moe_tool_parser):
<arg_key>date</arg_key>
<arg_value>2025-08-01</arg_value>
</tool_call>
meaningwhile, I will also check the weather in Shanghai.
<tool_call>get_weather
<arg_key>city</arg_key>
<arg_value>Shanghai</arg_value>
...
...
@@ -402,4 +404,4 @@ def test_extract_tool_calls_incomplete_tool_call(glm4_moe_tool_parser):
# Incomplete tool calls should not be extracted
assert
not
extracted_tool_calls
.
tools_called
assert
extracted_tool_calls
.
tool_calls
==
[]
assert
extracted_tool_calls
.
content
==
model_output
\ No newline at end of file
assert
extracted_tool_calls
.
content
==
model_output
tests/tool_use/test_kimi_k2_tool_parser.py
0 → 100644
View file @
711aa9d5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
import
json
import
pytest
from
vllm.entrypoints.openai.protocol
import
FunctionCall
,
ToolCall
from
vllm.entrypoints.openai.tool_parsers
import
KimiK2ToolParser
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
# Use a common model that is likely to be available
MODEL
=
"moonshotai/Kimi-K2-Instruct"
@
pytest
.
fixture
(
scope
=
"module"
)
def
kimi_k2_tokenizer
():
return
get_tokenizer
(
tokenizer_name
=
MODEL
,
trust_remote_code
=
True
)
@
pytest
.
fixture
def
kimi_k2_tool_parser
(
kimi_k2_tokenizer
):
return
KimiK2ToolParser
(
kimi_k2_tokenizer
)
def
assert_tool_calls
(
actual_tool_calls
:
list
[
ToolCall
],
expected_tool_calls
:
list
[
ToolCall
]):
assert
len
(
actual_tool_calls
)
==
len
(
expected_tool_calls
)
for
actual_tool_call
,
expected_tool_call
in
zip
(
actual_tool_calls
,
expected_tool_calls
):
assert
actual_tool_call
.
type
==
"function"
assert
actual_tool_call
.
function
==
expected_tool_call
.
function
# assert tool call id format
assert
actual_tool_call
.
id
.
startswith
(
"functions."
)
assert
actual_tool_call
.
id
.
split
(
':'
)[
-
1
].
isdigit
()
assert
actual_tool_call
.
id
.
split
(
'.'
)[
1
].
split
(
':'
)[
0
]
==
expected_tool_call
.
function
.
name
def
test_extract_tool_calls_no_tools
(
kimi_k2_tool_parser
):
model_output
=
"This is a test"
extracted_tool_calls
=
kimi_k2_tool_parser
.
extract_tool_calls
(
model_output
,
request
=
None
)
# type: ignore[arg-type]
assert
not
extracted_tool_calls
.
tools_called
assert
extracted_tool_calls
.
tool_calls
==
[]
assert
extracted_tool_calls
.
content
==
model_output
@
pytest
.
mark
.
parametrize
(
ids
=
[
"tool_call_with_content_before"
,
"multi_tool_call_with_content_before"
,
],
argnames
=
[
"model_output"
,
"expected_tool_calls"
,
"expected_content"
],
argvalues
=
[
(
"""I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_calls_section_end|>"""
,
[
ToolCall
(
id
=
'functions.get_weather:0'
,
function
=
FunctionCall
(
name
=
"get_weather"
,
arguments
=
json
.
dumps
({
"city"
:
"Beijing"
,
},
),
),
type
=
'function'
)
],
"I'll help you check the weather. "
,
),
(
"""I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_call_begin|>
functions.get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>"""
,
[
ToolCall
(
id
=
'functions.get_weather:0'
,
function
=
FunctionCall
(
name
=
"get_weather"
,
arguments
=
json
.
dumps
({
"city"
:
"Beijing"
,
},
),
),
type
=
'function'
),
ToolCall
(
id
=
'functions.get_weather:1'
,
function
=
FunctionCall
(
name
=
"get_weather"
,
arguments
=
json
.
dumps
({
"city"
:
"Shanghai"
,
},
),
),
type
=
'function'
)
],
"I'll help you check the weather. "
,
),
],
)
def
test_extract_tool_calls
(
kimi_k2_tool_parser
,
model_output
,
expected_tool_calls
,
expected_content
):
extracted_tool_calls
=
kimi_k2_tool_parser
.
extract_tool_calls
(
model_output
,
request
=
None
)
# type: ignore[arg-type]
assert
extracted_tool_calls
.
tools_called
assert_tool_calls
(
extracted_tool_calls
.
tool_calls
,
expected_tool_calls
)
assert
extracted_tool_calls
.
content
==
expected_content
def
test_extract_tool_calls_invalid_json
(
kimi_k2_tool_parser
):
"""we'll return every funcall result"""
model_output
=
"""I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
functions.invalid_get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing" <|tool_call_end|> <|tool_call_begin|>
functions.valid_get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>"""
extracted_tool_calls
=
kimi_k2_tool_parser
.
extract_tool_calls
(
model_output
,
request
=
None
)
# type: ignore[arg-type]
assert
extracted_tool_calls
.
tools_called
# Should extract only the valid JSON tool calls
assert
len
(
extracted_tool_calls
.
tool_calls
)
==
2
assert
extracted_tool_calls
.
tool_calls
[
0
].
function
.
name
==
"invalid_get_weather"
assert
extracted_tool_calls
.
tool_calls
[
1
].
function
.
name
==
"valid_get_weather"
def
test_extract_tool_calls_invalid_funcall
(
kimi_k2_tool_parser
):
"""we'll return every funcall result"""
model_output
=
"""I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
functions.invalid_get_weather.0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_call_begin|>
functions.valid_get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>"""
extracted_tool_calls
=
kimi_k2_tool_parser
.
extract_tool_calls
(
model_output
,
request
=
None
)
# type: ignore[arg-type]
assert
extracted_tool_calls
.
tools_called
# Should extract only the valid JSON tool calls
assert
len
(
extracted_tool_calls
.
tool_calls
)
==
1
assert
extracted_tool_calls
.
tool_calls
[
0
].
function
.
name
==
"valid_get_weather"
def
test_streaming_basic_functionality
(
kimi_k2_tool_parser
):
"""Test basic streaming functionality."""
# Reset streaming state
kimi_k2_tool_parser
.
current_tool_name_sent
=
False
kimi_k2_tool_parser
.
prev_tool_call_arr
=
[]
kimi_k2_tool_parser
.
current_tool_id
=
-
1
kimi_k2_tool_parser
.
streamed_args_for_tool
=
[]
# Test with a simple tool call
current_text
=
""" check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_calls_section_end|>"""
# First call should handle the initial setup
result
=
kimi_k2_tool_parser
.
extract_tool_calls_streaming
(
previous_text
=
"I'll help you"
,
current_text
=
current_text
,
delta_text
=
"<|tool_calls_section_end|>"
,
previous_token_ids
=
[],
current_token_ids
=
[],
delta_token_ids
=
[],
request
=
None
,
)
# The result might be None or contain tool call information
# This depends on the internal state management
if
result
is
not
None
and
hasattr
(
result
,
'tool_calls'
)
and
result
.
tool_calls
:
assert
len
(
result
.
tool_calls
)
>=
0
def
test_streaming_no_tool_calls
(
kimi_k2_tool_parser
):
"""Test streaming when there are no tool calls."""
current_text
=
"This is just regular text without any tool calls."
result
=
kimi_k2_tool_parser
.
extract_tool_calls_streaming
(
previous_text
=
"This is just regular text"
,
current_text
=
current_text
,
delta_text
=
" without any tool calls."
,
previous_token_ids
=
[],
current_token_ids
=
[],
delta_token_ids
=
[],
request
=
None
,
)
# Should return the delta text as content
assert
result
is
not
None
assert
hasattr
(
result
,
'content'
)
assert
result
.
content
==
" without any tool calls."
tests/tool_use/test_qwen3coder_tool_parser.py
0 → 100644
View file @
711aa9d5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
json
from
collections.abc
import
Generator
from
typing
import
Optional
import
pytest
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
ChatCompletionToolsParam
,
DeltaMessage
,
FunctionCall
,
ToolCall
)
from
vllm.entrypoints.openai.tool_parsers.qwen3coder_tool_parser
import
(
Qwen3CoderToolParser
)
from
vllm.transformers_utils.detokenizer
import
detokenize_incrementally
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
get_tokenizer
MODEL
=
"Qwen/Qwen3-Coder-480B-A35B-Instruct-FP8"
@
pytest
.
fixture
(
scope
=
"module"
)
def
qwen3_tokenizer
():
return
get_tokenizer
(
tokenizer_name
=
MODEL
)
@
pytest
.
fixture
def
qwen3_tool_parser
(
qwen3_tokenizer
):
return
Qwen3CoderToolParser
(
qwen3_tokenizer
)
@
pytest
.
fixture
def
sample_tools
():
return
[
ChatCompletionToolsParam
(
type
=
"function"
,
function
=
{
"name"
:
"get_current_weather"
,
"description"
:
"Get the current weather"
,
"parameters"
:
{
"type"
:
"object"
,
"properties"
:
{
"city"
:
{
"type"
:
"string"
,
"description"
:
"The city name"
},
"state"
:
{
"type"
:
"string"
,
"description"
:
"The state code"
},
"unit"
:
{
"type"
:
"string"
,
"enum"
:
[
"fahrenheit"
,
"celsius"
]
}
},
"required"
:
[
"city"
,
"state"
]
}
}),
ChatCompletionToolsParam
(
type
=
"function"
,
function
=
{
"name"
:
"calculate_area"
,
"description"
:
"Calculate area of a shape"
,
"parameters"
:
{
"type"
:
"object"
,
"properties"
:
{
"shape"
:
{
"type"
:
"string"
},
"dimensions"
:
{
"type"
:
"object"
},
"precision"
:
{
"type"
:
"integer"
}
}
}
})
]
def
assert_tool_calls
(
actual_tool_calls
:
list
[
ToolCall
],
expected_tool_calls
:
list
[
ToolCall
]):
assert
len
(
actual_tool_calls
)
==
len
(
expected_tool_calls
)
for
actual_tool_call
,
expected_tool_call
in
zip
(
actual_tool_calls
,
expected_tool_calls
):
# Qwen3 parser doesn't generate IDs during extraction
assert
actual_tool_call
.
type
==
"function"
assert
(
actual_tool_call
.
function
.
name
==
expected_tool_call
.
function
.
name
)
assert
(
json
.
loads
(
actual_tool_call
.
function
.
arguments
)
==
json
.
loads
(
expected_tool_call
.
function
.
arguments
))
def
stream_delta_message_generator
(
qwen3_tool_parser
:
Qwen3CoderToolParser
,
qwen3_tokenizer
:
AnyTokenizer
,
model_output
:
str
,
request
:
Optional
[
ChatCompletionRequest
]
=
None
)
->
Generator
[
DeltaMessage
,
None
,
None
]:
all_token_ids
=
qwen3_tokenizer
.
encode
(
model_output
,
add_special_tokens
=
False
)
previous_text
=
""
previous_tokens
=
None
prefix_offset
=
0
read_offset
=
0
for
i
,
delta_token
in
enumerate
(
all_token_ids
):
delta_token_ids
=
[
delta_token
]
previous_token_ids
=
all_token_ids
[:
i
]
current_token_ids
=
all_token_ids
[:
i
+
1
]
(
new_tokens
,
delta_text
,
new_prefix_offset
,
new_read_offset
)
=
detokenize_incrementally
(
tokenizer
=
qwen3_tokenizer
,
all_input_ids
=
current_token_ids
,
prev_tokens
=
previous_tokens
,
prefix_offset
=
prefix_offset
,
read_offset
=
read_offset
,
skip_special_tokens
=
False
,
spaces_between_special_tokens
=
True
,
)
current_text
=
previous_text
+
delta_text
delta_message
=
qwen3_tool_parser
.
extract_tool_calls_streaming
(
previous_text
,
current_text
,
delta_text
,
previous_token_ids
,
current_token_ids
,
delta_token_ids
,
request
=
request
,
)
if
delta_message
:
yield
delta_message
previous_text
=
current_text
previous_tokens
=
(
previous_tokens
+
new_tokens
if
previous_tokens
else
new_tokens
)
prefix_offset
=
new_prefix_offset
read_offset
=
new_read_offset
def
test_extract_tool_calls_no_tools
(
qwen3_tool_parser
):
model_output
=
"This is a test response without any tool calls"
extracted_tool_calls
=
qwen3_tool_parser
.
extract_tool_calls
(
model_output
,
request
=
None
)
# type: ignore[arg-type]
assert
not
extracted_tool_calls
.
tools_called
assert
extracted_tool_calls
.
tool_calls
==
[]
assert
extracted_tool_calls
.
content
==
model_output
@
pytest
.
mark
.
parametrize
(
ids
=
[
"single_tool"
,
"single_tool_with_content"
,
"single_tool_multiline_param"
,
"parallel_tools"
,
"tool_with_typed_params"
,
],
argnames
=
[
"model_output"
,
"expected_tool_calls"
,
"expected_content"
],
argvalues
=
[
(
'''<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>'''
,
[
ToolCall
(
function
=
FunctionCall
(
name
=
"get_current_weather"
,
arguments
=
json
.
dumps
({
"city"
:
"Dallas"
,
"state"
:
"TX"
,
"unit"
:
"fahrenheit"
})))
],
None
),
(
'''Sure! Let me check the weather for you.<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>'''
,
[
ToolCall
(
function
=
FunctionCall
(
name
=
"get_current_weather"
,
arguments
=
json
.
dumps
({
"city"
:
"Dallas"
,
"state"
:
"TX"
,
"unit"
:
"fahrenheit"
})))
],
"Sure! Let me check the weather for you."
),
(
'''<tool_call>
<function=calculate_area>
<parameter=shape>
rectangle
</parameter>
<parameter=dimensions>
{"width": 10,
"height": 20}
</parameter>
<parameter=precision>
2
</parameter>
</function>
</tool_call>'''
,
[
ToolCall
(
function
=
FunctionCall
(
name
=
"calculate_area"
,
arguments
=
json
.
dumps
({
"shape"
:
"rectangle"
,
"dimensions"
:
{
"width"
:
10
,
"height"
:
20
},
"precision"
:
2
})))
],
None
),
(
'''<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>
<tool_call>
<function=get_current_weather>
<parameter=city>
Orlando
</parameter>
<parameter=state>
FL
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>'''
,
[
ToolCall
(
function
=
FunctionCall
(
name
=
"get_current_weather"
,
arguments
=
json
.
dumps
({
"city"
:
"Dallas"
,
"state"
:
"TX"
,
"unit"
:
"fahrenheit"
}))),
ToolCall
(
function
=
FunctionCall
(
name
=
"get_current_weather"
,
arguments
=
json
.
dumps
({
"city"
:
"Orlando"
,
"state"
:
"FL"
,
"unit"
:
"fahrenheit"
})))
],
None
),
(
'''Let me calculate that area for you.<tool_call>
<function=calculate_area>
<parameter=shape>
circle
</parameter>
<parameter=dimensions>
{"radius": 15.5}
</parameter>
<parameter=precision>
3
</parameter>
</function>
</tool_call>'''
,
[
ToolCall
(
function
=
FunctionCall
(
name
=
"calculate_area"
,
arguments
=
json
.
dumps
({
"shape"
:
"circle"
,
"dimensions"
:
{
"radius"
:
15.5
},
"precision"
:
3
})))
],
"Let me calculate that area for you."
),
],
)
def
test_extract_tool_calls
(
qwen3_tool_parser
,
sample_tools
,
model_output
,
expected_tool_calls
,
expected_content
):
request
=
ChatCompletionRequest
(
model
=
MODEL
,
messages
=
[],
tools
=
sample_tools
)
extracted_tool_calls
=
qwen3_tool_parser
.
extract_tool_calls
(
model_output
,
request
=
request
)
assert
extracted_tool_calls
.
tools_called
assert_tool_calls
(
extracted_tool_calls
.
tool_calls
,
expected_tool_calls
)
assert
extracted_tool_calls
.
content
==
expected_content
def
test_extract_tool_calls_fallback_no_tags
(
qwen3_tool_parser
,
sample_tools
):
"""Test fallback parsing when XML tags are missing"""
model_output
=
'''<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
</function>'''
request
=
ChatCompletionRequest
(
model
=
MODEL
,
messages
=
[],
tools
=
sample_tools
)
extracted_tool_calls
=
qwen3_tool_parser
.
extract_tool_calls
(
model_output
,
request
=
request
)
assert
extracted_tool_calls
.
tools_called
assert
len
(
extracted_tool_calls
.
tool_calls
)
==
1
assert
(
extracted_tool_calls
.
tool_calls
[
0
].
function
.
name
==
"get_current_weather"
)
def
test_extract_tool_calls_type_conversion
(
qwen3_tool_parser
):
"""Test parameter type conversion based on tool schema"""
tools
=
[
ChatCompletionToolsParam
(
type
=
"function"
,
function
=
{
"name"
:
"test_types"
,
"parameters"
:
{
"type"
:
"object"
,
"properties"
:
{
"int_param"
:
{
"type"
:
"integer"
},
"float_param"
:
{
"type"
:
"float"
},
"bool_param"
:
{
"type"
:
"boolean"
},
"str_param"
:
{
"type"
:
"string"
},
"obj_param"
:
{
"type"
:
"object"
}
}
}
})
]
model_output
=
'''<tool_call>
<function=test_types>
<parameter=int_param>
42
</parameter>
<parameter=float_param>
3.14
</parameter>
<parameter=bool_param>
true
</parameter>
<parameter=str_param>
hello world
</parameter>
<parameter=obj_param>
{"key": "value"}
</parameter>
</function>
</tool_call>'''
request
=
ChatCompletionRequest
(
model
=
MODEL
,
messages
=
[],
tools
=
tools
)
extracted_tool_calls
=
qwen3_tool_parser
.
extract_tool_calls
(
model_output
,
request
=
request
)
args
=
json
.
loads
(
extracted_tool_calls
.
tool_calls
[
0
].
function
.
arguments
)
assert
args
[
"int_param"
]
==
42
assert
args
[
"float_param"
]
==
3.14
assert
args
[
"bool_param"
]
is
True
assert
args
[
"str_param"
]
==
"hello world"
assert
args
[
"obj_param"
]
==
{
"key"
:
"value"
}
@
pytest
.
mark
.
parametrize
(
ids
=
[
"no_tools"
,
"single_tool"
,
"single_tool_with_content"
,
"parallel_tools"
,
],
argnames
=
[
"model_output"
,
"expected_tool_calls"
,
"expected_content"
],
argvalues
=
[
(
"This is a test without tools"
,
[],
"This is a test without tools"
),
(
'''<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>'''
,
[
ToolCall
(
function
=
FunctionCall
(
name
=
"get_current_weather"
,
arguments
=
json
.
dumps
({
"city"
:
"Dallas"
,
"state"
:
"TX"
,
"unit"
:
"fahrenheit"
})))
],
""
),
(
'''Sure! Let me check the weather for you.<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>'''
,
[
ToolCall
(
function
=
FunctionCall
(
name
=
"get_current_weather"
,
arguments
=
json
.
dumps
({
"city"
:
"Dallas"
,
"state"
:
"TX"
,
"unit"
:
"fahrenheit"
})))
],
"Sure! Let me check the weather for you."
),
(
'''<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>
<tool_call>
<function=get_current_weather>
<parameter=city>
Orlando
</parameter>
<parameter=state>
FL
</parameter>
<parameter=unit>
celsius
</parameter>
</function>
</tool_call>'''
,
[
ToolCall
(
function
=
FunctionCall
(
name
=
"get_current_weather"
,
arguments
=
json
.
dumps
({
"city"
:
"Dallas"
,
"state"
:
"TX"
,
"unit"
:
"fahrenheit"
}))),
ToolCall
(
function
=
FunctionCall
(
name
=
"get_current_weather"
,
arguments
=
json
.
dumps
({
"city"
:
"Orlando"
,
"state"
:
"FL"
,
"unit"
:
"celsius"
})))
],
""
),
],
)
def
test_extract_tool_calls_streaming
(
qwen3_tool_parser
,
qwen3_tokenizer
,
sample_tools
,
model_output
,
expected_tool_calls
,
expected_content
):
"""Test incremental streaming behavior"""
request
=
ChatCompletionRequest
(
model
=
MODEL
,
messages
=
[],
tools
=
sample_tools
)
other_content
=
''
tool_states
=
{}
# Track state per tool index
for
delta_message
in
stream_delta_message_generator
(
qwen3_tool_parser
,
qwen3_tokenizer
,
model_output
,
request
):
# role should never be streamed from tool parser
assert
not
delta_message
.
role
if
delta_message
.
content
:
other_content
+=
delta_message
.
content
if
delta_message
.
tool_calls
:
for
tool_call
in
delta_message
.
tool_calls
:
idx
=
tool_call
.
index
# Initialize state for new tool
if
idx
not
in
tool_states
:
tool_states
[
idx
]
=
{
"id"
:
None
,
"name"
:
None
,
"arguments"
:
""
,
"type"
:
None
}
# First chunk should have id, name, and type
if
tool_call
.
id
:
tool_states
[
idx
][
"id"
]
=
tool_call
.
id
if
tool_call
.
type
:
assert
tool_call
.
type
==
"function"
tool_states
[
idx
][
"type"
]
=
tool_call
.
type
if
tool_call
.
function
:
if
tool_call
.
function
.
name
:
# Should only be set once
assert
tool_states
[
idx
][
"name"
]
is
None
tool_states
[
idx
][
"name"
]
=
tool_call
.
function
.
name
if
tool_call
.
function
.
arguments
is
not
None
:
# Accumulate arguments incrementally
tool_states
[
idx
][
"arguments"
]
+=
tool_call
.
function
.
arguments
# Verify final content
assert
other_content
==
expected_content
# Verify we got all expected tool calls
assert
len
(
tool_states
)
==
len
(
expected_tool_calls
)
# Verify each tool call
for
idx
,
expected_tool
in
enumerate
(
expected_tool_calls
):
state
=
tool_states
[
idx
]
assert
state
[
"id"
]
is
not
None
assert
state
[
"type"
]
==
"function"
assert
state
[
"name"
]
==
expected_tool
.
function
.
name
# Parse accumulated arguments
arguments_str
=
state
[
"arguments"
]
assert
arguments_str
is
not
None
actual_args
=
json
.
loads
(
arguments_str
)
expected_args
=
json
.
loads
(
expected_tool
.
function
.
arguments
)
assert
actual_args
==
expected_args
def
test_extract_tool_calls_streaming_incremental
(
qwen3_tool_parser
,
qwen3_tokenizer
,
sample_tools
):
"""Test that streaming is truly incremental"""
model_output
=
'''I'll check the weather.<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
</function>
</tool_call>'''
request
=
ChatCompletionRequest
(
model
=
MODEL
,
messages
=
[],
tools
=
sample_tools
)
chunks
=
[]
for
delta_message
in
stream_delta_message_generator
(
qwen3_tool_parser
,
qwen3_tokenizer
,
model_output
,
request
):
chunks
.
append
(
delta_message
)
# Should have multiple chunks
assert
len
(
chunks
)
>
3
# First chunk(s) should be content
assert
chunks
[
0
].
content
is
not
None
assert
chunks
[
0
].
tool_calls
is
None
or
chunks
[
0
].
tool_calls
==
[]
# Should have a chunk with tool header (id, name, type)
header_found
=
False
for
chunk
in
chunks
:
if
chunk
.
tool_calls
and
chunk
.
tool_calls
[
0
].
id
:
header_found
=
True
assert
(
chunk
.
tool_calls
[
0
].
function
.
name
==
"get_current_weather"
)
assert
chunk
.
tool_calls
[
0
].
type
==
"function"
# Empty initially
assert
chunk
.
tool_calls
[
0
].
function
.
arguments
==
""
break
assert
header_found
# Should have chunks with incremental arguments
arg_chunks
=
[]
for
chunk
in
chunks
:
if
chunk
.
tool_calls
and
chunk
.
tool_calls
[
0
].
function
.
arguments
:
arg_chunks
.
append
(
chunk
.
tool_calls
[
0
].
function
.
arguments
)
# Arguments should be streamed incrementally
assert
len
(
arg_chunks
)
>
1
# Concatenated arguments should form valid JSON
full_args
=
""
.
join
(
arg_chunks
)
parsed_args
=
json
.
loads
(
full_args
)
assert
parsed_args
[
"city"
]
==
"Dallas"
assert
parsed_args
[
"state"
]
==
"TX"
Prev
1
…
16
17
18
19
20
21
22
23
24
…
26
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment