Commit 711aa9d5 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.0' into v0.10.0-dev

parents 751c492c 6d8d0a24
......@@ -21,7 +21,7 @@ def v1(run_with_both_engines):
def _generate(
model: LLM,
llm: LLM,
prompt: str,
num_prompt_tokens: int,
temperature: float = 0,
......@@ -33,7 +33,7 @@ def _generate(
)
# [([output_token_ids, ], [output_text, ]), ]
output = model.generate([prompt], sampling_params=sampling_params)
output = llm.generate([prompt], sampling_params=sampling_params)
output_token_ids = output[0][0][0][num_prompt_tokens:]
# [0] first (and only) request output
......@@ -68,10 +68,10 @@ class TestOneTokenBadWord:
assert self.target_token_id not in output_token_ids
def _generate(self,
model: LLM,
llm: LLM,
bad_words: Optional[list[str]] = None) -> list[int]:
return _generate(
model=model,
llm=llm,
prompt=self.PROMPT,
num_prompt_tokens=self.num_prompt_tokens,
bad_words=bad_words,
......@@ -158,10 +158,10 @@ class TestTwoTokenBadWord:
or (self.neighbour_token_id2 in output_token_ids))
def _generate(self,
model: LLM,
llm: LLM,
bad_words: Optional[list[str]] = None) -> list[int]:
return _generate(
model=model,
llm=llm,
prompt=self.PROMPT,
num_prompt_tokens=self.num_prompt_tokens,
bad_words=bad_words,
......
......@@ -51,7 +51,7 @@ def test_random_sample_with_seed(
sampling_params_seed_2 = copy.deepcopy(sampling_params)
sampling_params_seed_2.seed = 200
llm = vllm_model.model
llm = vllm_model.llm
for prompt in example_prompts:
for params in (
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for rejection sampling."""
import pytest
import torch
from vllm.model_executor.layers.typical_acceptance_sampler import (
TypicalAcceptanceSampler)
from vllm.model_executor.utils import set_random_seed
CUDA_DEVICES = [f"cuda:{i}" for i in range(1)]
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
"""
This file tests V0 internals, so set VLLM_USE_V1=0.
"""
monkeypatch.setenv('VLLM_USE_V1', '0')
def get_zero_temperature_prob_dist(batch_size, k, vocab_size):
"""
Generates a fake temperature zero probability distribution.
Returns:
1. A fake temperature zero probability distribution of shape
[batch_size, k, vocab_size]
2. Tensor of shape [batch_size, k] containing the token ids
of the probability 1.0 tokens at each position.
"""
# Simulate temperature 0 probability distribution for target probabilities
# and create target probabilities such that only 1 token id has
# probability 1.0
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
probs = torch.rand(batch_size, k, vocab_size)
_, zero_temperature_token_ids = torch.max(probs, dim=-1)
# set the probability of the tokens with ids in zero_temperature_token_ids
# to 1 and the rest to 0.
target_probs = torch.zeros_like(probs).scatter_(
-1, zero_temperature_token_ids.unsqueeze(-1), 1.0)
return target_probs, zero_temperature_token_ids
def get_draft_token_ids(batch_size: int, k: int, vocab_size: int,
token_ids_to_exclude: torch.Tensor):
"""
Returns a tensor of shape [batch_size, k] of fake draft token ids
drawn randomly from a vocab of size vocab_size. We however ensure
that token_ids from token_ids_to_exclude are excluded at the
corresponding positions.
"""
draft_token_ids = torch.empty(batch_size, k, dtype=torch.long)
for i in range(batch_size):
for j in range(k):
# Generate a random token ID excluding token_ids_to_exclude[i, j]
while True:
token_id = torch.randint(0, vocab_size, (1, )).item()
if token_id != token_ids_to_exclude[i, j]:
draft_token_ids[i, j] = token_id
break
return draft_token_ids
def get_acceptance_sampler(
posterior_threshold: float = 0.03,
posterior_alpha: float = 0.9,
strict_mode: bool = False,
) -> TypicalAcceptanceSampler:
"""
Initializes and returns a TypicalAcceptanceSampler.
"""
return TypicalAcceptanceSampler(posterior_threshold, posterior_alpha,
strict_mode)
@pytest.mark.parametrize("k", list(range(1, 6)))
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
@pytest.mark.parametrize("batch_size", list(range(1, 32)))
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
device: str):
"""
Tests that the TypicalAcceptancSampler forward succeeds for
different combinations of k, vocab_size, batch_size and num devices.
"""
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler()
typical_acceptance_sampler.init_gpu_tensors(device=device)
target_with_bonus_probs = torch.rand(batch_size,
k + 1,
vocab_size,
dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64)
# Verify that sampling succeeds for all cases.
typical_acceptance_sampler(target_with_bonus_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
@pytest.mark.parametrize("which_token_ids",
["bonus_token_ids", "draft_token_ids"])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
which_token_ids: str, device: str):
"""
Tests that we throw an exception of the token ids fall outside
the bound of the provided vocabulary.
"""
k = 3
batch_size = 5
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
target_with_bonus_probs = torch.rand(batch_size,
k + 1,
vocab_size,
dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64)
# Verify that appropriate exceptions are thrown for out
# of bound vocabs.
oob_token_ids = None
if which_token_ids == "bonus_token_ids":
oob_token_ids = bonus_token_ids
elif which_token_ids == "draft_token_ids":
oob_token_ids = draft_token_ids
else:
raise AssertionError()
if above_or_below_vocab_range == "above":
rogue_token_id = vocab_size + 1
elif above_or_below_vocab_range == "below":
rogue_token_id = -1
else:
raise AssertionError()
oob_token_ids[0][0] = rogue_token_id
with pytest.raises(AssertionError):
typical_acceptance_sampler(target_with_bonus_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_uniform_target_distribution_accepts_all_tokens(
seed: int, device: str):
"""
Test the TypicalAcceptanceSampler with a uniform target probability
distribution.
This test verifies that when provided with a uniform target probability
distribution, the TypicalAcceptanceSampler accepts all draft tokens. The
entropy of the uniform target distribution being high should lead to all
draft tokens being accepted.
"""
set_random_seed(seed)
k = 3
batch_size = 5
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
target_with_bonus_probs = torch.rand(batch_size,
k + 1,
vocab_size,
dtype=torch.float32)
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
output_token_ids = typical_acceptance_sampler(
target_with_bonus_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
# We are using a uniform target probability distribution.
# For a uniform distribution the entropy is very high and it
# should lead to all draft tokens being accepted. Verify that.
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, -1] == bonus_token_ids.squeeze())
assert torch.all(output_token_ids[:, :k] == draft_token_ids)
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_temperature_zero_target_distribution(seed: int, device: str):
"""
Test the TypicalAcceptanceSampler with a zero-temperature target
probability distribution.
This test verifies that when using a zero-temperature target probability
distribution, where only one token has a probability of 1.0, the
TypicalAcceptanceSampler correctly rejects all draft tokens that do not
match this probability. Additionally, it ensures that when all draft
tokens are rejected, the sampler falls back to greedy sampling to select a
single token from the target distribution.
"""
set_random_seed(seed)
k = 3
batch_size = 5
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
# Simulate temperature 0 probability distribution for target probabilities
# and create target probabilities such that only 1 token id has
# probability 1.0
target_with_bonus_probs, zero_temperature_token_ids = \
get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size)
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
# Populate draft_token_ids such that they exclude the token_ids
# with probability = 1.0
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
zero_temperature_token_ids)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
# The target probaility distribution is a temperature zero distribution
# with zero entropy. Since our draft token ids don't match the probability
# 1.0 tokens in the target distribution we will reject all of them and
# fallback to the greedy sampling for selecting 1 token for each sequence.
# Verify the same.
output_token_ids = typical_acceptance_sampler(
target_with_bonus_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, -1] == -1)
assert torch.all(output_token_ids[:, 0] == zero_temperature_token_ids[:,
0])
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_mixed_target_distribution(seed: int, device: str):
"""
Test the TypicalAcceptanceSampler with a mixed target probability
distribution.
This test ensures that the TypicalAcceptanceSampler handles a mixed
target probability distribution correctly. Specifically, it uses a
zero-temperature distribution for some sequences and a uniform
distribution for others. The test verifies that:
- For sequences with a zero-temperature distribution, only the token
with a probability of 1.0 is accepted, and all other tokens are rejected.
- For sequences with a uniform distribution, all draft tokens are
accepted.
"""
set_random_seed(seed)
k = 3
batch_size = 4
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
# For sequences 0 and 2 set the distribution to a temperature
# zero distribution. For sequences 1 and 3 set it to a uniform
# distribution.
target_with_bonus_probs, zero_temperature_token_ids = \
get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size)
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
target_probs = target_with_bonus_probs[:, :-1]
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
zero_temperature_token_ids)
uniform_probs = torch.rand(2, k, vocab_size, dtype=torch.float32)
target_probs[[1, 3]] = uniform_probs
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
output_token_ids = typical_acceptance_sampler(
target_with_bonus_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
# verify the shape of output_token_ids
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
# For sequences 0 and 2 verify that only 1 token is accepted
# which is the token with probability 1.0 in the target distribution
# at position 0.
assert torch.all(output_token_ids[[0, 2], 1:] == -1)
assert (torch.all(output_token_ids[[0, 2],
0] == zero_temperature_token_ids[[0, 2],
0]))
# For sequences 1 and 3 verify that all tokens are accepted since the
# target probability distribution is uniform. In addition verify that
# we also accept the bonus tokens.
assert torch.all(
output_token_ids[[1, 3], :-1] == draft_token_ids[[1, 3], :])
assert torch.all(output_token_ids[[1, 3], -1] != -1)
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_accept_tokens_partially(seed: int, device: str):
"""
Test the TypicalAcceptanceSampler's behavior when only a subset of draft
tokens should be accepted.
This test verifies that the TypicalAcceptanceSampler correctly accepts or
rejects draft tokens based on a zero-temperature target probability
distribution. Specifically, it ensures that:
- When all draft tokens match tokens with a probability of 1.0 in the
target distribution, all draft tokens are accepted.
- When only some draft tokens match tokens with a probability of 1.0 in
the target distribution, only those matching tokens are accepted, and the
rest are rejected.
"""
set_random_seed(seed)
k = 5
batch_size = 1
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
# Create a temperature zero target probability distribution and ensure
# all draft token ids correspond to the tokens with 1.0 probability.
# Verify that all of them are accepted.
target_with_bonus_probs, zero_temperature_token_ids = \
get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size)
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
draft_token_ids = zero_temperature_token_ids
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
output_token_ids = typical_acceptance_sampler(
target_with_bonus_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
assert torch.all(output_token_ids[:, -1] == bonus_token_ids)
# Next only keep the first 2 draft tokens same as the zero temperature
# tokens. For the remaining 3 choose some other tokens. In the
# response we will expect the first 2 tokens to be the same as the
# draft tokens and the recovered token and rest as -1
draft_token_ids_to_replace = get_draft_token_ids(
batch_size, k, vocab_size, zero_temperature_token_ids)
draft_token_ids = torch.cat(
(draft_token_ids[:, :2], draft_token_ids_to_replace[:, -3:]), dim=1)
output_token_ids = typical_acceptance_sampler(
target_with_bonus_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, :2] == draft_token_ids[:, :2])
assert torch.all(
output_token_ids[:, 2] == target_with_bonus_probs.argmax(-1)[:, 2])
assert torch.all(output_token_ids[:, -3:] == -1)
@pytest.mark.parametrize("seed", list(range(1)))
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_accept_tokens_set_non_default_posteriors(seed: int, device: str):
"""
Test the TypicalAcceptanceSampler with custom posterior thresholds and
alpha values. This test verifies that by modifying the posterior
thresholds and alpha values we can change the acceptance behavior of the
sampler.
"""
set_random_seed(seed)
k = 5
batch_size = 1
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
# Simulate temperature 0 probability distribution for target
# probabilities and create target probabilities such that only 1 token
# id has probability 1.0 and others have a very low probability of
# 0.00001. Populate draft_token_ids such that they exclude the token_ids
# with probability = 1.0. Without any changes to the posterior thresholds
# none of the draft tokens are accepted.
target_probs, zero_temperature_token_ids = get_zero_temperature_prob_dist(
batch_size, k + 1, vocab_size)
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
target_probs[target_probs == 0] = 0.00001
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
zero_temperature_token_ids)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
output_token_ids = typical_acceptance_sampler(
target_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, 1:-1] == -1)
# Change the posterior threshold values to 0.0 so that we will
# now accept even draft tokens with very low probability in the
# target distribution. Simulate and verify the same.
typical_acceptance_sampler = TypicalAcceptanceSampler(
strict_mode=True, posterior_threshold=0.0, posterior_alpha=0.0)
typical_acceptance_sampler.init_gpu_tensors(device=device)
output_token_ids = typical_acceptance_sampler(
target_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
assert torch.all(output_token_ids[:, -1] == bonus_token_ids)
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_get_recovered_token_ids(seed: int, device: str):
"""
Test the TypicalAcceptanceSampler's method for generating
replacement token IDs.
This test verifies that the `_get_recovered_token_ids` method of the
TypicalAcceptanceSampler correctly identifies the token IDs to be used
as recovered token IDs based on the target probability distribution.
Specifically, it ensures that the method correctly identifies the
tokens with the highest probability for each sequence in the batch.
"""
set_random_seed(seed)
k = 10
batch_size = 5
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
expected_replacement_tokens = torch.argmax(target_probs, dim=-1)
actual_replacement_tokens = (
typical_acceptance_sampler._get_recovered_token_ids(target_probs))
assert torch.all(expected_replacement_tokens == actual_replacement_tokens)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
"""
Since this module is V0 only, set VLLM_USE_V1=0 for
all tests in the module.
"""
monkeypatch.setenv('VLLM_USE_V1', '0')
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from itertools import cycle
from typing import Optional, Union
import pytest
import torch
from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import PromptLogprobs, SampleLogprobs
from ...models.utils import (TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs,
check_logprobs_close, check_outputs_equal)
from ...utils import RemoteOpenAIServer
PROMPTS = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
"San Francisco is know for its",
"Facebook was created in 2004 by",
"Curious George is a",
"Python 3.11 brings improvements to its",
]
@pytest.fixture
def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
test_llm_kwargs, seed):
def generate():
kwargs = {
**common_llm_kwargs,
**per_test_common_llm_kwargs,
**test_llm_kwargs,
}
llm = LLM(**kwargs)
if seed is not None:
set_random_seed(seed)
yield llm
del llm
cleanup_dist_env_and_memory()
return generate
def maybe_assert_ngram_worker(llm):
# Verify the proposer worker is ngram if ngram is specified.
if (llm.llm_engine.speculative_config is not None
and llm.llm_engine.speculative_config.method == "ngram"):
from vllm.spec_decode.ngram_worker import NGramWorker
assert isinstance(
llm.llm_engine.model_executor.driver_worker.proposer_worker,
NGramWorker)
def get_output_from_llm_generator(
llm_generator, prompts,
sampling_params) -> tuple[list[str], list[list[int]], float]:
tokens: list[str] = []
token_ids: list[list[int]] = []
acceptance_rate: float = -1.0
for llm in llm_generator():
maybe_assert_ngram_worker(llm)
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
token_ids = [output.outputs[0].token_ids for output in outputs]
tokens = [output.outputs[0].text for output in outputs]
# Fetch acceptance rate if logging is enabled.
if stat_loggers := getattr(llm.llm_engine, "stat_loggers", None):
stat_logger = stat_loggers["prometheus"]
acceptance_rate = (stat_logger.metrics.
gauge_spec_decode_draft_acceptance_rate.labels(
**stat_logger.labels)._value.get())
del llm
return tokens, token_ids, acceptance_rate
def check_logprobs_correctness(
spec_outputs: Sequence[Union[TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs]],
baseline_outputs: Sequence[Union[TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs]],
disable_logprobs: bool = False,
):
"""Compare sampled and prompt logprobs between baseline and spec decoding
"""
if not disable_logprobs:
return check_logprobs_close(
outputs_0_lst=baseline_outputs,
outputs_1_lst=spec_outputs,
name_0="org",
name_1="sd",
)
# Check correctness when disable_logprobs == True
for spec_output, baseline_output in zip(spec_outputs, baseline_outputs):
# Check generated token logprobs.
spec_logprobs = spec_output[2]
baseline_logprobs = baseline_output[2]
_check_logprobs_when_output_disabled(spec_logprobs,
baseline_logprobs,
is_prompt_logprobs=False)
# Check prompt logprobs too, if they exist
if len(baseline_output) == 4:
assert len(spec_output) == 4
spec_prompt_logprobs = spec_output[3]
baseline_prompt_logprobs = baseline_output[3]
_check_logprobs_when_output_disabled(spec_prompt_logprobs,
baseline_prompt_logprobs,
is_prompt_logprobs=True)
def _check_logprobs_when_output_disabled(
spec_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs],
baseline_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs],
is_prompt_logprobs: bool = False,
):
# Prompt logprobs are optional
if is_prompt_logprobs and baseline_logprobs is None:
assert spec_logprobs is None
return
assert spec_logprobs is not None
assert baseline_logprobs is not None
assert len(spec_logprobs) == len(baseline_logprobs)
# For each generated position of the sequence.
for pos, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate(
zip(spec_logprobs, baseline_logprobs)):
# First prompt logprob is expected to be None
if is_prompt_logprobs and baseline_pos_logprobs is None:
assert spec_pos_logprobs is None
assert pos == 0
continue
assert spec_pos_logprobs is not None
assert baseline_pos_logprobs is not None
# When disabled, the 1 logprob is returned with dummy values for the
# score and rank, but the token id should match the baseline model
assert len(spec_pos_logprobs) == 1
(spec_pos_logprob_token_id,
spec_pos_logprob) = next(iter(spec_pos_logprobs.items()))
assert spec_pos_logprob.rank == -1
assert spec_pos_logprob.logprob == 0.0
if isinstance(spec_pos_logprob_token_id, torch.Tensor):
spec_pos_logprob_token_id = spec_pos_logprob_token_id.item()
assert spec_pos_logprob_token_id in baseline_pos_logprobs
def run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size: int,
max_output_len: int,
seed: Optional[int] = 0,
temperature: float = 0.0,
disable_seed: bool = False,
ignore_eos: bool = True,
ensure_all_accepted: bool = False,
expected_acceptance_rate: Optional[float] = None,
logprobs: Optional[int] = None,
prompt_logprobs: Optional[int] = None,
disable_logprobs: bool = False):
org_args = {
**common_llm_kwargs,
**per_test_common_llm_kwargs,
**baseline_llm_kwargs,
}
sd_args = {
**common_llm_kwargs,
**per_test_common_llm_kwargs,
**test_llm_kwargs,
}
prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]
if disable_seed:
seed = None
sampling_params = SamplingParams(temperature=temperature,
max_tokens=max_output_len,
seed=seed,
ignore_eos=ignore_eos,
logprobs=logprobs,
prompt_logprobs=prompt_logprobs)
with vllm_runner(**org_args) as vllm_model:
org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
with vllm_runner(**sd_args) as vllm_model:
if ensure_all_accepted or expected_acceptance_rate is not None:
# Force log interval to be 0 to catch all metrics.
stat_logger = vllm_model.model.llm_engine.stat_loggers[
'prometheus']
stat_logger.local_interval = -100
sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
if ensure_all_accepted or expected_acceptance_rate is not None:
acceptance_rate = (stat_logger.metrics.
gauge_spec_decode_draft_acceptance_rate.labels(
**stat_logger.labels)._value.get())
if ensure_all_accepted:
assert True
# FIXME: ci fails to log acceptance rate.
# It works locally.
# assert acceptance_rate == 1.0
if expected_acceptance_rate is not None:
assert acceptance_rate >= expected_acceptance_rate - 1e-2
# Only pass token entries, not the logprobs
check_outputs_equal(outputs_0_lst=[out[0:2] for out in org_outputs],
outputs_1_lst=[out[0:2] for out in sd_outputs],
name_0="org",
name_1="sd")
# Check logprobs if requested
if logprobs is not None or prompt_logprobs is not None:
check_logprobs_correctness(sd_outputs, org_outputs, disable_logprobs)
def run_equality_correctness_test_tp(model,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size: int,
max_output_len: int,
seed: int = 0,
temperature: float = 0.0,
logprobs: Optional[int] = None):
"""Helper method that compares the outputs of both the baseline LLM and
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero.
"""
arg1 = common_llm_kwargs + per_test_common_llm_kwargs + baseline_llm_kwargs
arg2 = common_llm_kwargs + per_test_common_llm_kwargs + test_llm_kwargs
env1 = env2 = None
max_wait_seconds = 240
results = []
prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]
for args, env in ((arg1, env1), (arg2, env2)):
with RemoteOpenAIServer(model,
args,
env_dict=env,
max_wait_seconds=max_wait_seconds) as server:
client = server.get_client()
completion = client.completions.create(model=model,
prompt=prompts,
max_tokens=max_output_len,
seed=seed,
temperature=temperature,
logprobs=logprobs)
results.append({
"test":
"seeded_sampling",
"text": [choice.text for choice in completion.choices],
"logprobs": [choice.logprobs for choice in completion.choices],
"finish_reason":
[choice.finish_reason for choice in completion.choices],
"usage":
completion.usage,
})
n = len(results) // 2
arg1_results = results[:n]
arg2_results = results[n:]
# Separate logprobs to avoid asserting exact equality.
arg1_logprobs = [r.pop("logprobs") for r in arg1_results]
arg2_logprobs = [r.pop("logprobs") for r in arg2_results]
for arg1_result, arg2_result in zip(arg1_results, arg2_results):
assert arg1_result == arg2_result, (
f"Results for {model=} are not the same with {arg1=} and {arg2=}. "
f"{arg1_result=} != {arg2_result=}")
if logprobs:
for logs1, logs2 in zip(arg1_logprobs, arg2_logprobs):
for l1, l2 in zip(logs1, logs2):
assert l1.tokens == l2.tokens
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from .utils import create_seq_group_metadata_from_prompts, mock_worker
@pytest.mark.parametrize('num_target_seq_ids', [100])
@pytest.mark.skip_global_cleanup
def test_create_target_seq_id_iterator(num_target_seq_ids: int):
"""Verify all new sequence ids are greater than all input
seq ids.
"""
scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000)
all_seq_ids = [
[1, 3, 5, 7],
list(range(100)) + [0],
[100],
]
for seq_ids in all_seq_ids:
max_seq_id = max(seq_ids)
iterator = scorer._create_target_seq_id_iterator(seq_ids) # pylint: disable=protected-access
for _ in range(num_target_seq_ids):
assert next(iterator) > max_seq_id
@pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.skip_global_cleanup
def test_get_token_ids_to_score(k: int):
"""Verify correct tokens are selected for scoring.
"""
proposal_token_ids = torch.tensor(
list(range(k)),
dtype=torch.int64,
device='cuda',
)
expected_output: list[list[int]] = [
[],
]
for i in range(proposal_token_ids.shape[0]):
expected_output.append(proposal_token_ids[:i + 1].tolist())
scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000)
actual_output = scorer._get_token_ids_to_score(proposal_token_ids.tolist()) # pylint: disable=protected-access
actual_output = [
x.tolist() if isinstance(x, torch.Tensor) else x for x in actual_output
]
assert actual_output == expected_output
@pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.skip_global_cleanup
def test_create_single_target_seq_group_metadata(k: int):
"""Verify correct creation of a batch-expanded seq group metadata.
"""
prompt_tokens = [1, 2, 3]
prev_output_tokens = [4, 5, 6]
token_ids = list(range(k))
num_tokens_processed = len(prompt_tokens) + len(prev_output_tokens) - 1
final_seq_len = len(prompt_tokens) + len(prev_output_tokens) + len(
token_ids)
block_size = 32
input_seq_group_metadata = create_seq_group_metadata_from_prompts(
[prompt_tokens], 2048 // block_size, block_size, [final_seq_len],
[prev_output_tokens], [num_tokens_processed])[0]
input_seq_id = list(input_seq_group_metadata.seq_data.keys())[0]
target_seq_id = 100
scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000)
output = scorer._create_single_target_seq_group_metadata( # pylint: disable=protected-access
input_seq_group_metadata,
input_seq_id,
target_seq_id,
token_ids,
input_seq_group_metadata.sampling_params,
)
assert output.request_id == input_seq_group_metadata.request_id
assert output.sampling_params.repetition_penalty == \
input_seq_group_metadata.sampling_params.repetition_penalty
assert output.sampling_params.temperature == \
input_seq_group_metadata.sampling_params.temperature
assert output.sampling_params.top_p == \
input_seq_group_metadata.sampling_params.top_p
assert output.sampling_params.top_k == \
input_seq_group_metadata.sampling_params.top_k
assert len(output.seq_data) == 1
assert output.seq_data[target_seq_id].get_prompt_token_ids() == tuple(
prompt_tokens)
assert output.seq_data[target_seq_id].get_output_token_ids() == tuple(
prev_output_tokens + token_ids)
assert len(output.block_tables) == 1
assert output.block_tables[
target_seq_id] == input_seq_group_metadata.block_tables[input_seq_id]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock, patch
import pytest
import torch
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.metrics import AsyncMetricsCollector
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
from vllm.spec_decode.top1_proposer import Top1Proposer
from .test_utils import mock_spec_decode_sampler
from .utils import create_batch, mock_worker
@pytest.mark.parametrize('queue_size', [4])
@pytest.mark.parametrize('batch_size', [1])
@pytest.mark.parametrize('k', [1])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode()
def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int,
acceptance_sampler_method: str):
"""Verify that speculative tokens are disabled when the batch size
exceeds the threshold.
"""
disable_by_batch_size = 3
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(proposer_worker=draft_worker,
scorer_worker=target_worker,
spec_decode_sampler=mock_spec_decode_sampler(
acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector,
disable_by_batch_size=disable_by_batch_size)
exception_secret = 'artificial stop'
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k,
running_queue_size=queue_size)
if queue_size > disable_by_batch_size:
with patch.object(worker,
'_run_no_spec',
side_effect=ValueError(exception_secret)), \
pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=execute_model_req)
# When the batch size is larger than the threshold,
# we expect no speculative tokens (0).
expected_num_spec_tokens = None if queue_size < disable_by_batch_size else 0
assert seq_group_metadata_list[
0].num_speculative_tokens == expected_num_spec_tokens
draft_worker.sampler_output.side_effect = ValueError(exception_secret)
proposer = Top1Proposer(
worker=draft_worker,
device='cpu', # not used
vocab_size=100, # not used
# Must be long enough to avoid being skipped due to length.
max_proposal_len=1024,
)
if queue_size < disable_by_batch_size:
# Should raise exception when executing the mocked draft model.
with pytest.raises(ValueError, match=exception_secret):
proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())
else:
# Should not execute the draft model because spec decode is disabled
# for all requests. Accordingly, the proposal length should be 0.
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())
assert proposals.proposal_lens.tolist() == [0] * batch_size
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""This docstring details important information on the testing methodology.
This test verifies that memory usage remains constant (or never grows) when
we enable / disable speculation via --speculative-disable-by-batch-size.
There are a lot of things we try to keep track of between batches of requests
and if certain tensors are not freed from memory, can result in CUDA ooms.
This is particularly relevant for production situations where speculation might
be enabled during off hours, but disabled once traffic peaks during the workday.
Since traffic will stay high for a long period of time, verifying we do not
increase our memory usage over time is essential to prevent possible CUDA ooms.
"""
import torch
import vllm
from tests.core.utils import create_dummy_prompt
from vllm.sequence import SequenceGroup
ITERATIONS = 100
MAIN_MODEL = "JackFram/llama-68m"
# speculative model
SPEC_MODEL = "abhigoyal/vllm-medusa-llama-68m-random"
BATCH_SIZE = 5
SPEC_DISABLE_BATCH_SIZE = 2
def add_seq_group_to_engine(engine: vllm.LLMEngine, seq_group: SequenceGroup):
scheduler = engine.scheduler[0]
scheduler.add_seq_group(seq_group)
"""
Since we are using a batch size greater than the disabled batch size,
we can ensure we go through the _no_spec codepath for most of our engine steps.
"""
def test_memory_usage_no_spec():
previous_memory_allocated = None
llm = vllm.LLM(model=MAIN_MODEL,
speculative_config={
"model": SPEC_MODEL,
"num_speculative_tokens": 3,
"disable_by_batch_size": SPEC_DISABLE_BATCH_SIZE,
})
batch_sequences = set()
engine = llm.llm_engine
for i in range(ITERATIONS):
seq, seq_group = create_dummy_prompt(request_id=str(i),
prompt_length=10,
min_tokens=10,
max_tokens=10)
add_seq_group_to_engine(engine, seq_group)
batch_sequences.add(seq)
engine.step()
for seq in list(batch_sequences):
if seq.is_finished():
batch_sequences.remove(seq)
# If we aren't at our batch size yet, continue
if len(batch_sequences) <= BATCH_SIZE:
continue
# Otherwise, loop until at least one request is done
while not any(seq.is_finished() for seq in batch_sequences):
engine.step()
# Remove it from the set
for seq in list(batch_sequences):
if seq.is_finished():
batch_sequences.remove(seq)
# At this point, we are always at the case where we have finished
# processing some number of requests from the batch after running
# several _no_spec executions. The memory should not have
# increased between the previous time this was recorded and the
# current time.
if previous_memory_allocated is None:
previous_memory_allocated = torch.cuda.memory_allocated()
else:
assert previous_memory_allocated == torch.cuda.memory_allocated()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock
import pytest
import torch
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.layers.sampler import _get_ranks
from vllm.model_executor.layers.typical_acceptance_sampler import (
TypicalAcceptanceSampler)
from vllm.sequence import SequenceGroupMetadata, get_all_seq_ids
from vllm.spec_decode.util import (get_sampled_token_logprobs,
split_batch_by_proposal_len)
def test_get_all_seq_ids():
"""Verify get_all_seq_ids extracts all seq ids.
"""
expected_seq_ids = list(range(10)) + list(range(100, 110))
seq_group_metadata_list = [
SequenceGroupMetadata(
request_id=str(seq_id),
is_prompt=True,
seq_data={
seq_id: MagicMock(),
},
sampling_params=MagicMock(),
block_tables={
seq_id: MagicMock(),
},
lora_request=None,
) for seq_id in expected_seq_ids
]
actual_seq_ids = get_all_seq_ids(seq_group_metadata_list)
assert actual_seq_ids == expected_seq_ids
@pytest.fixture
def fake_sequence_group_metadata():
seq_ids = list(range(3))
return [
SequenceGroupMetadata(
request_id=str(i),
is_prompt=True,
seq_data={
i: MagicMock(),
},
sampling_params=MagicMock(),
block_tables={
i: MagicMock(),
},
lora_request=None,
) for i in seq_ids
]
def test_filter_zero_length_proposals(fake_sequence_group_metadata):
proposal_lens = [0, 1, 0]
_, (filtered_groups,
indices) = split_batch_by_proposal_len(fake_sequence_group_metadata,
proposal_lens)
expected_groups = [
fake_sequence_group_metadata[0], fake_sequence_group_metadata[2]
]
expected_indices = [0, 2]
assert filtered_groups == expected_groups
assert indices == expected_indices
def test_filter_non_zero_length_proposals(fake_sequence_group_metadata):
proposal_lens = [0, 1, 2]
(filtered_groups,
indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata,
proposal_lens)
expected_groups = [
fake_sequence_group_metadata[1], fake_sequence_group_metadata[2]
]
expected_indices = [1, 2]
assert filtered_groups == expected_groups
assert indices == expected_indices
def test_empty_inputs():
_, (filtered_groups, indices) = split_batch_by_proposal_len([], [])
assert filtered_groups == []
assert indices == []
def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata):
proposal_lens = [0, 0, 0]
(filtered_groups,
indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata,
proposal_lens)
assert filtered_groups == []
assert indices == []
def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata):
proposal_lens = [1, 1, 1]
_, (filtered_groups,
indices) = split_batch_by_proposal_len(fake_sequence_group_metadata,
proposal_lens)
assert filtered_groups == []
assert indices == []
def mock_spec_decode_sampler(acceptance_sampler_method):
"""
Returns either a RejectionSampler or TypicalAcceptanceSampler
object depending on whether acceptance_sampler_method is
'rejection_sampler' or 'typical_acceptance_sampler' respectively.
"""
if acceptance_sampler_method == "rejection_sampler":
sampler = MagicMock(spec=RejectionSampler)
sampler.token_id_dtype = torch.int64
return sampler
elif acceptance_sampler_method == "typical_acceptance_sampler":
sampler = MagicMock(spec=TypicalAcceptanceSampler)
sampler.token_id_dtype = torch.int64
return sampler
else:
raise ValueError(f"Invalid sampler name {acceptance_sampler_method}")
def test_get_sampled_token_logprobs():
"""Verify get_sampled_token_logprobs returns consistent rankings
with regular get_ranks when probabilities match exactly.
"""
logprob_tensor = torch.tensor(
[[[-.1, -.1]] * 2]) # shape (num_steps, batch_size, vocab_size)
sampled_token_tensor = torch.tensor([[1,
0]]) # shape (num_steps, batch_size)
ranks_spec_dec, _ = get_sampled_token_logprobs(logprob_tensor,
sampled_token_tensor)
ranks_regular = _get_ranks(logprob_tensor.reshape((2, -1)),
sampled_token_tensor.reshape(-1))
assert torch.equal(ranks_spec_dec.reshape(-1), ranks_regular)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence as GenericSequence
from itertools import count
from typing import Callable, Optional, TypeVar, Union
from unittest.mock import MagicMock
import torch
from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.utils import set_random_seed
from vllm.sampling_params import SamplingParams
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SequenceData, SequenceGroupMetadata, SequenceOutput)
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import ModelRunner
from vllm.worker.worker import Worker
T = TypeVar("T", bound=Worker)
def round_up_to_next_block(seq_len: int, block_size: int) -> int:
return (seq_len + block_size - 1) // block_size
def mock_worker(cls=None,
vocab_size: int = 30_000,
max_model_len: int = 2048,
rank: int = 0,
use_spec: bool = True) -> MagicMock:
if cls is None:
cls = Worker
spec = cls if use_spec else None
worker = MagicMock(spec=spec)
worker.vocab_size = vocab_size
worker.max_model_len = max_model_len
worker.rank = rank
worker.device = 'cuda:0'
return worker
def patch_execute_model_with_seeds(worker: Worker, rand_seeds: list[int]):
seed_iter = iter(rand_seeds)
original_execute_model = worker.execute_model
def new_execute_model(*args, **kwargs):
result = original_execute_model(*args, **kwargs)
set_random_seed(next(seed_iter))
return result
return new_execute_model
def zero_kv_cache(cache_engine: list[CacheEngine]):
assert cache_engine[0].gpu_cache
for key_blocks, value_blocks in cache_engine[0].gpu_cache:
key_blocks.zero_()
value_blocks.zero_()
def create_worker(cls: Callable[..., T],
model_name: str,
block_size: int,
num_gpu_blocks: int,
seed: int,
is_driver_worker: bool = True,
enforce_eager: bool = True,
model_runner_cls: Optional[ModelRunner] = None,
dtype: Optional[str] = "auto") -> T:
engine_args = EngineArgs(
model=model_name,
seed=seed,
block_size=block_size,
enforce_eager=enforce_eager,
dtype=dtype,
)
engine_config = engine_args.create_engine_config()
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
worker = cls(
vllm_config=engine_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker,
model_runner_cls=model_runner_cls,
)
worker.init_device()
worker.load_model()
engine_config.cache_config.num_gpu_blocks = num_gpu_blocks
engine_config.cache_config.num_cpu_blocks = 0
worker.initialize_cache(
num_gpu_blocks=engine_config.cache_config.num_gpu_blocks,
num_cpu_blocks=engine_config.cache_config.num_cpu_blocks)
return worker
def create_seq_group_metadata_from_prompts(
prompts: list[list[int]],
num_gpu_blocks: int,
block_size: int,
final_prompt_lens: list[int],
continuations: Optional[list[list[int]]] = None,
seq_ids: Optional[list[int]] = None,
) -> list[SequenceGroupMetadata]:
if continuations is None:
continuations = [[] for _ in prompts]
if seq_ids is None:
seq_ids = list(i for i, _ in enumerate(prompts))
free_gpu_blocks = list(range(num_gpu_blocks))
block_allocations = {
i: [
free_gpu_blocks.pop()
for _ in range(round_up_to_next_block(final_len, block_size))
]
for i, final_len in enumerate(final_prompt_lens)
}
seq_grou_metadata_list = []
for i, (prompt_token_ids,
cont_token_ids) in enumerate(zip(prompts, continuations)):
data = SequenceData.from_seqs(prompt_token_ids, cont_token_ids)
data.update_num_computed_tokens(
len(prompt_token_ids) + len(cont_token_ids) - 1)
seq_data = {i: data}
seq_grou_metadata_list.append(
SequenceGroupMetadata(
request_id=str(i),
is_prompt=len(cont_token_ids) == 0,
seq_data=seq_data,
sampling_params=SamplingParams(temperature=0.0),
block_tables={i: block_allocations[i][:]},
))
return seq_grou_metadata_list
def create_chunked_seq_group_metadata_from_prompt(
prompt: list[int],
num_gpu_blocks: int,
chunk_size: int,
block_size: int,
seq_id: Optional[int] = None) -> list[SequenceGroupMetadata]:
if seq_id is None:
seq_id = 0
free_gpu_blocks = list(range(num_gpu_blocks))
block_allocations = [
free_gpu_blocks.pop()
for _ in range(round_up_to_next_block(len(prompt), block_size))
]
seq_group_metadata_list = []
for i, idx in enumerate(range(0, len(prompt), chunk_size)):
chunk_ids = prompt[idx:idx + chunk_size]
data = SequenceData.from_seqs(prompt)
data.update_num_computed_tokens(idx)
seq_data = {i: data}
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=str(seq_id),
is_prompt=True,
do_sample=idx + chunk_size >= len(prompt), # terminal chunk
seq_data=seq_data,
sampling_params=SamplingParams(temperature=0.0),
block_tables={i: block_allocations},
token_chunk_size=len(chunk_ids)))
return seq_group_metadata_list
def assert_logprobs_dict_allclose(
actual_logprobs: list[dict[int, Logprob]],
expected_logprobs: list[dict[int, Logprob]]) -> None:
for single_step_actual_logprobs, single_step_expected_logprobs in zip(
actual_logprobs, expected_logprobs):
assert set(single_step_actual_logprobs.keys()) == set(
single_step_expected_logprobs.keys())
for token_id in single_step_actual_logprobs:
actual = torch.tensor(
single_step_actual_logprobs[token_id].logprob)
expected = torch.tensor(
single_step_expected_logprobs[token_id].logprob)
torch.testing.assert_close(actual, expected)
def create_sampler_output_list(
token_ids: torch.Tensor,
probs: GenericSequence[Optional[torch.Tensor]],
logprobs: GenericSequence[Optional[torch.Tensor]],
seq_ids: Optional[list[int]] = None) -> list[SamplerOutput]:
num_steps, batch_size = token_ids.shape
token_ids_by_step = token_ids.tolist()
if seq_ids is None:
seq_ids = list(range(batch_size))
return [
SamplerOutput(outputs=[
CompletionSequenceGroupOutput(
samples=[
SequenceOutput(
output_token=token_id,
parent_seq_id=seq_ids[seq_index],
logprobs={token_id: Logprob(0)},
)
],
prompt_logprobs=None,
) for seq_index, token_id in enumerate(token_ids_by_step[step])
],
sampled_token_probs=probs[step],
logprobs=logprobs[step],
sampled_token_ids=token_ids[step])
for step in range(num_steps)
]
def create_batch(batch_size,
k,
prompt_len: Union[int, list[int]] = 10,
prev_output_token_len: int = 10,
seq_ids: Optional[list[int]] = None,
num_gpu_blocks: Optional[int] = None,
block_size: Optional[int] = None,
prefill_chunk_size: Optional[int] = None):
if block_size is None:
block_size = 8
if num_gpu_blocks is None:
num_gpu_blocks = 2048 // block_size
iterator = count()
if isinstance(prompt_len, int):
prompt_lens = [prompt_len for _ in range(batch_size)]
else:
prompt_lens = prompt_len
prompts = [[next(iterator) for _ in range(p_len)] for p_len in prompt_lens]
if prefill_chunk_size:
# Create a batch of chunked prompts.
if not seq_ids:
seq_ids = list(range(len(prompts)))
seq_group_metadata_list = []
for p, sid in zip(prompts, seq_ids):
seq_group_metadata_list += \
create_chunked_seq_group_metadata_from_prompt(
p, num_gpu_blocks, prefill_chunk_size, block_size, sid)
seq_group_metadata_list = seq_group_metadata_list[:batch_size]
prev_output_tokens = []
else:
prev_output_tokens = [[
next(iterator) for _ in range(prev_output_token_len)
] for _ in range(batch_size)]
final_prompt_lens = [
len(prompt) + len(prev_output_token) + k + 1
for prompt, prev_output_token in zip(prompts, prev_output_tokens)
]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts, num_gpu_blocks, block_size, final_prompt_lens,
prev_output_tokens, seq_ids)
return seq_group_metadata_list, prompts, prev_output_tokens
def maybe_enable_chunked_prefill(prefill_chunk_size, llm_kwargs):
if prefill_chunk_size > 0:
llm_kwargs.update(
**{
"enable_chunked_prefill": True,
"max_num_batched_tokens": prefill_chunk_size,
"max_num_seqs": prefill_chunk_size
})
else:
llm_kwargs["enable_chunked_prefill"] = False
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable
import pytest
from vllm import LLM, EngineArgs
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.model_executor.model_loader import tensorizer as tensorizer_mod
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.v1.executor.abstract import UniProcExecutor
from vllm.worker.worker_base import WorkerWrapperBase
MODEL_REF = "facebook/opt-125m"
@pytest.fixture()
def model_ref():
return MODEL_REF
@pytest.fixture(autouse=True)
def allow_insecure_serialization(monkeypatch):
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
@pytest.fixture(autouse=True)
......@@ -11,7 +30,73 @@ def cleanup():
cleanup_dist_env_and_memory(shutdown_ray=True)
@pytest.fixture()
def just_serialize_model_tensors(model_ref, monkeypatch, tmp_path):
def noop(*args, **kwargs):
return None
args = EngineArgs(model=model_ref)
tc = TensorizerConfig(tensorizer_uri=f"{tmp_path}/model.tensors")
monkeypatch.setattr(tensorizer_mod, "serialize_extra_artifacts", noop)
tensorizer_mod.tensorize_vllm_model(args, tc)
yield tmp_path
@pytest.fixture(autouse=True)
def tensorizer_config():
config = TensorizerConfig(tensorizer_uri="vllm")
return config
@pytest.fixture()
def model_path(model_ref, tmp_path):
yield tmp_path / model_ref / "model.tensors"
def assert_from_collective_rpc(engine: LLM, closure: Callable,
closure_kwargs: dict):
res = engine.collective_rpc(method=closure, kwargs=closure_kwargs)
return all(res)
# This is an object pulled from tests/v1/engine/test_engine_core.py
# Modified to strip the `load_model` method from its `_init_executor`
# method. It's purely used as a dummy utility to run methods that test
# Tensorizer functionality
class DummyExecutor(UniProcExecutor):
def _init_executor(self) -> None:
"""Initialize the worker and load the model.
"""
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
rpc_rank=0)
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
local_rank = 0
# set local rank as the device index if specified
device_info = self.vllm_config.device_config.device.__str__().split(
":")
if len(device_info) > 1:
local_rank = int(device_info[1])
rank = 0
is_driver_worker = True
kwargs = dict(
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker,
)
self.collective_rpc("init_worker", args=([kwargs], ))
self.collective_rpc("init_device")
@property
def max_concurrent_batches(self) -> int:
return 2
def shutdown(self):
if hasattr(self, 'thread_pool'):
self.thread_pool.shutdown(wait=False)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import gc
import json
import os
import pathlib
import subprocess
import sys
from typing import Any
import pytest
import torch
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
import vllm.model_executor.model_loader.tensorizer
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs
# yapf conflicts with isort for this docstring
# yapf: disable
from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
TensorSerializer,
is_vllm_tensorized,
open_stream,
tensorize_vllm_model)
from vllm.lora.request import LoRARequest
from vllm.model_executor.model_loader.tensorizer_loader import (
BLACKLISTED_TENSORIZER_ARGS)
# yapf: enable
from vllm.utils import PlaceholderModule
from ..utils import VLLM_PATH, models_path_prefix
from ..utils import VLLM_PATH, RemoteOpenAIServer, models_path_prefix
from .conftest import DummyExecutor, assert_from_collective_rpc
try:
import tensorizer
from tensorizer import EncryptionParams
except ImportError:
tensorizer = PlaceholderModule("tensorizer") # type: ignore[assignment]
EncryptionParams = tensorizer.placeholder_attr("EncryptionParams")
class TensorizerCaughtError(Exception):
pass
EXAMPLES_PATH = VLLM_PATH / "examples"
pytest_plugins = "pytest_asyncio",
prompts = [
"Hello, my name is",
"The president of the United States is",
......@@ -42,9 +56,37 @@ prompts = [
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)
model_ref = os.path.join(models_path_prefix, "facebook/opt-125m")
tensorize_model_for_testing_script = os.path.join(
os.path.dirname(__file__), "tensorize_vllm_model_for_testing.py")
def patch_init_and_catch_error(self, obj, method_name,
expected_error: type[Exception]):
original = getattr(obj, method_name, None)
if original is None:
raise ValueError("Method '{}' not found.".format(method_name))
def wrapper(*args, **kwargs):
try:
return original(*args, **kwargs)
except expected_error as err:
raise TensorizerCaughtError from err
setattr(obj, method_name, wrapper)
self.load_model()
def assert_specific_tensorizer_error_is_raised(
executor,
obj: Any,
method_name: str,
expected_error: type[Exception],
):
with pytest.raises(TensorizerCaughtError):
executor.collective_rpc(patch_init_and_catch_error,
args=(
obj,
method_name,
expected_error,
))
def is_curl_installed():
......@@ -62,32 +104,12 @@ def write_keyfile(keyfile_path: str):
f.write(encryption_params.key)
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_can_deserialize_s3(vllm_runner):
model_ref = os.path.join(models_path_prefix, "EleutherAI/pythia-1.4b")
tensorized_path = f"{model_ref}/fp16/model.tensors"
with vllm_runner(model_ref,
load_format="tensorizer",
model_loader_extra_config=TensorizerConfig(
tensorizer_uri=tensorized_path,
num_readers=1,
s3_endpoint="object.ord1.coreweave.com",
)) as loaded_hf_model:
deserialized_outputs = loaded_hf_model.generate(
prompts, sampling_params)
# noqa: E501
assert deserialized_outputs
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_deserialized_encrypted_vllm_model_has_same_outputs(
vllm_runner, tmp_path):
model_ref, vllm_runner, tmp_path, model_path):
args = EngineArgs(model=model_ref)
with vllm_runner(model_ref) as vllm_model:
model_path = tmp_path / (model_ref + ".tensors")
key_path = tmp_path / (model_ref + ".key")
key_path = tmp_path / model_ref / "model.key"
write_keyfile(key_path)
outputs = vllm_model.generate(prompts, sampling_params)
......@@ -113,9 +135,9 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs(
def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
tmp_path):
tmp_path, model_ref,
model_path):
with hf_runner(model_ref) as hf_model:
model_path = tmp_path / (model_ref + ".tensors")
max_tokens = 50
outputs = hf_model.generate_greedy(prompts, max_tokens=max_tokens)
with open_stream(model_path, "wb+") as stream:
......@@ -125,7 +147,7 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
with vllm_runner(model_ref,
load_format="tensorizer",
model_loader_extra_config=TensorizerConfig(
tensorizer_uri=model_path,
tensorizer_uri=str(model_path),
num_readers=1,
)) as loaded_hf_model:
deserialized_outputs = loaded_hf_model.generate_greedy(
......@@ -134,7 +156,7 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
assert outputs == deserialized_outputs
def test_load_without_tensorizer_load_format(vllm_runner, capfd):
def test_load_without_tensorizer_load_format(vllm_runner, capfd, model_ref):
model = None
try:
model = vllm_runner(
......@@ -152,7 +174,8 @@ def test_load_without_tensorizer_load_format(vllm_runner, capfd):
torch.cuda.empty_cache()
def test_raise_value_error_on_invalid_load_format(vllm_runner, capfd):
def test_raise_value_error_on_invalid_load_format(vllm_runner, capfd,
model_ref):
model = None
try:
model = vllm_runner(
......@@ -211,7 +234,7 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(
outputs = base_model.generate(prompts, sampling_params)
# load model with two shards and serialize with encryption
model_path = str(tmp_path / (model_ref + "-%02d.tensors"))
model_path = str(tmp_path / model_ref / "model-%02d.tensors")
key_path = tmp_path / (model_ref + ".key")
tensorizer_config = TensorizerConfig(
......@@ -245,13 +268,13 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(
@pytest.mark.flaky(reruns=3)
def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
def test_vllm_tensorized_model_has_same_outputs(model_ref, vllm_runner,
tmp_path, model_path):
gc.collect()
torch.cuda.empty_cache()
model_ref = os.path.join(models_path_prefix, "facebook/opt-125m")
model_path = tmp_path / (model_ref + ".tensors")
config = TensorizerConfig(tensorizer_uri=str(model_path))
args = EngineArgs(model=model_ref, device="cuda")
args = EngineArgs(model=model_ref)
with vllm_runner(model_ref) as vllm_model:
outputs = vllm_model.generate(prompts, sampling_params)
......@@ -266,4 +289,244 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
prompts, sampling_params)
# noqa: E501
assert outputs == deserialized_outputs
\ No newline at end of file
assert outputs == deserialized_outputs
def test_load_with_just_model_tensors(just_serialize_model_tensors, model_ref):
# For backwards compatibility, ensure Tensorizer can be still be loaded
# for inference by passing the model reference name, not a local/S3 dir,
# and the location of the model tensors
model_dir = just_serialize_model_tensors
extra_config = {"tensorizer_uri": f"{model_dir}/model.tensors"}
## Start OpenAI API server
args = [
"--load-format",
"tensorizer",
"--model-loader-extra-config",
json.dumps(extra_config),
]
with RemoteOpenAIServer(model_ref, args):
# This test only concerns itself with being able to load the model
# and successfully initialize the server
pass
def test_assert_serialization_kwargs_passed_to_tensor_serializer(tmp_path):
serialization_params = {
"limit_cpu_concurrency": 2,
}
model_ref = "facebook/opt-125m"
model_path = tmp_path / (model_ref + ".tensors")
config = TensorizerConfig(tensorizer_uri=str(model_path),
serialization_kwargs=serialization_params)
llm = LLM(model=model_ref, )
def serialization_test(self, *args, **kwargs):
# This is performed in the ephemeral worker process, so monkey-patching
# will actually work, and cleanup is guaranteed so don't
# need to reset things
original_dict = serialization_params
to_compare = {}
original = tensorizer.serialization.TensorSerializer.__init__
def tensorizer_serializer_wrapper(self, *args, **kwargs):
nonlocal to_compare
to_compare = kwargs.copy()
return original(self, *args, **kwargs)
tensorizer.serialization.TensorSerializer.__init__ = (
tensorizer_serializer_wrapper)
tensorizer_config = TensorizerConfig(**kwargs["tensorizer_config"])
self.save_tensorized_model(tensorizer_config=tensorizer_config, )
return to_compare | original_dict == to_compare
kwargs = {"tensorizer_config": config.to_serializable()}
assert assert_from_collective_rpc(llm, serialization_test, kwargs)
def test_assert_deserialization_kwargs_passed_to_tensor_deserializer(
tmp_path, capfd):
deserialization_kwargs = {
"num_readers": "bar", # illegal value
}
serialization_params = {
"limit_cpu_concurrency": 2,
}
model_ref = "facebook/opt-125m"
model_path = tmp_path / (model_ref + ".tensors")
config = TensorizerConfig(tensorizer_uri=str(model_path),
serialization_kwargs=serialization_params)
args = EngineArgs(model=model_ref)
tensorize_vllm_model(args, config)
loader_tc = TensorizerConfig(
tensorizer_uri=str(model_path),
deserialization_kwargs=deserialization_kwargs,
)
engine_args = EngineArgs(
model="facebook/opt-125m",
load_format="tensorizer",
model_loader_extra_config=loader_tc.to_serializable(),
)
vllm_config = engine_args.create_engine_config()
executor = DummyExecutor(vllm_config)
assert_specific_tensorizer_error_is_raised(
executor,
tensorizer.serialization.TensorDeserializer,
"__init__",
TypeError,
)
def test_assert_stream_kwargs_passed_to_tensor_deserializer(tmp_path, capfd):
deserialization_kwargs = {
"num_readers": 1,
}
serialization_params = {
"limit_cpu_concurrency": 2,
}
model_ref = "facebook/opt-125m"
model_path = tmp_path / (model_ref + ".tensors")
config = TensorizerConfig(tensorizer_uri=str(model_path),
serialization_kwargs=serialization_params)
args = EngineArgs(model=model_ref)
tensorize_vllm_model(args, config)
stream_kwargs = {"mode": "foo"}
loader_tc = TensorizerConfig(
tensorizer_uri=str(model_path),
deserialization_kwargs=deserialization_kwargs,
stream_kwargs=stream_kwargs,
)
engine_args = EngineArgs(
model="facebook/opt-125m",
load_format="tensorizer",
model_loader_extra_config=loader_tc.to_serializable(),
)
vllm_config = engine_args.create_engine_config()
executor = DummyExecutor(vllm_config)
assert_specific_tensorizer_error_is_raised(
executor,
vllm.model_executor.model_loader.tensorizer,
"open_stream",
ValueError,
)
@pytest.mark.asyncio
async def test_serialize_and_serve_entrypoints(tmp_path):
model_ref = "facebook/opt-125m"
suffix = "test"
try:
result = subprocess.run([
sys.executable,
f"{VLLM_PATH}/examples/others/tensorize_vllm_model.py", "--model",
model_ref, "serialize", "--serialized-directory",
str(tmp_path), "--suffix", suffix, "--serialization-kwargs",
'{"limit_cpu_concurrency": 4}'
],
check=True,
capture_output=True,
text=True)
except subprocess.CalledProcessError as e:
print("Tensorizing failed.")
print("STDOUT:\n", e.stdout)
print("STDERR:\n", e.stderr)
raise
assert "Successfully serialized" in result.stdout
# Next, try to serve with vllm serve
model_uri = tmp_path / "vllm" / model_ref / suffix / "model.tensors"
model_loader_extra_config = {
"tensorizer_uri": str(model_uri),
"stream_kwargs": {
"force_http": False,
},
"deserialization_kwargs": {
"verify_hash": True,
"num_readers": 8,
}
}
cmd = [
"-m", "vllm.entrypoints.cli.main", "serve", "--host", "localhost",
"--load-format", "tensorizer", model_ref,
"--model-loader-extra-config",
json.dumps(model_loader_extra_config, indent=2)
]
proc = await asyncio.create_subprocess_exec(
sys.executable,
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
assert proc.stdout is not None
fut = proc.stdout.readuntil(b"Application startup complete.")
try:
await asyncio.wait_for(fut, 180)
except asyncio.TimeoutError:
pytest.fail("Server did not start successfully")
finally:
proc.terminate()
await proc.communicate()
@pytest.mark.parametrize("illegal_value", BLACKLISTED_TENSORIZER_ARGS)
def test_blacklisted_parameter_for_loading(tmp_path, vllm_runner, capfd,
illegal_value):
serialization_params = {
"limit_cpu_concurrency": 2,
}
model_ref = "facebook/opt-125m"
model_path = tmp_path / (model_ref + ".tensors")
config = TensorizerConfig(tensorizer_uri=str(model_path),
serialization_kwargs=serialization_params)
args = EngineArgs(model=model_ref)
tensorize_vllm_model(args, config)
loader_tc = {"tensorizer_uri": str(model_path), illegal_value: "foo"}
try:
vllm_runner(
model_ref,
load_format="tensorizer",
model_loader_extra_config=loader_tc,
)
except RuntimeError:
out, err = capfd.readouterr()
combined_output = out + err
assert (f"ValueError: {illegal_value} is not an allowed "
f"Tensorizer argument.") in combined_output
......@@ -8,7 +8,7 @@ import os
from vllm.compilation.backends import VllmBackend
from vllm.config import (LoadConfig, ModelConfig, PoolerConfig, VllmConfig,
get_field)
get_field, update_config)
from vllm.model_executor.layers.pooler import PoolingType
from vllm.platforms import current_platform
from utils import models_path_prefix
......@@ -48,15 +48,43 @@ def test_get_field():
assert c.default_factory is MISSING
@dataclass
class _TestNestedConfig:
a: _TestConfigFields = field(
default_factory=lambda: _TestConfigFields(a=0))
def test_update_config():
# Simple update
config1 = _TestConfigFields(a=0)
new_config1 = update_config(config1, {"a": 42})
assert new_config1.a == 42
# Nonexistent field
with pytest.raises(AssertionError):
new_config1 = update_config(config1, {"nonexistent": 1})
# Nested update with dataclass
config2 = _TestNestedConfig()
new_inner_config = _TestConfigFields(a=1, c="new_value")
new_config2 = update_config(config2, {"a": new_inner_config})
assert new_config2.a == new_inner_config
# Nested update with dict
config3 = _TestNestedConfig()
new_config3 = update_config(config3, {"a": {"c": "new_value"}})
assert new_config3.a.c == "new_value"
# Nested update with invalid type
with pytest.raises(AssertionError):
new_config3 = update_config(config3, {"a": "new_value"})
@pytest.mark.parametrize(
("model_id", "expected_runner_type", "expected_task"),
[
(os.path.join(models_path_prefix, "distilbert/distilgpt2"), "generate", "generate"),
(os.path.join(models_path_prefix, "intfloat/multilingual-e5-small"), "pooling", "embed"),
(os.path.join(models_path_prefix, "jason9693/Qwen2.5-1.5B-apeach"), "pooling", "classify"),
(os.path.join(models_path_prefix, "cross-encoder/ms-marco-MiniLM-L-6-v2"), "pooling", "classify"),
(os.path.join(models_path_prefix, "Qwen/Qwen2.5-Math-RM-72B"), "pooling", "reward"),
(os.path.join(models_path_prefix, "openai/whisper-small"), "transcription", "transcription"),
(os.path.join(models_path_prefix,"distilbert/distilgpt2"), "generate", "generate"),
(os.path.join(models_path_prefix,"intfloat/multilingual-e5-small"), "pooling", "embed"),
(os.path.join(models_path_prefix,"jason9693/Qwen2.5-1.5B-apeach"), "pooling", "classify"),
(os.path.join(models_path_prefix,"cross-encoder/ms-marco-MiniLM-L-6-v2"), "pooling", "classify"),
(os.path.join(models_path_prefix,"Qwen/Qwen2.5-Math-RM-72B"), "pooling", "reward"),
(os.path.join(models_path_prefix,"openai/whisper-small"), "generate", "transcription"),
],
)
def test_auto_task(model_id, expected_runner_type, expected_task):
......@@ -71,7 +99,11 @@ def test_auto_task(model_id, expected_runner_type, expected_task):
)
assert config.runner_type == expected_runner_type
assert config.task == expected_task
if config.runner_type == "pooling":
assert config.task == expected_task
else:
assert expected_task in config.supported_tasks
@pytest.mark.parametrize(
......@@ -100,11 +132,50 @@ def test_score_task(model_id, expected_runner_type, expected_task):
assert config.task == expected_task
@pytest.mark.parametrize(("model_id", "expected_runner_type", "expected_task"),
[
("Qwen/Qwen2.5-1.5B-Instruct", "draft", "auto"),
])
def test_draft_task(model_id, expected_runner_type, expected_task):
config = ModelConfig(
model_id,
runner="draft",
tokenizer=model_id,
seed=0,
dtype="float16",
)
assert config.runner_type == expected_runner_type
assert config.task == expected_task
@pytest.mark.parametrize(
("model_id", "expected_runner_type", "expected_task"),
[
("openai/whisper-small", "generate", "transcription"),
],
)
def test_transcription_task(model_id, expected_runner_type, expected_task):
config = ModelConfig(
model_id,
task="transcription",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
)
assert config.runner_type == expected_runner_type
assert config.task == expected_task
@pytest.mark.parametrize(("model_id", "bad_task"), [
(os.path.join(models_path_prefix, "Qwen/Qwen2.5-Math-RM-72B"), "generate"),
(os.path.join(models_path_prefix,"Qwen/Qwen2.5-Math-RM-72B"), "generate"),
(os.path.join(models_path_prefix,"Qwen/Qwen3-0.6B"), "transcription"),
])
def test_incorrect_task(model_id, bad_task):
with pytest.raises(ValueError, match=r"does not support the .* task"):
with pytest.raises(ValueError, match=r"does not support task=.*"):
ModelConfig(
model_id,
task=bad_task,
......
......@@ -29,7 +29,6 @@ def test_sampler_output_initialization(sampler_output, sample_outputs):
assert len(sampler_output) == len(sample_outputs)
assert sampler_output.sampled_token_probs is None
assert sampler_output.sampled_token_ids is None
assert sampler_output.spec_decode_worker_metrics is None
def test_sampler_output_getitem(sampler_output, sample_outputs):
......
......@@ -15,15 +15,18 @@ import os
import pytest
import torch
import zmq
from transformers import AutoTokenizer
from vllm_test_utils.monitor import monitor
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.transformers_utils.detokenizer_utils import (
convert_ids_list_to_tokens)
from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache,
MemorySnapshot, PlaceholderModule, StoreBoolean,
bind_kv_cache, common_broadcastable_dtype,
deprecate_kwargs, get_open_port, get_tcp_uri,
is_lossless_cast, join_host_port, make_zmq_path,
make_zmq_socket, memory_profiling,
current_stream, deprecate_kwargs, get_open_port,
get_tcp_uri, is_lossless_cast, join_host_port,
make_zmq_path, make_zmq_socket, memory_profiling,
merge_async_iterators, sha256, split_host_port,
split_zmq_path, supports_kw, swap_dict_values)
......@@ -458,6 +461,31 @@ def test_bind_kv_cache():
assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[2]
assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[3]
def test_bind_kv_cache_kv_sharing():
from vllm.attention import Attention
ctx = {
'layers.0.self_attn': Attention(32, 128, 0.1),
'layers.1.self_attn': Attention(32, 128, 0.1),
'layers.2.self_attn': Attention(32, 128, 0.1),
'layers.3.self_attn': Attention(32, 128, 0.1),
}
kv_cache = [
torch.zeros((1, )),
torch.zeros((1, )),
torch.zeros((1, )),
torch.zeros((1, )),
]
shared_kv_cache_layers = {
'layers.2.self_attn': 'layers.1.self_attn',
'layers.3.self_attn': 'layers.0.self_attn'
}
bind_kv_cache(ctx, [kv_cache], shared_kv_cache_layers)
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0]
assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[1]
assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[1]
assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[0]
def test_bind_kv_cache_non_attention():
from vllm.attention import Attention
......@@ -921,3 +949,52 @@ def test_split_host_port():
def test_join_host_port():
assert join_host_port("127.0.0.1", 5555) == "127.0.0.1:5555"
assert join_host_port("::1", 5555) == "[::1]:5555"
def test_convert_ids_list_to_tokens():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
token_ids = tokenizer.encode("Hello, world!")
# token_ids = [9707, 11, 1879, 0]
assert tokenizer.convert_ids_to_tokens(token_ids) == [
'Hello', ',', 'Ġworld', '!'
]
tokens = convert_ids_list_to_tokens(tokenizer, token_ids)
assert tokens == ['Hello', ',', ' world', '!']
def test_current_stream_multithread():
import threading
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
main_default_stream = torch.cuda.current_stream()
child_stream = torch.cuda.Stream()
thread_stream_ready = threading.Event()
thread_can_exit = threading.Event()
def child_thread_func():
with torch.cuda.stream(child_stream):
thread_stream_ready.set()
thread_can_exit.wait(timeout=10)
child_thread = threading.Thread(target=child_thread_func)
child_thread.start()
try:
assert thread_stream_ready.wait(
timeout=5), "Child thread failed to enter stream context in time"
main_current_stream = current_stream()
assert main_current_stream != child_stream, "Main thread's current_stream was contaminated by child thread"
assert main_current_stream == main_default_stream, "Main thread's current_stream is not the default stream"
# Notify child thread it can exit
thread_can_exit.set()
finally:
# Ensure child thread exits properly
child_thread.join(timeout=5)
if child_thread.is_alive():
pytest.fail("Child thread failed to exit properly")
......@@ -395,7 +395,7 @@ def test_decode_prompt_logprobs_chunked_prefill(
logprobs=5,
prompt_logprobs=5,
temperature=0.0)
vllm_results = vllm_model.model.generate(
vllm_results = vllm_model.llm.generate(
example_prompts, sampling_params=vllm_sampling_params)
for idx, result in enumerate(vllm_results):
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.transformers_utils.tokenizer import get_tokenizer
TOKENIZER_NAMES = ["BAAI/bge-base-en"]
@pytest.mark.parametrize("tokenizer_name", TOKENIZER_NAMES)
@pytest.mark.parametrize("n_tokens", [510])
def test_special_tokens(tokenizer_name: str, n_tokens: int):
tokenizer = get_tokenizer(tokenizer_name, revision="main")
prompts = '[UNK]' * n_tokens
prompt_token_ids = tokenizer.encode(prompts)
assert len(prompt_token_ids) == n_tokens + 2
......@@ -254,7 +254,9 @@ def test_extract_tool_calls_mixed_content(glm4_moe_tool_parser):
<arg_key>date</arg_key>
<arg_value>2025-08-01</arg_value>
</tool_call>
meaningwhile, I will also check the weather in Shanghai.
<tool_call>get_weather
<arg_key>city</arg_key>
<arg_value>Shanghai</arg_value>
......@@ -402,4 +404,4 @@ def test_extract_tool_calls_incomplete_tool_call(glm4_moe_tool_parser):
# Incomplete tool calls should not be extracted
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output
\ No newline at end of file
assert extracted_tool_calls.content == model_output
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
import json
import pytest
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
from vllm.entrypoints.openai.tool_parsers import KimiK2ToolParser
from vllm.transformers_utils.tokenizer import get_tokenizer
# Use a common model that is likely to be available
MODEL = "moonshotai/Kimi-K2-Instruct"
@pytest.fixture(scope="module")
def kimi_k2_tokenizer():
return get_tokenizer(tokenizer_name=MODEL, trust_remote_code=True)
@pytest.fixture
def kimi_k2_tool_parser(kimi_k2_tokenizer):
return KimiK2ToolParser(kimi_k2_tokenizer)
def assert_tool_calls(actual_tool_calls: list[ToolCall],
expected_tool_calls: list[ToolCall]):
assert len(actual_tool_calls) == len(expected_tool_calls)
for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
expected_tool_calls):
assert actual_tool_call.type == "function"
assert actual_tool_call.function == expected_tool_call.function
# assert tool call id format
assert actual_tool_call.id.startswith("functions.")
assert actual_tool_call.id.split(':')[-1].isdigit()
assert actual_tool_call.id.split('.')[1].split(
':')[0] == expected_tool_call.function.name
def test_extract_tool_calls_no_tools(kimi_k2_tool_parser):
model_output = "This is a test"
extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output
@pytest.mark.parametrize(
ids=[
"tool_call_with_content_before",
"multi_tool_call_with_content_before",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
(
"""I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_calls_section_end|>""",
[
ToolCall(id='functions.get_weather:0',
function=FunctionCall(
name="get_weather",
arguments=json.dumps({
"city": "Beijing",
}, ),
),
type='function')
],
"I'll help you check the weather. ",
),
(
"""I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_call_begin|>
functions.get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>""",
[
ToolCall(id='functions.get_weather:0',
function=FunctionCall(
name="get_weather",
arguments=json.dumps({
"city": "Beijing",
}, ),
),
type='function'),
ToolCall(id='functions.get_weather:1',
function=FunctionCall(
name="get_weather",
arguments=json.dumps({
"city": "Shanghai",
}, ),
),
type='function')
],
"I'll help you check the weather. ",
),
],
)
def test_extract_tool_calls(kimi_k2_tool_parser, model_output,
expected_tool_calls, expected_content):
extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
assert extracted_tool_calls.content == expected_content
def test_extract_tool_calls_invalid_json(kimi_k2_tool_parser):
"""we'll return every funcall result"""
model_output = """I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
functions.invalid_get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing" <|tool_call_end|> <|tool_call_begin|>
functions.valid_get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>"""
extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
# Should extract only the valid JSON tool calls
assert len(extracted_tool_calls.tool_calls) == 2
assert extracted_tool_calls.tool_calls[
0].function.name == "invalid_get_weather"
assert extracted_tool_calls.tool_calls[
1].function.name == "valid_get_weather"
def test_extract_tool_calls_invalid_funcall(kimi_k2_tool_parser):
"""we'll return every funcall result"""
model_output = """I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
functions.invalid_get_weather.0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_call_begin|>
functions.valid_get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>"""
extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
# Should extract only the valid JSON tool calls
assert len(extracted_tool_calls.tool_calls) == 1
assert extracted_tool_calls.tool_calls[
0].function.name == "valid_get_weather"
def test_streaming_basic_functionality(kimi_k2_tool_parser):
"""Test basic streaming functionality."""
# Reset streaming state
kimi_k2_tool_parser.current_tool_name_sent = False
kimi_k2_tool_parser.prev_tool_call_arr = []
kimi_k2_tool_parser.current_tool_id = -1
kimi_k2_tool_parser.streamed_args_for_tool = []
# Test with a simple tool call
current_text = """ check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_calls_section_end|>"""
# First call should handle the initial setup
result = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="I'll help you",
current_text=current_text,
delta_text="<|tool_calls_section_end|>",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
# The result might be None or contain tool call information
# This depends on the internal state management
if result is not None and hasattr(result,
'tool_calls') and result.tool_calls:
assert len(result.tool_calls) >= 0
def test_streaming_no_tool_calls(kimi_k2_tool_parser):
"""Test streaming when there are no tool calls."""
current_text = "This is just regular text without any tool calls."
result = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="This is just regular text",
current_text=current_text,
delta_text=" without any tool calls.",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
# Should return the delta text as content
assert result is not None
assert hasattr(result, 'content')
assert result.content == " without any tool calls."
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Generator
from typing import Optional
import pytest
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionToolsParam,
DeltaMessage, FunctionCall,
ToolCall)
from vllm.entrypoints.openai.tool_parsers.qwen3coder_tool_parser import (
Qwen3CoderToolParser)
from vllm.transformers_utils.detokenizer import detokenize_incrementally
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
MODEL = "Qwen/Qwen3-Coder-480B-A35B-Instruct-FP8"
@pytest.fixture(scope="module")
def qwen3_tokenizer():
return get_tokenizer(tokenizer_name=MODEL)
@pytest.fixture
def qwen3_tool_parser(qwen3_tokenizer):
return Qwen3CoderToolParser(qwen3_tokenizer)
@pytest.fixture
def sample_tools():
return [
ChatCompletionToolsParam(type="function",
function={
"name": "get_current_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city name"
},
"state": {
"type": "string",
"description":
"The state code"
},
"unit": {
"type": "string",
"enum":
["fahrenheit", "celsius"]
}
},
"required": ["city", "state"]
}
}),
ChatCompletionToolsParam(type="function",
function={
"name": "calculate_area",
"description":
"Calculate area of a shape",
"parameters": {
"type": "object",
"properties": {
"shape": {
"type": "string"
},
"dimensions": {
"type": "object"
},
"precision": {
"type": "integer"
}
}
}
})
]
def assert_tool_calls(actual_tool_calls: list[ToolCall],
expected_tool_calls: list[ToolCall]):
assert len(actual_tool_calls) == len(expected_tool_calls)
for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
expected_tool_calls):
# Qwen3 parser doesn't generate IDs during extraction
assert actual_tool_call.type == "function"
assert (
actual_tool_call.function.name == expected_tool_call.function.name)
assert (json.loads(actual_tool_call.function.arguments) == json.loads(
expected_tool_call.function.arguments))
def stream_delta_message_generator(
qwen3_tool_parser: Qwen3CoderToolParser,
qwen3_tokenizer: AnyTokenizer,
model_output: str,
request: Optional[ChatCompletionRequest] = None
) -> Generator[DeltaMessage, None, None]:
all_token_ids = qwen3_tokenizer.encode(model_output,
add_special_tokens=False)
previous_text = ""
previous_tokens = None
prefix_offset = 0
read_offset = 0
for i, delta_token in enumerate(all_token_ids):
delta_token_ids = [delta_token]
previous_token_ids = all_token_ids[:i]
current_token_ids = all_token_ids[:i + 1]
(new_tokens, delta_text, new_prefix_offset,
new_read_offset) = detokenize_incrementally(
tokenizer=qwen3_tokenizer,
all_input_ids=current_token_ids,
prev_tokens=previous_tokens,
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=False,
spaces_between_special_tokens=True,
)
current_text = previous_text + delta_text
delta_message = qwen3_tool_parser.extract_tool_calls_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
delta_token_ids,
request=request,
)
if delta_message:
yield delta_message
previous_text = current_text
previous_tokens = (previous_tokens +
new_tokens if previous_tokens else new_tokens)
prefix_offset = new_prefix_offset
read_offset = new_read_offset
def test_extract_tool_calls_no_tools(qwen3_tool_parser):
model_output = "This is a test response without any tool calls"
extracted_tool_calls = qwen3_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output
@pytest.mark.parametrize(
ids=[
"single_tool",
"single_tool_with_content",
"single_tool_multiline_param",
"parallel_tools",
"tool_with_typed_params",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
('''<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>''', [
ToolCall(
function=FunctionCall(name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit"
})))
], None),
('''Sure! Let me check the weather for you.<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>''', [
ToolCall(
function=FunctionCall(name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit"
})))
], "Sure! Let me check the weather for you."),
('''<tool_call>
<function=calculate_area>
<parameter=shape>
rectangle
</parameter>
<parameter=dimensions>
{"width": 10,
"height": 20}
</parameter>
<parameter=precision>
2
</parameter>
</function>
</tool_call>''', [
ToolCall(function=FunctionCall(name="calculate_area",
arguments=json.dumps({
"shape": "rectangle",
"dimensions": {
"width": 10,
"height": 20
},
"precision": 2
})))
], None),
('''<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>
<tool_call>
<function=get_current_weather>
<parameter=city>
Orlando
</parameter>
<parameter=state>
FL
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>''', [
ToolCall(
function=FunctionCall(name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit"
}))),
ToolCall(
function=FunctionCall(name="get_current_weather",
arguments=json.dumps({
"city": "Orlando",
"state": "FL",
"unit": "fahrenheit"
})))
], None),
('''Let me calculate that area for you.<tool_call>
<function=calculate_area>
<parameter=shape>
circle
</parameter>
<parameter=dimensions>
{"radius": 15.5}
</parameter>
<parameter=precision>
3
</parameter>
</function>
</tool_call>''', [
ToolCall(function=FunctionCall(name="calculate_area",
arguments=json.dumps({
"shape": "circle",
"dimensions": {
"radius": 15.5
},
"precision": 3
})))
], "Let me calculate that area for you."),
],
)
def test_extract_tool_calls(qwen3_tool_parser, sample_tools, model_output,
expected_tool_calls, expected_content):
request = ChatCompletionRequest(model=MODEL,
messages=[],
tools=sample_tools)
extracted_tool_calls = qwen3_tool_parser.extract_tool_calls(
model_output, request=request)
assert extracted_tool_calls.tools_called
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
assert extracted_tool_calls.content == expected_content
def test_extract_tool_calls_fallback_no_tags(qwen3_tool_parser, sample_tools):
"""Test fallback parsing when XML tags are missing"""
model_output = '''<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
</function>'''
request = ChatCompletionRequest(model=MODEL,
messages=[],
tools=sample_tools)
extracted_tool_calls = qwen3_tool_parser.extract_tool_calls(
model_output, request=request)
assert extracted_tool_calls.tools_called
assert len(extracted_tool_calls.tool_calls) == 1
assert (extracted_tool_calls.tool_calls[0].function.name ==
"get_current_weather")
def test_extract_tool_calls_type_conversion(qwen3_tool_parser):
"""Test parameter type conversion based on tool schema"""
tools = [
ChatCompletionToolsParam(type="function",
function={
"name": "test_types",
"parameters": {
"type": "object",
"properties": {
"int_param": {
"type": "integer"
},
"float_param": {
"type": "float"
},
"bool_param": {
"type": "boolean"
},
"str_param": {
"type": "string"
},
"obj_param": {
"type": "object"
}
}
}
})
]
model_output = '''<tool_call>
<function=test_types>
<parameter=int_param>
42
</parameter>
<parameter=float_param>
3.14
</parameter>
<parameter=bool_param>
true
</parameter>
<parameter=str_param>
hello world
</parameter>
<parameter=obj_param>
{"key": "value"}
</parameter>
</function>
</tool_call>'''
request = ChatCompletionRequest(model=MODEL, messages=[], tools=tools)
extracted_tool_calls = qwen3_tool_parser.extract_tool_calls(
model_output, request=request)
args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
assert args["int_param"] == 42
assert args["float_param"] == 3.14
assert args["bool_param"] is True
assert args["str_param"] == "hello world"
assert args["obj_param"] == {"key": "value"}
@pytest.mark.parametrize(
ids=[
"no_tools",
"single_tool",
"single_tool_with_content",
"parallel_tools",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
("This is a test without tools", [], "This is a test without tools"),
('''<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>''', [
ToolCall(
function=FunctionCall(name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit"
})))
], ""),
('''Sure! Let me check the weather for you.<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>''', [
ToolCall(
function=FunctionCall(name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit"
})))
], "Sure! Let me check the weather for you."),
('''<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>
<tool_call>
<function=get_current_weather>
<parameter=city>
Orlando
</parameter>
<parameter=state>
FL
</parameter>
<parameter=unit>
celsius
</parameter>
</function>
</tool_call>''', [
ToolCall(
function=FunctionCall(name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit"
}))),
ToolCall(
function=FunctionCall(name="get_current_weather",
arguments=json.dumps({
"city": "Orlando",
"state": "FL",
"unit": "celsius"
})))
], ""),
],
)
def test_extract_tool_calls_streaming(qwen3_tool_parser, qwen3_tokenizer,
sample_tools, model_output,
expected_tool_calls, expected_content):
"""Test incremental streaming behavior"""
request = ChatCompletionRequest(model=MODEL,
messages=[],
tools=sample_tools)
other_content = ''
tool_states = {} # Track state per tool index
for delta_message in stream_delta_message_generator(
qwen3_tool_parser, qwen3_tokenizer, model_output, request):
# role should never be streamed from tool parser
assert not delta_message.role
if delta_message.content:
other_content += delta_message.content
if delta_message.tool_calls:
for tool_call in delta_message.tool_calls:
idx = tool_call.index
# Initialize state for new tool
if idx not in tool_states:
tool_states[idx] = {
"id": None,
"name": None,
"arguments": "",
"type": None
}
# First chunk should have id, name, and type
if tool_call.id:
tool_states[idx]["id"] = tool_call.id
if tool_call.type:
assert tool_call.type == "function"
tool_states[idx]["type"] = tool_call.type
if tool_call.function:
if tool_call.function.name:
# Should only be set once
assert tool_states[idx]["name"] is None
tool_states[idx]["name"] = tool_call.function.name
if tool_call.function.arguments is not None:
# Accumulate arguments incrementally
tool_states[idx][
"arguments"] += tool_call.function.arguments
# Verify final content
assert other_content == expected_content
# Verify we got all expected tool calls
assert len(tool_states) == len(expected_tool_calls)
# Verify each tool call
for idx, expected_tool in enumerate(expected_tool_calls):
state = tool_states[idx]
assert state["id"] is not None
assert state["type"] == "function"
assert state["name"] == expected_tool.function.name
# Parse accumulated arguments
arguments_str = state["arguments"]
assert arguments_str is not None
actual_args = json.loads(arguments_str)
expected_args = json.loads(expected_tool.function.arguments)
assert actual_args == expected_args
def test_extract_tool_calls_streaming_incremental(qwen3_tool_parser,
qwen3_tokenizer,
sample_tools):
"""Test that streaming is truly incremental"""
model_output = '''I'll check the weather.<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
</function>
</tool_call>'''
request = ChatCompletionRequest(model=MODEL,
messages=[],
tools=sample_tools)
chunks = []
for delta_message in stream_delta_message_generator(
qwen3_tool_parser, qwen3_tokenizer, model_output, request):
chunks.append(delta_message)
# Should have multiple chunks
assert len(chunks) > 3
# First chunk(s) should be content
assert chunks[0].content is not None
assert chunks[0].tool_calls is None or chunks[0].tool_calls == []
# Should have a chunk with tool header (id, name, type)
header_found = False
for chunk in chunks:
if chunk.tool_calls and chunk.tool_calls[0].id:
header_found = True
assert (chunk.tool_calls[0].function.name == "get_current_weather")
assert chunk.tool_calls[0].type == "function"
# Empty initially
assert chunk.tool_calls[0].function.arguments == ""
break
assert header_found
# Should have chunks with incremental arguments
arg_chunks = []
for chunk in chunks:
if chunk.tool_calls and chunk.tool_calls[0].function.arguments:
arg_chunks.append(chunk.tool_calls[0].function.arguments)
# Arguments should be streamed incrementally
assert len(arg_chunks) > 1
# Concatenated arguments should form valid JSON
full_args = "".join(arg_chunks)
parsed_args = json.loads(full_args)
assert parsed_args["city"] == "Dallas"
assert parsed_args["state"] == "TX"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment