Commit 0640f227 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.0' into v0.6.0-dev

parents 82f1ffdf 32e7db25
......@@ -6,8 +6,10 @@ from typing import Dict, Tuple
import numpy as np
import pytest
from PIL import Image
from transformers import AutoConfig, AutoTokenizer
from vllm.multimodal.utils import async_fetch_image, fetch_image
from vllm.multimodal.utils import (async_fetch_image, fetch_image,
repeat_and_pad_placeholder_tokens)
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_URLS = [
......@@ -80,3 +82,34 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
data_image_async = await async_fetch_image(data_url)
assert _image_equals(data_image_sync, data_image_async)
@pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-mistral-7b-hf"])
def test_repeat_and_pad_placeholder_tokens(model):
config = AutoConfig.from_pretrained(model)
image_token_id = config.image_token_index
tokenizer = AutoTokenizer.from_pretrained(model)
test_cases = [
("<image>", 2, "<image><image>", [32000, 32000]),
("<image><image>", 2, "<image><image><image>", [32000, 32000, 32000]),
("<image><image>", [3, 2], "<image><image><image><image><image>",
[32000, 32000, 32000, 32000, 32000]),
("Image:<image>Image:<image>!", [3, 2],
"Image:<image><image><image>Image:<image><image>!",
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918]),
("<image>", [3, 2], "<image><image><image>", [32000, 32000, 32000]),
]
for prompt, repeat_count, expected_prompt, expected_token_ids in test_cases:
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
tokenizer=tokenizer,
prompt=prompt,
prompt_token_ids=tokenizer.encode(prompt,
add_special_tokens=False),
placeholder_token_id=image_token_id,
repeat_count=repeat_count,
)
assert new_prompt == expected_prompt
assert new_token_ids == expected_token_ids
......@@ -2,85 +2,115 @@
Run `pytest tests/quantization/test_bitsandbytes.py`.
'''
import gc
import pytest
import torch
from tests.quantization.utils import is_quant_method_supported
from vllm import SamplingParams
models_to_test = [
models_4bit_to_test = [
('huggyllama/llama-7b', 'quantize model inflight'),
('lllyasviel/omost-llama-3-8b-4bits', 'read pre-quantized model'),
]
models_pre_qaunt_4bit_to_test = [
('lllyasviel/omost-llama-3-8b-4bits',
'read pre-quantized 4-bit NF4 model'),
('PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed',
'read pre-quantized 4-bit FP4 model'),
]
models_pre_quant_8bit_to_test = [
('meta-llama/Llama-Guard-3-8B-INT8', 'read pre-quantized 8-bit model'),
]
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
reason='bitsandbytes is not supported on this GPU type.')
@pytest.mark.parametrize("model_name, description", models_4bit_to_test)
def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
model_name, description) -> None:
hf_model_kwargs = {"load_in_4bit": True}
validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1],
model_name, hf_model_kwargs)
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
reason='bitsandbytes is not supported on this GPU type.')
@pytest.mark.parametrize("model_name, description",
models_pre_qaunt_4bit_to_test)
def test_load_pre_quant_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
model_name, description) -> None:
validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1],
model_name)
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
reason='bitsandbytes is not supported on this GPU type.')
@pytest.mark.parametrize("model_name, description", models_to_test)
def test_load_bnb_model(vllm_runner, model_name, description) -> None:
@pytest.mark.parametrize("model_name, description",
models_pre_quant_8bit_to_test)
def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts,
model_name, description) -> None:
validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1],
model_name)
def log_generated_texts(prompts, outputs, runner_name):
logged_texts = []
for i, (_, generated_text) in enumerate(outputs):
log_entry = {
"prompt": prompts[i],
"runner_name": runner_name,
"generated_text": generated_text,
}
logged_texts.append(log_entry)
return logged_texts
def validate_generated_texts(hf_runner,
vllm_runner,
prompts,
model_name,
hf_model_kwargs=None):
if hf_model_kwargs is None:
hf_model_kwargs = {}
# Run with HF runner
with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm:
hf_outputs = llm.generate_greedy(prompts, 8)
hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner")
# Clean up the GPU memory for the next test
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()
#Run with vLLM runner
with vllm_runner(model_name,
quantization='bitsandbytes',
load_format='bitsandbytes',
enforce_eager=True) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
# check the weights in MLP & SelfAttention are quantized to torch.uint8
qweight = model.model.layers[0].mlp.gate_up_proj.qweight
assert qweight.dtype == torch.uint8, (
f'Expected gate_up_proj dtype torch.uint8 but got {qweight.dtype}')
qweight = model.model.layers[0].mlp.down_proj.qweight
assert qweight.dtype == torch.uint8, (
f'Expected down_proj dtype torch.uint8 but got {qweight.dtype}')
qweight = model.model.layers[0].self_attn.o_proj.qweight
assert qweight.dtype == torch.uint8, (
f'Expected o_proj dtype torch.uint8 but got {qweight.dtype}')
qweight = model.model.layers[0].self_attn.qkv_proj.qweight
assert qweight.dtype == torch.uint8, (
f'Expected qkv_proj dtype torch.uint8 but got {qweight.dtype}')
# some weights should not be quantized
weight = model.lm_head.weight
assert weight.dtype != torch.uint8, (
'lm_head weight dtype should not be torch.uint8')
weight = model.model.embed_tokens.weight
assert weight.dtype != torch.uint8, (
'embed_tokens weight dtype should not be torch.uint8')
weight = model.model.layers[0].input_layernorm.weight
assert weight.dtype != torch.uint8, (
'input_layernorm weight dtype should not be torch.uint8')
weight = model.model.layers[0].post_attention_layernorm.weight
assert weight.dtype != torch.uint8, (
'input_layernorm weight dtype should not be torch.uint8')
# check the output of the model is expected
sampling_params = SamplingParams(temperature=0.0,
logprobs=1,
prompt_logprobs=1,
max_tokens=8)
prompts = ['That which does not kill us', 'To be or not to be,']
expected_outputs = [
'That which does not kill us makes us stronger.',
'To be or not to be, that is the question.'
]
outputs = llm.generate(prompts, sampling_params=sampling_params)
assert len(outputs) == len(prompts)
for index in range(len(outputs)):
# compare the first line of the output
actual_output = outputs[index][1][0].split('\n', 1)[0]
expected_output = expected_outputs[index].split('\n', 1)[0]
assert len(actual_output) >= len(expected_output), (
f'Actual {actual_output} should be larger than or equal to '
f'expected {expected_output}')
actual_output = actual_output[:len(expected_output)]
assert actual_output == expected_output, (
f'Expected: {expected_output}, but got: {actual_output}')
enforce_eager=True,
gpu_memory_utilization=0.8) as llm:
vllm_outputs = llm.generate_greedy(prompts, 8)
vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner")
# Clean up the GPU memory for the next test
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()
# Compare the generated strings
for hf_log, vllm_log in zip(hf_logs, vllm_logs):
hf_str = hf_log["generated_text"]
vllm_str = vllm_log["generated_text"]
prompt = hf_log["prompt"]
assert hf_str == vllm_str, (f"Model: {model_name}"
f"Mismatch between HF and vLLM outputs:\n"
f"Prompt: {prompt}\n"
f"HF Output: '{hf_str}'\n"
f"vLLM Output: '{vllm_str}'")
......@@ -160,4 +160,4 @@ def test_compressed_tensors_kv_cache(vllm_runner):
model_path = "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme"
with vllm_runner(model_path, kv_cache_dtype="fp8") as llm:
output = llm.generate_greedy("Hello world!", max_tokens=20)
assert output
assert output
\ No newline at end of file
......@@ -44,12 +44,16 @@ def mock_causal_accepted_tensor(
["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"])
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("use_flashinfer", [True, False])
@torch.inference_mode()
def test_correct_output_format(which_tokens_accepted: str,
disable_bonus_tokens: bool, seed: int,
device: str):
def test_correct_output_format(which_tokens_accepted: str, seed: int,
disable_bonus_tokens: bool, device: str,
use_flashinfer: bool):
"""Verify the output has correct format given predetermined accepted matrix.
"""
if use_flashinfer and disable_bonus_tokens:
pytest.skip("Flashinfer rejection sampler must enable bonus token.")
set_random_seed(seed)
torch.set_default_device(device)
......@@ -85,7 +89,8 @@ def test_correct_output_format(which_tokens_accepted: str,
dtype=torch.int64)
rejection_sampler = RejectionSampler(
disable_bonus_tokens=disable_bonus_tokens)
disable_bonus_tokens=disable_bonus_tokens,
use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device)
output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access
accepted,
......@@ -133,15 +138,20 @@ def test_correct_output_format(which_tokens_accepted: str,
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
@pytest.mark.parametrize("batch_size", list(range(1, 32)))
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("use_flashinfer", [True, False])
@torch.inference_mode()
def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
device: str):
device: str, use_flashinfer: bool):
torch.set_default_device(device)
rejection_sampler = RejectionSampler()
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device)
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size,
k + 1,
vocab_size,
dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
......@@ -161,16 +171,21 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
@pytest.mark.parametrize("n_rep", [100])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("use_flashinfer", [True, False])
@torch.inference_mode()
def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
frac_seeded: float, n_rep: int,
device: str):
frac_seeded: float, n_rep: int, device: str,
use_flashinfer: bool):
torch.set_default_device(device)
rejection_sampler = RejectionSampler()
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device)
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size,
k + 1,
vocab_size,
dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
......@@ -198,23 +213,85 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
assert torch.equal(results[j][i], results[0][i])
@pytest.mark.parametrize("k", [1, 3, 6])
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_compare_nonflashinfer_backend(k: int, vocab_size: int,
batch_size: int, device: str):
"""
Test the flashinfer and nonflashinfer backend generate
the same output metrics.
"""
torch.set_default_device(device)
torch.manual_seed(0)
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size,
k + 1,
vocab_size,
dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64)
num_accepted_tokens = []
num_emitted_tokens = []
num_draft_tokens = []
def get_seeded_seqs():
return {
i: torch.Generator(device=device).manual_seed(i)
for i in range(batch_size)
}
for use_flashinfer in [True, False]:
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device)
# We use seeded sequences to ensure the same tokens are accepted
# for both flashinfer and nonflashinfer backends.
seeded_seqs = get_seeded_seqs()
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
draft_token_ids, seeded_seqs)
num_accepted_tokens.append(rejection_sampler.num_accepted_tokens)
num_emitted_tokens.append(rejection_sampler.num_emitted_tokens)
num_draft_tokens.append(rejection_sampler.num_draft_tokens)
assert num_accepted_tokens[0] == num_accepted_tokens[1]
assert num_emitted_tokens[0] == num_emitted_tokens[1]
assert num_draft_tokens[0] == num_draft_tokens[1]
@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
@pytest.mark.parametrize("which_token_ids",
["bonus_token_ids", "draft_token_ids"])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("use_flashinfer", [True, False])
@torch.inference_mode()
def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
which_token_ids: str, device: str):
which_token_ids: str, device: str,
use_flashinfer: bool):
k = 3
batch_size = 5
vocab_size = 30_000
torch.set_default_device(device)
rejection_sampler = RejectionSampler(strict_mode=True)
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
use_flashinfer=use_flashinfer,
strict_mode=True)
rejection_sampler.init_gpu_tensors(device=device)
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size,
k + 1,
vocab_size,
dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
......@@ -248,9 +325,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
@pytest.mark.parametrize("draft_and_target_probs_equal", [True, False])
@pytest.mark.parametrize("seed", list(range(5)))
@pytest.mark.parametrize("use_flashinfer", [True, False])
@torch.inference_mode()
def test_rejection_sampling_approximates_target_distribution(
seed: int, draft_and_target_probs_equal: bool):
seed: int, draft_and_target_probs_equal: bool, use_flashinfer: bool):
"""Verify rejection sampling approximates target distribution,
despite sampling from a potentially distinct draft distribution.
......@@ -279,10 +357,10 @@ def test_rejection_sampling_approximates_target_distribution(
"""
torch.set_default_device("cpu")
set_random_seed(seed)
helper = _CorrectnessTestHelper(
vocab_size=10,
rejection_sampler=RejectionSampler(),
rejection_sampler=RejectionSampler(disable_bonus_tokens=False,
use_flashinfer=use_flashinfer),
)
draft_probs, target_probs, reference_probs = helper.generate_probs_for_test(
......@@ -398,10 +476,10 @@ class _CorrectnessTestHelper:
draft_probs = draft_probs.reshape(1, self.k, self.vocab_size).repeat(
num_samples, 1, 1)
# Repeat target probs num_samples * k times.
# Repeat target probs num_samples * (k + 1) times.
# Rejection sampler requires bonus token probs, but they aren't used.
target_probs = target_probs.reshape(1, 1, self.vocab_size).repeat(
num_samples, self.k, 1)
num_samples, self.k + 1, 1)
# Randomly sample draft token ids from draft probs.
draft_token_ids = torch.multinomial(draft_probs[:, 0, :],
......
......@@ -418,6 +418,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
prompt_len = seq_data.get_prompt_len()
seq_lens.append(prompt_len)
assert sgm.sampling_params is not None
if sgm.sampling_params.prompt_logprobs:
# with prompt_logprobs each token in the prompt has a row in
# logits
......@@ -533,6 +534,8 @@ def test_sampler_mixed(seed: int, device: str):
for i, (sequence_output, metadata) in enumerate(
zip(sampler_output, seq_group_metadata_list)):
assert metadata.sampling_params is not None
if metadata.sampling_params.use_beam_search:
continue
......@@ -550,6 +553,8 @@ def test_sampler_mixed(seed: int, device: str):
assert expected_tokens_item is not None
for n, nth_output in enumerate(sequence_output.samples):
assert metadata.sampling_params is not None
if (metadata.sampling_params.temperature == 0
or metadata.sampling_params.seed is not None):
# Ensure exact matches for greedy or random with seed
......
......@@ -79,7 +79,10 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler()
typical_acceptance_sampler.init_gpu_tensors(device=device)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_with_bonus_probs = torch.rand(batch_size,
k + 1,
vocab_size,
dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
......@@ -89,7 +92,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
size=(batch_size, k),
dtype=torch.int64)
# Verify that sampling succeeds for all cases.
typical_acceptance_sampler(target_probs,
typical_acceptance_sampler(target_with_bonus_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
......@@ -112,7 +115,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_with_bonus_probs = torch.rand(batch_size,
k + 1,
vocab_size,
dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
......@@ -141,7 +147,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
oob_token_ids[0][0] = rogue_token_id
with pytest.raises(AssertionError):
typical_acceptance_sampler(target_probs,
typical_acceptance_sampler(target_with_bonus_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
......@@ -172,7 +178,10 @@ def test_uniform_target_distribution_accepts_all_tokens(
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(device=device)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_with_bonus_probs = torch.rand(batch_size,
k + 1,
vocab_size,
dtype=torch.float32)
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
......@@ -182,7 +191,7 @@ def test_uniform_target_distribution_accepts_all_tokens(
size=(batch_size, 1),
dtype=torch.int64)
output_token_ids = typical_acceptance_sampler(
target_probs,
target_with_bonus_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
......@@ -229,8 +238,9 @@ def test_temperature_zero_target_distribution(seed: int,
# Simulate temperature 0 probability distribution for target probabilities
# and create target probabilities such that only 1 token id has
# probability 1.0
target_probs, zero_temperature_token_ids = get_zero_temperature_prob_dist(
batch_size, k, vocab_size)
target_with_bonus_probs, zero_temperature_token_ids = \
get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size)
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
# Populate draft_token_ids such that they exclude the token_ids
# with probability = 1.0
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
......@@ -245,7 +255,7 @@ def test_temperature_zero_target_distribution(seed: int,
# fallback to the greedy sampling for selecting 1 token for each sequence.
# Verify the same.
output_token_ids = typical_acceptance_sampler(
target_probs,
target_with_bonus_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
......@@ -289,8 +299,10 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
# For sequences 0 and 2 set the distribution to a temperature
# zero distribution. For sequences 1 and 3 set it to a uniform
# distribution.
target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist(
batch_size, k, vocab_size))
target_with_bonus_probs, zero_temperature_token_ids = \
get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size)
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
target_probs = target_with_bonus_probs[:, :-1]
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
zero_temperature_token_ids)
uniform_probs = torch.rand(2, k, vocab_size, dtype=torch.float32)
......@@ -300,7 +312,7 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
size=(batch_size, 1),
dtype=torch.int64)
output_token_ids = typical_acceptance_sampler(
target_probs,
target_with_bonus_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
......@@ -356,15 +368,16 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
# Create a temperature zero target probability distribution and ensure
# all draft token ids correspond to the tokens with 1.0 probability.
# Verify that all of them are accepted.
target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist(
batch_size, k, vocab_size))
target_with_bonus_probs, zero_temperature_token_ids = \
get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size)
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
draft_token_ids = zero_temperature_token_ids
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
output_token_ids = typical_acceptance_sampler(
target_probs,
target_with_bonus_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
......@@ -384,7 +397,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
draft_token_ids = torch.cat(
(draft_token_ids[:, :2], draft_token_ids_to_replace[:, -3:]), dim=1)
output_token_ids = typical_acceptance_sampler(
target_probs,
target_with_bonus_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
......@@ -421,8 +434,9 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
# 0.00001. Populate draft_token_ids such that they exclude the token_ids
# with probability = 1.0. Without any changes to the posterior thresholds
# none of the draft tokens are accepted.
target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist(
batch_size, k, vocab_size))
target_probs, zero_temperature_token_ids = get_zero_temperature_prob_dist(
batch_size, k + 1, vocab_size)
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
target_probs[target_probs == 0] = 0.00001
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
zero_temperature_token_ids)
......
......@@ -5,9 +5,10 @@ from unittest.mock import MagicMock
import pytest
import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import (ExecuteModelRequest, HiddenStates, Logprob,
SamplerOutput, get_all_seq_ids)
get_all_seq_ids)
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.top1_proposer import Top1Proposer
......
......@@ -7,8 +7,9 @@ from unittest.mock import MagicMock
import pytest
import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import ExecuteModelRequest, SamplerOutput, SequenceOutput
from vllm.sequence import ExecuteModelRequest, SequenceOutput
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
SpecDecodeWorkerMetrics)
......@@ -229,9 +230,8 @@ def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int,
assert torch.equal(actual.bonus_token_ids,
target_token_ids.reshape(batch_size, k + 1)[:, -1:])
assert torch.equal(
actual.target_probs,
target_token_probs.reshape(batch_size, k + 1, -1)[:, :-1])
assert torch.equal(actual.target_with_bonus_probs,
target_token_probs.reshape(batch_size, k + 1, -1))
assert torch.equal(actual.draft_token_ids, proposal_token_ids)
assert torch.equal(actual.draft_probs, proposal_probs)
......
......@@ -4,10 +4,12 @@ import pytest
import torch
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.layers.sampler import _get_ranks
from vllm.model_executor.layers.typical_acceptance_sampler import (
TypicalAcceptanceSampler)
from vllm.sequence import SequenceGroupMetadata, get_all_seq_ids
from vllm.spec_decode.util import split_batch_by_proposal_len
from vllm.spec_decode.util import (get_sampled_token_logprobs,
split_batch_by_proposal_len)
def test_get_all_seq_ids():
......@@ -55,10 +57,9 @@ def fake_sequence_group_metadata():
def test_filter_zero_length_proposals(fake_sequence_group_metadata):
proposal_lens = [0, 1, 0]
filtered_groups, indices = split_batch_by_proposal_len(
fake_sequence_group_metadata,
proposal_lens,
select_proposal_len_zero=True)
_, (filtered_groups,
indices) = split_batch_by_proposal_len(fake_sequence_group_metadata,
proposal_lens)
expected_groups = [
fake_sequence_group_metadata[0], fake_sequence_group_metadata[2]
......@@ -71,10 +72,9 @@ def test_filter_zero_length_proposals(fake_sequence_group_metadata):
def test_filter_non_zero_length_proposals(fake_sequence_group_metadata):
proposal_lens = [0, 1, 2]
filtered_groups, indices = split_batch_by_proposal_len(
fake_sequence_group_metadata,
proposal_lens,
select_proposal_len_zero=False)
(filtered_groups,
indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata,
proposal_lens)
expected_groups = [
fake_sequence_group_metadata[1], fake_sequence_group_metadata[2]
......@@ -86,8 +86,7 @@ def test_filter_non_zero_length_proposals(fake_sequence_group_metadata):
def test_empty_inputs():
filtered_groups, indices = split_batch_by_proposal_len(
[], [], select_proposal_len_zero=True)
_, (filtered_groups, indices) = split_batch_by_proposal_len([], [])
assert filtered_groups == []
assert indices == []
......@@ -95,10 +94,9 @@ def test_empty_inputs():
def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata):
proposal_lens = [0, 0, 0]
filtered_groups, indices = split_batch_by_proposal_len(
fake_sequence_group_metadata,
proposal_lens,
select_proposal_len_zero=False)
(filtered_groups,
indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata,
proposal_lens)
assert filtered_groups == []
assert indices == []
......@@ -106,10 +104,9 @@ def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata):
def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata):
proposal_lens = [1, 1, 1]
filtered_groups, indices = split_batch_by_proposal_len(
fake_sequence_group_metadata,
proposal_lens,
select_proposal_len_zero=True)
_, (filtered_groups,
indices) = split_batch_by_proposal_len(fake_sequence_group_metadata,
proposal_lens)
assert filtered_groups == []
assert indices == []
......@@ -131,3 +128,20 @@ def mock_spec_decode_sampler(acceptance_sampler_method):
return sampler
else:
raise ValueError(f"Invalid sampler name {acceptance_sampler_method}")
def test_get_sampled_token_logprobs():
"""Verify get_sampled_token_logprobs returns consistent rankings
with regular get_ranks when probabilities match exactly.
"""
logprob_tensor = torch.tensor(
[[[-.1, -.1]] * 2]) # shape (num_steps, batch_size, vocab_size)
sampled_token_tensor = torch.tensor([[1,
0]]) # shape (num_steps, batch_size)
ranks_spec_dec, _ = get_sampled_token_logprobs(logprob_tensor,
sampled_token_tensor)
ranks_regular = _get_ranks(logprob_tensor.reshape((2, -1)),
sampled_token_tensor.reshape(-1))
assert torch.equal(ranks_spec_dec.reshape(-1), ranks_regular)
......@@ -8,12 +8,12 @@ from unittest.mock import MagicMock
import torch
from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.utils import set_random_seed
from vllm.sampling_params import SamplingParams
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE,
CompletionSequenceGroupOutput, Logprob,
SamplerOutput, SequenceData, SequenceGroupMetadata,
SequenceOutput)
SequenceData, SequenceGroupMetadata, SequenceOutput)
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import ModelRunner
......
......@@ -2,9 +2,10 @@ from array import array
import pytest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE,
CompletionSequenceGroupOutput, SamplerOutput,
SequenceData, SequenceOutput)
CompletionSequenceGroupOutput, SequenceData,
SequenceOutput)
from .core.utils import create_dummy_prompt
......
......@@ -132,6 +132,16 @@ def parser():
return parser
@pytest.fixture
def parser_with_config():
parser = FlexibleArgumentParser()
parser.add_argument('serve')
parser.add_argument('--config', type=str)
parser.add_argument('--port', type=int)
parser.add_argument('--tensor-parallel-size', type=int)
return parser
def test_underscore_to_dash(parser):
args = parser.parse_args(['--image_input_type', 'pixel_values'])
assert args.image_input_type == 'pixel_values'
......@@ -176,3 +186,37 @@ def test_missing_required_argument(parser):
parser.add_argument('--required-arg', required=True)
with pytest.raises(SystemExit):
parser.parse_args([])
def test_cli_override_to_config(parser_with_config):
args = parser_with_config.parse_args([
'serve', '--config', './data/test_config.yaml',
'--tensor-parallel-size', '3'
])
assert args.tensor_parallel_size == 3
args = parser_with_config.parse_args([
'serve', '--tensor-parallel-size', '3', '--config',
'./data/test_config.yaml'
])
assert args.tensor_parallel_size == 3
def test_config_args(parser_with_config):
args = parser_with_config.parse_args(
['serve', '--config', './data/test_config.yaml'])
assert args.tensor_parallel_size == 2
def test_config_file(parser_with_config):
with pytest.raises(FileNotFoundError):
parser_with_config.parse_args(['serve', '--config', 'test_config.yml'])
with pytest.raises(ValueError):
parser_with_config.parse_args(
['serve', '--config', './data/test_config.json'])
with pytest.raises(ValueError):
parser_with_config.parse_args([
'serve', '--tensor-parallel-size', '3', '--config', '--batch-size',
'32'
])
import pytest
import pytest_asyncio
from huggingface_hub import snapshot_download
from tests.utils import RemoteOpenAIServer
from .utils import ARGS, CONFIGS, ServerConfig
# for each server config, download the model and return the config
@pytest.fixture(scope="session", params=CONFIGS.keys())
def server_config(request):
config = CONFIGS[request.param]
# download model and tokenizer using transformers
snapshot_download(config["model"])
yield CONFIGS[request.param]
# run this for each server config
@pytest.fixture(scope="session")
def server(request, server_config: ServerConfig):
model = server_config["model"]
args_for_model = server_config["arguments"]
with RemoteOpenAIServer(model, ARGS + args_for_model,
max_wait_seconds=480) as server:
yield server
@pytest_asyncio.fixture
async def client(server: RemoteOpenAIServer):
async with server.get_async_client() as async_client:
yield async_client
from typing import List
import openai
import pytest
from .utils import MESSAGES_WITHOUT_TOOLS, WEATHER_TOOL
# test: make sure chat completions without tools provided work even when tools
# are enabled. This makes sure tool call chat templates work, AND that the tool
# parser stream processing doesn't change the output of the model.
@pytest.mark.asyncio
async def test_chat_completion_without_tools(client: openai.AsyncOpenAI):
models = await client.models.list()
model_name: str = models.data[0].id
chat_completion = await client.chat.completions.create(
messages=MESSAGES_WITHOUT_TOOLS,
temperature=0,
max_tokens=150,
model=model_name,
logprobs=False)
choice = chat_completion.choices[0]
stop_reason = chat_completion.choices[0].finish_reason
output_text = chat_completion.choices[0].message.content
# check to make sure we got text
assert output_text is not None
assert len(output_text) > 0
assert stop_reason != "tool_calls"
# check to make sure no tool calls were returned
assert (choice.message.tool_calls is None
or len(choice.message.tool_calls) == 0)
# make the same request, streaming
stream = await client.chat.completions.create(
messages=MESSAGES_WITHOUT_TOOLS,
temperature=0,
max_tokens=150,
model=model_name,
logprobs=False,
stream=True,
)
chunks: List[str] = []
finish_reason_count = 0
role_sent: bool = False
# assemble streamed chunks
async for chunk in stream:
delta = chunk.choices[0].delta
# make sure the role is assistant
if delta.role:
assert not role_sent
assert delta.role == 'assistant'
role_sent = True
if delta.content:
chunks.append(delta.content)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
assert chunk.choices[0].finish_reason == choice.finish_reason
# make sure tool call chunks aren't being streamed
assert not delta.tool_calls or len(delta.tool_calls) == 0
# make sure the role was sent, only 1 finish reason was sent, that chunks
# were in fact sent, and that the chunks match non-streaming
assert role_sent
assert finish_reason_count == 1
assert len(chunks)
assert "".join(chunks) == output_text
# test: conversation with tools enabled and provided that should not invoke
# tools, to make sure we can still get normal chat completion responses
# and that they won't be parsed as tools
@pytest.mark.asyncio
async def test_chat_completion_with_tools(client: openai.AsyncOpenAI):
models = await client.models.list()
model_name: str = models.data[0].id
chat_completion = await client.chat.completions.create(
messages=MESSAGES_WITHOUT_TOOLS,
temperature=0,
max_tokens=150,
model=model_name,
tools=[WEATHER_TOOL],
logprobs=False)
choice = chat_completion.choices[0]
stop_reason = chat_completion.choices[0].finish_reason
output_text = chat_completion.choices[0].message.content
# check to make sure we got text
assert output_text is not None
assert stop_reason != 'tool_calls'
assert len(output_text) > 0
# check to make sure no tool calls were returned
assert (choice.message.tool_calls is None
or len(choice.message.tool_calls) == 0)
# make the same request, streaming
stream = await client.chat.completions.create(
messages=MESSAGES_WITHOUT_TOOLS,
temperature=0,
max_tokens=150,
model=model_name,
logprobs=False,
tools=[WEATHER_TOOL],
stream=True,
)
chunks: List[str] = []
finish_reason_count = 0
role_sent: bool = False
# assemble streamed chunks
async for chunk in stream:
delta = chunk.choices[0].delta
# make sure the role is assistant
if delta.role:
assert delta.role == 'assistant'
role_sent = True
if delta.content:
chunks.append(delta.content)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
# make sure tool call chunks aren't being streamed
assert not delta.tool_calls or len(delta.tool_calls) == 0
# make sure the role was sent, only 1 finish reason was sent, that chunks
# were in fact sent, and that the chunks match non-streaming
assert role_sent
assert finish_reason_count == 1
assert chunk.choices[0].finish_reason == stop_reason
assert chunk.choices[0].finish_reason != 'tool_calls'
assert len(chunks)
assert "".join(chunks) == output_text
import json
from typing import Dict, List, Optional
import openai
import pytest
from .utils import (MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, SEARCH_TOOL,
WEATHER_TOOL)
# test: getting the model to generate parallel tool calls (streaming/not)
# when requested. NOTE that not all models may support this, so some exclusions
# may be added in the future. e.g. llama 3.1 models are not designed to support
# parallel tool calls.
@pytest.mark.asyncio
async def test_parallel_tool_calls(client: openai.AsyncOpenAI):
models = await client.models.list()
model_name: str = models.data[0].id
chat_completion = await client.chat.completions.create(
messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
temperature=0,
max_tokens=200,
model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False)
choice = chat_completion.choices[0]
stop_reason = chat_completion.choices[0].finish_reason
non_streamed_tool_calls = chat_completion.choices[0].message.tool_calls
# make sure 2 tool calls are present
assert choice.message.role == "assistant"
assert non_streamed_tool_calls is not None
assert len(non_streamed_tool_calls) == 2
for tool_call in non_streamed_tool_calls:
# make sure the tool includes a function and ID
assert tool_call.type == "function"
assert tool_call.function is not None
assert isinstance(tool_call.id, str)
assert len(tool_call.id) > 16
# make sure the weather tool was called correctly
assert tool_call.function.name == WEATHER_TOOL["function"]["name"]
assert isinstance(tool_call.function.arguments, str)
parsed_arguments = json.loads(tool_call.function.arguments)
assert isinstance(parsed_arguments, Dict)
assert isinstance(parsed_arguments.get("city"), str)
assert isinstance(parsed_arguments.get("state"), str)
assert stop_reason == "tool_calls"
# make the same request, streaming
stream = await client.chat.completions.create(
model=model_name,
messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
temperature=0,
max_tokens=200,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False,
stream=True)
role_name: Optional[str] = None
finish_reason_count: int = 0
tool_call_names: List[str] = []
tool_call_args: List[str] = []
tool_call_idx: int = -1
tool_call_id_count: int = 0
async for chunk in stream:
# if there's a finish reason make sure it's tools
if chunk.choices[0].finish_reason:
finish_reason_count += 1
assert chunk.choices[0].finish_reason == 'tool_calls'
# if a role is being streamed make sure it wasn't already set to
# something else
if chunk.choices[0].delta.role:
assert not role_name or role_name == 'assistant'
role_name = 'assistant'
# if a tool call is streamed make sure there's exactly one
# (based on the request parameters
streamed_tool_calls = chunk.choices[0].delta.tool_calls
if streamed_tool_calls and len(streamed_tool_calls) > 0:
# make sure only one diff is present - correct even for parallel
assert len(streamed_tool_calls) == 1
tool_call = streamed_tool_calls[0]
# if a new tool is being called, set up empty arguments
if tool_call.index != tool_call_idx:
tool_call_idx = tool_call.index
tool_call_args.append("")
# if a tool call ID is streamed, make sure one hasn't been already
if tool_call.id:
tool_call_id_count += 1
assert (isinstance(tool_call.id, str)
and (len(tool_call.id) > 16))
# if parts of the function start being streamed
if tool_call.function:
# if the function name is defined, set it. it should be streamed
# IN ENTIRETY, exactly one time.
if tool_call.function.name:
assert isinstance(tool_call.function.name, str)
tool_call_names.append(tool_call.function.name)
if tool_call.function.arguments:
# make sure they're a string and then add them to the list
assert isinstance(tool_call.function.arguments, str)
tool_call_args[
tool_call.index] += tool_call.function.arguments
assert finish_reason_count == 1
assert role_name == 'assistant'
assert (len(non_streamed_tool_calls) == len(tool_call_names) ==
len(tool_call_args))
for i in range(2):
assert non_streamed_tool_calls[i].function.name == tool_call_names[i]
streamed_args = json.loads(tool_call_args[i])
non_streamed_args = json.loads(
non_streamed_tool_calls[i].function.arguments)
assert streamed_args == non_streamed_args
# test: providing parallel tool calls back to the model to get a response
# (streaming/not)
@pytest.mark.asyncio
async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI):
models = await client.models.list()
model_name: str = models.data[0].id
chat_completion = await client.chat.completions.create(
messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE,
temperature=0,
max_tokens=200,
model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False)
choice = chat_completion.choices[0]
assert choice.finish_reason != "tool_calls" # "stop" or "length"
assert choice.message.role == "assistant"
assert choice.message.tool_calls is None \
or len(choice.message.tool_calls) == 0
assert choice.message.content is not None
assert "98" in choice.message.content # Dallas temp in tool response
assert "78" in choice.message.content # Orlando temp in tool response
stream = await client.chat.completions.create(
messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE,
temperature=0,
max_tokens=200,
model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False,
stream=True)
chunks: List[str] = []
finish_reason_count = 0
role_sent: bool = False
async for chunk in stream:
delta = chunk.choices[0].delta
if delta.role:
assert not role_sent
assert delta.role == "assistant"
role_sent = True
if delta.content:
chunks.append(delta.content)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
assert chunk.choices[0].finish_reason == choice.finish_reason
assert not delta.tool_calls or len(delta.tool_calls) == 0
assert role_sent
assert finish_reason_count == 1
assert len(chunks)
assert "".join(chunks) == choice.message.content
import json
from typing import Dict, List, Optional
import openai
import pytest
from .utils import (MESSAGES_ASKING_FOR_TOOLS, MESSAGES_WITH_TOOL_RESPONSE,
SEARCH_TOOL, WEATHER_TOOL)
# test: request a chat completion that should return tool calls, so we know they
# are parsable
@pytest.mark.asyncio
async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
models = await client.models.list()
model_name: str = models.data[0].id
chat_completion = await client.chat.completions.create(
messages=MESSAGES_ASKING_FOR_TOOLS,
temperature=0,
max_tokens=100,
model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False)
choice = chat_completion.choices[0]
stop_reason = chat_completion.choices[0].finish_reason
tool_calls = chat_completion.choices[0].message.tool_calls
# make sure a tool call is present
assert choice.message.role == 'assistant'
assert tool_calls is not None
assert len(tool_calls) == 1
assert tool_calls[0].type == 'function'
assert tool_calls[0].function is not None
assert isinstance(tool_calls[0].id, str)
assert len(tool_calls[0].id) > 16
# make sure the weather tool was called (classic example) with arguments
assert tool_calls[0].function.name == WEATHER_TOOL["function"]["name"]
assert tool_calls[0].function.arguments is not None
assert isinstance(tool_calls[0].function.arguments, str)
# make sure the arguments parse properly
parsed_arguments = json.loads(tool_calls[0].function.arguments)
assert isinstance(parsed_arguments, Dict)
assert isinstance(parsed_arguments.get("city"), str)
assert isinstance(parsed_arguments.get("state"), str)
assert parsed_arguments.get("city") == "Dallas"
assert parsed_arguments.get("state") == "TX"
assert stop_reason == "tool_calls"
function_name: Optional[str] = None
function_args_str: str = ''
tool_call_id: Optional[str] = None
role_name: Optional[str] = None
finish_reason_count: int = 0
# make the same request, streaming
stream = await client.chat.completions.create(
model=model_name,
messages=MESSAGES_ASKING_FOR_TOOLS,
temperature=0,
max_tokens=100,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False,
stream=True)
async for chunk in stream:
assert chunk.choices[0].index == 0
if chunk.choices[0].finish_reason:
finish_reason_count += 1
assert chunk.choices[0].finish_reason == 'tool_calls'
# if a role is being streamed make sure it wasn't already set to
# something else
if chunk.choices[0].delta.role:
assert not role_name or role_name == 'assistant'
role_name = 'assistant'
# if a tool call is streamed make sure there's exactly one
# (based on the request parameters
streamed_tool_calls = chunk.choices[0].delta.tool_calls
if streamed_tool_calls and len(streamed_tool_calls) > 0:
assert len(streamed_tool_calls) == 1
tool_call = streamed_tool_calls[0]
# if a tool call ID is streamed, make sure one hasn't been already
if tool_call.id:
assert not tool_call_id
tool_call_id = tool_call.id
# if parts of the function start being streamed
if tool_call.function:
# if the function name is defined, set it. it should be streamed
# IN ENTIRETY, exactly one time.
if tool_call.function.name:
assert function_name is None
assert isinstance(tool_call.function.name, str)
function_name = tool_call.function.name
if tool_call.function.arguments:
assert isinstance(tool_call.function.arguments, str)
function_args_str += tool_call.function.arguments
assert finish_reason_count == 1
assert role_name == 'assistant'
assert isinstance(tool_call_id, str) and (len(tool_call_id) > 16)
# validate the name and arguments
assert function_name == WEATHER_TOOL["function"]["name"]
assert function_name == tool_calls[0].function.name
assert isinstance(function_args_str, str)
# validate arguments
streamed_args = json.loads(function_args_str)
assert isinstance(streamed_args, Dict)
assert isinstance(streamed_args.get("city"), str)
assert isinstance(streamed_args.get("state"), str)
assert streamed_args.get("city") == "Dallas"
assert streamed_args.get("state") == "TX"
# make sure everything matches non-streaming except for ID
assert function_name == tool_calls[0].function.name
assert choice.message.role == role_name
assert choice.message.tool_calls[0].function.name == function_name
# compare streamed with non-streamed args Dict-wise, not string-wise
# because character-to-character comparison might not work e.g. the tool
# call parser adding extra spaces or something like that. we care about the
# dicts matching not byte-wise match
assert parsed_arguments == streamed_args
# test: providing tools and results back to model to get a non-tool response
# (streaming/not)
@pytest.mark.asyncio
async def test_tool_call_with_results(client: openai.AsyncOpenAI):
models = await client.models.list()
model_name: str = models.data[0].id
chat_completion = await client.chat.completions.create(
messages=MESSAGES_WITH_TOOL_RESPONSE,
temperature=0,
max_tokens=100,
model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False)
choice = chat_completion.choices[0]
assert choice.finish_reason != "tool_calls" # "stop" or "length"
assert choice.message.role == "assistant"
assert choice.message.tool_calls is None \
or len(choice.message.tool_calls) == 0
assert choice.message.content is not None
assert "98" in choice.message.content # the temperature from the response
stream = await client.chat.completions.create(
messages=MESSAGES_WITH_TOOL_RESPONSE,
temperature=0,
max_tokens=100,
model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False,
stream=True)
chunks: List[str] = []
finish_reason_count = 0
role_sent: bool = False
async for chunk in stream:
delta = chunk.choices[0].delta
if delta.role:
assert not role_sent
assert delta.role == "assistant"
role_sent = True
if delta.content:
chunks.append(delta.content)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
assert chunk.choices[0].finish_reason == choice.finish_reason
assert not delta.tool_calls or len(delta.tool_calls) == 0
assert role_sent
assert finish_reason_count == 1
assert len(chunks)
assert "".join(chunks) == choice.message.content
from typing import Dict, List
from openai.types.chat import (ChatCompletionMessageParam,
ChatCompletionToolParam)
from typing_extensions import TypedDict
from tests.utils import VLLM_PATH
class ServerConfig(TypedDict):
model: str
arguments: List[str]
# universal args for all models go here. also good if you need to test locally
# and change type or KV cache quantization or something.
ARGS: List[str] = ["--enable-auto-tool-choice", "--max-model-len", "8096"]
CONFIGS: Dict[str, ServerConfig] = {
"hermes": {
"model":
"NousResearch/Hermes-2-Pro-Llama-3-8B",
"arguments": [
"--tool-call-parser", "hermes", "--chat-template",
str(VLLM_PATH / "examples/tool_chat_template_hermes.jinja")
]
},
"mistral": {
"model":
"mistralai/Mistral-7B-Instruct-v0.3",
"arguments": [
"--tool-call-parser", "mistral", "--chat-template",
str(VLLM_PATH / "examples/tool_chat_template_mistral.jinja"),
"--ignore-patterns=\"consolidated.safetensors\""
]
}
}
WEATHER_TOOL: ChatCompletionToolParam = {
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type":
"string",
"description":
"The city to find the weather for, "
"e.g. 'San Francisco'"
},
"state": {
"type":
"string",
"description":
"the two-letter abbreviation for the state "
"that the city is in, e.g. 'CA' which would "
"mean 'California'"
},
"unit": {
"type": "string",
"description": "The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"]
}
}
}
}
}
SEARCH_TOOL: ChatCompletionToolParam = {
"type": "function",
"function": {
"name":
"web_search",
"description":
"Search the internet and get a summary of the top "
"10 webpages. Should only be used if you don't know "
"the answer to a user query, and the results are likely"
"to be able to be found with a web search",
"parameters": {
"type": "object",
"properties": {
"search_term": {
"type":
"string",
"description":
"The term to use in the search. This should"
"ideally be keywords to search for, not a"
"natural-language question"
}
},
"required": ["search_term"]
}
}
}
MESSAGES_WITHOUT_TOOLS: List[ChatCompletionMessageParam] = [{
"role":
"system",
"content":
"You are a helpful assistant with access to tools. If a tool"
" that you have would be helpful to answer a user query, "
"call the tool. Otherwise, answer the user's query directly "
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
"to the user's question - just respond to it normally."
}, {
"role":
"user",
"content":
"Hi! How are you?"
}, {
"role":
"assistant",
"content":
"I'm doing great! How can I assist you?"
}, {
"role":
"user",
"content":
"Can you tell me a joke please?"
}]
MESSAGES_ASKING_FOR_TOOLS: List[ChatCompletionMessageParam] = [{
"role":
"user",
"content":
"What is the weather in Dallas, Texas in Fahrenheit?"
}]
MESSAGES_WITH_TOOL_RESPONSE: List[ChatCompletionMessageParam] = [{
"role":
"user",
"content":
"What is the weather in Dallas, Texas in Fahrenheit?"
}, {
"role":
"assistant",
"tool_calls": [{
"id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295",
"type": "function",
"function": {
"name":
WEATHER_TOOL["function"]["name"],
"arguments":
'{"city": "Dallas", "state": "TX", '
'"unit": "fahrenheit"}'
}
}]
}, {
"role":
"tool",
"tool_call_id":
"chatcmpl-tool-03e6481b146e408e9523d9c956696295",
"content":
"The weather in Dallas is 98 degrees fahrenheit, with partly"
"cloudy skies and a low chance of rain."
}]
MESSAGES_ASKING_FOR_PARALLEL_TOOLS: List[ChatCompletionMessageParam] = [{
"role":
"user",
"content":
"What is the weather in Dallas, Texas and Orlando, Florida in "
"Fahrenheit?"
}]
MESSAGES_WITH_PARALLEL_TOOL_RESPONSE: List[ChatCompletionMessageParam] = [{
"role":
"user",
"content":
"What is the weather in Dallas, Texas and Orlando, Florida in "
"Fahrenheit?"
}, {
"role":
"assistant",
"tool_calls": [{
"id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295",
"type": "function",
"function": {
"name":
WEATHER_TOOL["function"]["name"],
"arguments":
'{"city": "Dallas", "state": "TX", '
'"unit": "fahrenheit"}'
}
}, {
"id": "chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b",
"type": "function",
"function": {
"name":
WEATHER_TOOL["function"]["name"],
"arguments":
'{"city": "Orlando", "state": "Fl", '
'"unit": "fahrenheit"}'
}
}]
}, {
"role":
"tool",
"tool_call_id":
"chatcmpl-tool-03e6481b146e408e9523d9c956696295",
"content":
"The weather in Dallas TX is 98 degrees fahrenheit with mostly "
"cloudy skies and a chance of rain in the evening."
}, {
"role":
"tool",
"tool_call_id":
"chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b",
"content":
"The weather in Orlando FL is 78 degrees fahrenheit with clear"
"skies."
}]
import glob
import os
import runpy
import tempfile
import depyf
# disable custom dispatcher, let Dynamo takes over
# all the control
os.environ['VLLM_DYNAMO_USE_CUSTOM_DISPATCHER'] = "0"
temp_dir = tempfile.mkdtemp()
with depyf.prepare_debug(temp_dir):
cur_dir = os.path.dirname(__file__)
parent_dir = os.path.dirname(cur_dir)
root_dir = os.path.dirname(parent_dir)
example_file = os.path.join(root_dir, "examples",
"offline_inference_tpu.py")
runpy.run_path(example_file)
compiled_code = sorted(
glob.glob(os.path.join(temp_dir, "__transformed_code*.py")))
# we should only trigger Dynamo compilation three times:
# one for the profiling phase without kv cache
# one for the prefill phase with symbolic shapes
# one for the decode phase with symbolic shapes
# and later calls should not trigger Dynamo compilation again.
# NOTE: it might still trigger XLA compilation.
# check we have three compiled code
# this is the assumption when we use the custom dispatcher
assert len(compiled_code) == 3
# check all the compilations are as expected
compiled_fn = sorted(
glob.glob(os.path.join(temp_dir, "__compiled_fn*Captured*.py")))
# the first compilation is the profiling phase,
# it should not have any kv cache
with open(compiled_fn[0]) as f:
content = f.read()
assert "kv_caches" not in content
# the second compilation is the prefill phase,
# it should have kv cache and the flash_attention op
with open(compiled_fn[1]) as f:
content = f.read()
assert "kv_caches" in content and "torch.ops.xla.flash_attention" in content
# the third compilation is the decode phase,
# it should have kv cache and the paged_attention op
with open(compiled_fn[2]) as f:
content = f.read()
assert "kv_caches" in content and "torch.ops.xla.paged_attention" in content
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment