Commit 0640f227 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.0' into v0.6.0-dev

parents 82f1ffdf 32e7db25
...@@ -6,8 +6,10 @@ from typing import Dict, Tuple ...@@ -6,8 +6,10 @@ from typing import Dict, Tuple
import numpy as np import numpy as np
import pytest import pytest
from PIL import Image from PIL import Image
from transformers import AutoConfig, AutoTokenizer
from vllm.multimodal.utils import async_fetch_image, fetch_image from vllm.multimodal.utils import (async_fetch_image, fetch_image,
repeat_and_pad_placeholder_tokens)
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) # Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_URLS = [ TEST_IMAGE_URLS = [
...@@ -80,3 +82,34 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image], ...@@ -80,3 +82,34 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
data_image_async = await async_fetch_image(data_url) data_image_async = await async_fetch_image(data_url)
assert _image_equals(data_image_sync, data_image_async) assert _image_equals(data_image_sync, data_image_async)
@pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-mistral-7b-hf"])
def test_repeat_and_pad_placeholder_tokens(model):
config = AutoConfig.from_pretrained(model)
image_token_id = config.image_token_index
tokenizer = AutoTokenizer.from_pretrained(model)
test_cases = [
("<image>", 2, "<image><image>", [32000, 32000]),
("<image><image>", 2, "<image><image><image>", [32000, 32000, 32000]),
("<image><image>", [3, 2], "<image><image><image><image><image>",
[32000, 32000, 32000, 32000, 32000]),
("Image:<image>Image:<image>!", [3, 2],
"Image:<image><image><image>Image:<image><image>!",
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918]),
("<image>", [3, 2], "<image><image><image>", [32000, 32000, 32000]),
]
for prompt, repeat_count, expected_prompt, expected_token_ids in test_cases:
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
tokenizer=tokenizer,
prompt=prompt,
prompt_token_ids=tokenizer.encode(prompt,
add_special_tokens=False),
placeholder_token_id=image_token_id,
repeat_count=repeat_count,
)
assert new_prompt == expected_prompt
assert new_token_ids == expected_token_ids
...@@ -2,85 +2,115 @@ ...@@ -2,85 +2,115 @@
Run `pytest tests/quantization/test_bitsandbytes.py`. Run `pytest tests/quantization/test_bitsandbytes.py`.
''' '''
import gc
import pytest import pytest
import torch import torch
from tests.quantization.utils import is_quant_method_supported from tests.quantization.utils import is_quant_method_supported
from vllm import SamplingParams
models_to_test = [ models_4bit_to_test = [
('huggyllama/llama-7b', 'quantize model inflight'), ('huggyllama/llama-7b', 'quantize model inflight'),
('lllyasviel/omost-llama-3-8b-4bits', 'read pre-quantized model'),
] ]
models_pre_qaunt_4bit_to_test = [
('lllyasviel/omost-llama-3-8b-4bits',
'read pre-quantized 4-bit NF4 model'),
('PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed',
'read pre-quantized 4-bit FP4 model'),
]
models_pre_quant_8bit_to_test = [
('meta-llama/Llama-Guard-3-8B-INT8', 'read pre-quantized 8-bit model'),
]
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
reason='bitsandbytes is not supported on this GPU type.')
@pytest.mark.parametrize("model_name, description", models_4bit_to_test)
def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
model_name, description) -> None:
hf_model_kwargs = {"load_in_4bit": True}
validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1],
model_name, hf_model_kwargs)
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
reason='bitsandbytes is not supported on this GPU type.')
@pytest.mark.parametrize("model_name, description",
models_pre_qaunt_4bit_to_test)
def test_load_pre_quant_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
model_name, description) -> None:
validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1],
model_name)
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), @pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
reason='bitsandbytes is not supported on this GPU type.') reason='bitsandbytes is not supported on this GPU type.')
@pytest.mark.parametrize("model_name, description", models_to_test) @pytest.mark.parametrize("model_name, description",
def test_load_bnb_model(vllm_runner, model_name, description) -> None: models_pre_quant_8bit_to_test)
def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts,
model_name, description) -> None:
validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1],
model_name)
def log_generated_texts(prompts, outputs, runner_name):
logged_texts = []
for i, (_, generated_text) in enumerate(outputs):
log_entry = {
"prompt": prompts[i],
"runner_name": runner_name,
"generated_text": generated_text,
}
logged_texts.append(log_entry)
return logged_texts
def validate_generated_texts(hf_runner,
vllm_runner,
prompts,
model_name,
hf_model_kwargs=None):
if hf_model_kwargs is None:
hf_model_kwargs = {}
# Run with HF runner
with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm:
hf_outputs = llm.generate_greedy(prompts, 8)
hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner")
# Clean up the GPU memory for the next test
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()
#Run with vLLM runner
with vllm_runner(model_name, with vllm_runner(model_name,
quantization='bitsandbytes', quantization='bitsandbytes',
load_format='bitsandbytes', load_format='bitsandbytes',
enforce_eager=True) as llm: enforce_eager=True,
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 gpu_memory_utilization=0.8) as llm:
vllm_outputs = llm.generate_greedy(prompts, 8)
# check the weights in MLP & SelfAttention are quantized to torch.uint8 vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner")
qweight = model.model.layers[0].mlp.gate_up_proj.qweight
assert qweight.dtype == torch.uint8, ( # Clean up the GPU memory for the next test
f'Expected gate_up_proj dtype torch.uint8 but got {qweight.dtype}') torch.cuda.synchronize()
gc.collect()
qweight = model.model.layers[0].mlp.down_proj.qweight torch.cuda.empty_cache()
assert qweight.dtype == torch.uint8, (
f'Expected down_proj dtype torch.uint8 but got {qweight.dtype}') # Compare the generated strings
for hf_log, vllm_log in zip(hf_logs, vllm_logs):
qweight = model.model.layers[0].self_attn.o_proj.qweight hf_str = hf_log["generated_text"]
assert qweight.dtype == torch.uint8, ( vllm_str = vllm_log["generated_text"]
f'Expected o_proj dtype torch.uint8 but got {qweight.dtype}') prompt = hf_log["prompt"]
assert hf_str == vllm_str, (f"Model: {model_name}"
qweight = model.model.layers[0].self_attn.qkv_proj.qweight f"Mismatch between HF and vLLM outputs:\n"
assert qweight.dtype == torch.uint8, ( f"Prompt: {prompt}\n"
f'Expected qkv_proj dtype torch.uint8 but got {qweight.dtype}') f"HF Output: '{hf_str}'\n"
f"vLLM Output: '{vllm_str}'")
# some weights should not be quantized
weight = model.lm_head.weight
assert weight.dtype != torch.uint8, (
'lm_head weight dtype should not be torch.uint8')
weight = model.model.embed_tokens.weight
assert weight.dtype != torch.uint8, (
'embed_tokens weight dtype should not be torch.uint8')
weight = model.model.layers[0].input_layernorm.weight
assert weight.dtype != torch.uint8, (
'input_layernorm weight dtype should not be torch.uint8')
weight = model.model.layers[0].post_attention_layernorm.weight
assert weight.dtype != torch.uint8, (
'input_layernorm weight dtype should not be torch.uint8')
# check the output of the model is expected
sampling_params = SamplingParams(temperature=0.0,
logprobs=1,
prompt_logprobs=1,
max_tokens=8)
prompts = ['That which does not kill us', 'To be or not to be,']
expected_outputs = [
'That which does not kill us makes us stronger.',
'To be or not to be, that is the question.'
]
outputs = llm.generate(prompts, sampling_params=sampling_params)
assert len(outputs) == len(prompts)
for index in range(len(outputs)):
# compare the first line of the output
actual_output = outputs[index][1][0].split('\n', 1)[0]
expected_output = expected_outputs[index].split('\n', 1)[0]
assert len(actual_output) >= len(expected_output), (
f'Actual {actual_output} should be larger than or equal to '
f'expected {expected_output}')
actual_output = actual_output[:len(expected_output)]
assert actual_output == expected_output, (
f'Expected: {expected_output}, but got: {actual_output}')
...@@ -160,4 +160,4 @@ def test_compressed_tensors_kv_cache(vllm_runner): ...@@ -160,4 +160,4 @@ def test_compressed_tensors_kv_cache(vllm_runner):
model_path = "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme" model_path = "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme"
with vllm_runner(model_path, kv_cache_dtype="fp8") as llm: with vllm_runner(model_path, kv_cache_dtype="fp8") as llm:
output = llm.generate_greedy("Hello world!", max_tokens=20) output = llm.generate_greedy("Hello world!", max_tokens=20)
assert output assert output
\ No newline at end of file
...@@ -44,12 +44,16 @@ def mock_causal_accepted_tensor( ...@@ -44,12 +44,16 @@ def mock_causal_accepted_tensor(
["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"]) ["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"])
@pytest.mark.parametrize("disable_bonus_tokens", [True, False]) @pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("use_flashinfer", [True, False])
@torch.inference_mode() @torch.inference_mode()
def test_correct_output_format(which_tokens_accepted: str, def test_correct_output_format(which_tokens_accepted: str, seed: int,
disable_bonus_tokens: bool, seed: int, disable_bonus_tokens: bool, device: str,
device: str): use_flashinfer: bool):
"""Verify the output has correct format given predetermined accepted matrix. """Verify the output has correct format given predetermined accepted matrix.
""" """
if use_flashinfer and disable_bonus_tokens:
pytest.skip("Flashinfer rejection sampler must enable bonus token.")
set_random_seed(seed) set_random_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
...@@ -85,7 +89,8 @@ def test_correct_output_format(which_tokens_accepted: str, ...@@ -85,7 +89,8 @@ def test_correct_output_format(which_tokens_accepted: str,
dtype=torch.int64) dtype=torch.int64)
rejection_sampler = RejectionSampler( rejection_sampler = RejectionSampler(
disable_bonus_tokens=disable_bonus_tokens) disable_bonus_tokens=disable_bonus_tokens,
use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device) rejection_sampler.init_gpu_tensors(device=device)
output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access
accepted, accepted,
...@@ -133,15 +138,20 @@ def test_correct_output_format(which_tokens_accepted: str, ...@@ -133,15 +138,20 @@ def test_correct_output_format(which_tokens_accepted: str,
@pytest.mark.parametrize("vocab_size", [30_000, 50_000]) @pytest.mark.parametrize("vocab_size", [30_000, 50_000])
@pytest.mark.parametrize("batch_size", list(range(1, 32))) @pytest.mark.parametrize("batch_size", list(range(1, 32)))
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("use_flashinfer", [True, False])
@torch.inference_mode() @torch.inference_mode()
def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int, def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
device: str): device: str, use_flashinfer: bool):
torch.set_default_device(device) torch.set_default_device(device)
rejection_sampler = RejectionSampler() rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device) rejection_sampler.init_gpu_tensors(device=device)
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) target_probs = torch.rand(batch_size,
k + 1,
vocab_size,
dtype=torch.float32)
bonus_token_ids = torch.randint(low=0, bonus_token_ids = torch.randint(low=0,
high=vocab_size, high=vocab_size,
size=(batch_size, 1), size=(batch_size, 1),
...@@ -161,16 +171,21 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int, ...@@ -161,16 +171,21 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128]) @pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
@pytest.mark.parametrize("n_rep", [100]) @pytest.mark.parametrize("n_rep", [100])
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("use_flashinfer", [True, False])
@torch.inference_mode() @torch.inference_mode()
def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int, def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
frac_seeded: float, n_rep: int, frac_seeded: float, n_rep: int, device: str,
device: str): use_flashinfer: bool):
torch.set_default_device(device) torch.set_default_device(device)
rejection_sampler = RejectionSampler() rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device) rejection_sampler.init_gpu_tensors(device=device)
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) target_probs = torch.rand(batch_size,
k + 1,
vocab_size,
dtype=torch.float32)
bonus_token_ids = torch.randint(low=0, bonus_token_ids = torch.randint(low=0,
high=vocab_size, high=vocab_size,
size=(batch_size, 1), size=(batch_size, 1),
...@@ -198,23 +213,85 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int, ...@@ -198,23 +213,85 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
assert torch.equal(results[j][i], results[0][i]) assert torch.equal(results[j][i], results[0][i])
@pytest.mark.parametrize("k", [1, 3, 6])
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_compare_nonflashinfer_backend(k: int, vocab_size: int,
batch_size: int, device: str):
"""
Test the flashinfer and nonflashinfer backend generate
the same output metrics.
"""
torch.set_default_device(device)
torch.manual_seed(0)
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size,
k + 1,
vocab_size,
dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64)
num_accepted_tokens = []
num_emitted_tokens = []
num_draft_tokens = []
def get_seeded_seqs():
return {
i: torch.Generator(device=device).manual_seed(i)
for i in range(batch_size)
}
for use_flashinfer in [True, False]:
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device)
# We use seeded sequences to ensure the same tokens are accepted
# for both flashinfer and nonflashinfer backends.
seeded_seqs = get_seeded_seqs()
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
draft_token_ids, seeded_seqs)
num_accepted_tokens.append(rejection_sampler.num_accepted_tokens)
num_emitted_tokens.append(rejection_sampler.num_emitted_tokens)
num_draft_tokens.append(rejection_sampler.num_draft_tokens)
assert num_accepted_tokens[0] == num_accepted_tokens[1]
assert num_emitted_tokens[0] == num_emitted_tokens[1]
assert num_draft_tokens[0] == num_draft_tokens[1]
@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"]) @pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
@pytest.mark.parametrize("which_token_ids", @pytest.mark.parametrize("which_token_ids",
["bonus_token_ids", "draft_token_ids"]) ["bonus_token_ids", "draft_token_ids"])
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("use_flashinfer", [True, False])
@torch.inference_mode() @torch.inference_mode()
def test_raises_when_vocab_oob(above_or_below_vocab_range: str, def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
which_token_ids: str, device: str): which_token_ids: str, device: str,
use_flashinfer: bool):
k = 3 k = 3
batch_size = 5 batch_size = 5
vocab_size = 30_000 vocab_size = 30_000
torch.set_default_device(device) torch.set_default_device(device)
rejection_sampler = RejectionSampler(strict_mode=True) rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
use_flashinfer=use_flashinfer,
strict_mode=True)
rejection_sampler.init_gpu_tensors(device=device) rejection_sampler.init_gpu_tensors(device=device)
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) target_probs = torch.rand(batch_size,
k + 1,
vocab_size,
dtype=torch.float32)
bonus_token_ids = torch.randint(low=0, bonus_token_ids = torch.randint(low=0,
high=vocab_size, high=vocab_size,
size=(batch_size, 1), size=(batch_size, 1),
...@@ -248,9 +325,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str, ...@@ -248,9 +325,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
@pytest.mark.parametrize("draft_and_target_probs_equal", [True, False]) @pytest.mark.parametrize("draft_and_target_probs_equal", [True, False])
@pytest.mark.parametrize("seed", list(range(5))) @pytest.mark.parametrize("seed", list(range(5)))
@pytest.mark.parametrize("use_flashinfer", [True, False])
@torch.inference_mode() @torch.inference_mode()
def test_rejection_sampling_approximates_target_distribution( def test_rejection_sampling_approximates_target_distribution(
seed: int, draft_and_target_probs_equal: bool): seed: int, draft_and_target_probs_equal: bool, use_flashinfer: bool):
"""Verify rejection sampling approximates target distribution, """Verify rejection sampling approximates target distribution,
despite sampling from a potentially distinct draft distribution. despite sampling from a potentially distinct draft distribution.
...@@ -279,10 +357,10 @@ def test_rejection_sampling_approximates_target_distribution( ...@@ -279,10 +357,10 @@ def test_rejection_sampling_approximates_target_distribution(
""" """
torch.set_default_device("cpu") torch.set_default_device("cpu")
set_random_seed(seed) set_random_seed(seed)
helper = _CorrectnessTestHelper( helper = _CorrectnessTestHelper(
vocab_size=10, vocab_size=10,
rejection_sampler=RejectionSampler(), rejection_sampler=RejectionSampler(disable_bonus_tokens=False,
use_flashinfer=use_flashinfer),
) )
draft_probs, target_probs, reference_probs = helper.generate_probs_for_test( draft_probs, target_probs, reference_probs = helper.generate_probs_for_test(
...@@ -398,10 +476,10 @@ class _CorrectnessTestHelper: ...@@ -398,10 +476,10 @@ class _CorrectnessTestHelper:
draft_probs = draft_probs.reshape(1, self.k, self.vocab_size).repeat( draft_probs = draft_probs.reshape(1, self.k, self.vocab_size).repeat(
num_samples, 1, 1) num_samples, 1, 1)
# Repeat target probs num_samples * k times. # Repeat target probs num_samples * (k + 1) times.
# Rejection sampler requires bonus token probs, but they aren't used. # Rejection sampler requires bonus token probs, but they aren't used.
target_probs = target_probs.reshape(1, 1, self.vocab_size).repeat( target_probs = target_probs.reshape(1, 1, self.vocab_size).repeat(
num_samples, self.k, 1) num_samples, self.k + 1, 1)
# Randomly sample draft token ids from draft probs. # Randomly sample draft token ids from draft probs.
draft_token_ids = torch.multinomial(draft_probs[:, 0, :], draft_token_ids = torch.multinomial(draft_probs[:, 0, :],
......
...@@ -418,6 +418,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): ...@@ -418,6 +418,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
prompt_len = seq_data.get_prompt_len() prompt_len = seq_data.get_prompt_len()
seq_lens.append(prompt_len) seq_lens.append(prompt_len)
assert sgm.sampling_params is not None
if sgm.sampling_params.prompt_logprobs: if sgm.sampling_params.prompt_logprobs:
# with prompt_logprobs each token in the prompt has a row in # with prompt_logprobs each token in the prompt has a row in
# logits # logits
...@@ -533,6 +534,8 @@ def test_sampler_mixed(seed: int, device: str): ...@@ -533,6 +534,8 @@ def test_sampler_mixed(seed: int, device: str):
for i, (sequence_output, metadata) in enumerate( for i, (sequence_output, metadata) in enumerate(
zip(sampler_output, seq_group_metadata_list)): zip(sampler_output, seq_group_metadata_list)):
assert metadata.sampling_params is not None
if metadata.sampling_params.use_beam_search: if metadata.sampling_params.use_beam_search:
continue continue
...@@ -550,6 +553,8 @@ def test_sampler_mixed(seed: int, device: str): ...@@ -550,6 +553,8 @@ def test_sampler_mixed(seed: int, device: str):
assert expected_tokens_item is not None assert expected_tokens_item is not None
for n, nth_output in enumerate(sequence_output.samples): for n, nth_output in enumerate(sequence_output.samples):
assert metadata.sampling_params is not None
if (metadata.sampling_params.temperature == 0 if (metadata.sampling_params.temperature == 0
or metadata.sampling_params.seed is not None): or metadata.sampling_params.seed is not None):
# Ensure exact matches for greedy or random with seed # Ensure exact matches for greedy or random with seed
......
...@@ -79,7 +79,10 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int, ...@@ -79,7 +79,10 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
torch.set_default_device(device) torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler() typical_acceptance_sampler = get_acceptance_sampler()
typical_acceptance_sampler.init_gpu_tensors(device=device) typical_acceptance_sampler.init_gpu_tensors(device=device)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) target_with_bonus_probs = torch.rand(batch_size,
k + 1,
vocab_size,
dtype=torch.float32)
bonus_token_ids = torch.randint(low=0, bonus_token_ids = torch.randint(low=0,
high=vocab_size, high=vocab_size,
size=(batch_size, 1), size=(batch_size, 1),
...@@ -89,7 +92,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int, ...@@ -89,7 +92,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
size=(batch_size, k), size=(batch_size, k),
dtype=torch.int64) dtype=torch.int64)
# Verify that sampling succeeds for all cases. # Verify that sampling succeeds for all cases.
typical_acceptance_sampler(target_probs, typical_acceptance_sampler(target_with_bonus_probs,
bonus_token_ids, bonus_token_ids,
draft_probs=None, draft_probs=None,
draft_token_ids=draft_token_ids) draft_token_ids=draft_token_ids)
...@@ -112,7 +115,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str, ...@@ -112,7 +115,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
torch.set_default_device(device) torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True) typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device) typical_acceptance_sampler.init_gpu_tensors(device=device)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) target_with_bonus_probs = torch.rand(batch_size,
k + 1,
vocab_size,
dtype=torch.float32)
bonus_token_ids = torch.randint(low=0, bonus_token_ids = torch.randint(low=0,
high=vocab_size, high=vocab_size,
size=(batch_size, 1), size=(batch_size, 1),
...@@ -141,7 +147,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str, ...@@ -141,7 +147,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
oob_token_ids[0][0] = rogue_token_id oob_token_ids[0][0] = rogue_token_id
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
typical_acceptance_sampler(target_probs, typical_acceptance_sampler(target_with_bonus_probs,
bonus_token_ids, bonus_token_ids,
draft_probs=None, draft_probs=None,
draft_token_ids=draft_token_ids) draft_token_ids=draft_token_ids)
...@@ -172,7 +178,10 @@ def test_uniform_target_distribution_accepts_all_tokens( ...@@ -172,7 +178,10 @@ def test_uniform_target_distribution_accepts_all_tokens(
typical_acceptance_sampler = get_acceptance_sampler( typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(device=device) typical_acceptance_sampler.init_gpu_tensors(device=device)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) target_with_bonus_probs = torch.rand(batch_size,
k + 1,
vocab_size,
dtype=torch.float32)
draft_token_ids = torch.randint(low=0, draft_token_ids = torch.randint(low=0,
high=vocab_size, high=vocab_size,
size=(batch_size, k), size=(batch_size, k),
...@@ -182,7 +191,7 @@ def test_uniform_target_distribution_accepts_all_tokens( ...@@ -182,7 +191,7 @@ def test_uniform_target_distribution_accepts_all_tokens(
size=(batch_size, 1), size=(batch_size, 1),
dtype=torch.int64) dtype=torch.int64)
output_token_ids = typical_acceptance_sampler( output_token_ids = typical_acceptance_sampler(
target_probs, target_with_bonus_probs,
bonus_token_ids, bonus_token_ids,
draft_probs=None, draft_probs=None,
draft_token_ids=draft_token_ids) draft_token_ids=draft_token_ids)
...@@ -229,8 +238,9 @@ def test_temperature_zero_target_distribution(seed: int, ...@@ -229,8 +238,9 @@ def test_temperature_zero_target_distribution(seed: int,
# Simulate temperature 0 probability distribution for target probabilities # Simulate temperature 0 probability distribution for target probabilities
# and create target probabilities such that only 1 token id has # and create target probabilities such that only 1 token id has
# probability 1.0 # probability 1.0
target_probs, zero_temperature_token_ids = get_zero_temperature_prob_dist( target_with_bonus_probs, zero_temperature_token_ids = \
batch_size, k, vocab_size) get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size)
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
# Populate draft_token_ids such that they exclude the token_ids # Populate draft_token_ids such that they exclude the token_ids
# with probability = 1.0 # with probability = 1.0
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size, draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
...@@ -245,7 +255,7 @@ def test_temperature_zero_target_distribution(seed: int, ...@@ -245,7 +255,7 @@ def test_temperature_zero_target_distribution(seed: int,
# fallback to the greedy sampling for selecting 1 token for each sequence. # fallback to the greedy sampling for selecting 1 token for each sequence.
# Verify the same. # Verify the same.
output_token_ids = typical_acceptance_sampler( output_token_ids = typical_acceptance_sampler(
target_probs, target_with_bonus_probs,
bonus_token_ids, bonus_token_ids,
draft_probs=None, draft_probs=None,
draft_token_ids=draft_token_ids) draft_token_ids=draft_token_ids)
...@@ -289,8 +299,10 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool, ...@@ -289,8 +299,10 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
# For sequences 0 and 2 set the distribution to a temperature # For sequences 0 and 2 set the distribution to a temperature
# zero distribution. For sequences 1 and 3 set it to a uniform # zero distribution. For sequences 1 and 3 set it to a uniform
# distribution. # distribution.
target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist( target_with_bonus_probs, zero_temperature_token_ids = \
batch_size, k, vocab_size)) get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size)
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
target_probs = target_with_bonus_probs[:, :-1]
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size, draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
zero_temperature_token_ids) zero_temperature_token_ids)
uniform_probs = torch.rand(2, k, vocab_size, dtype=torch.float32) uniform_probs = torch.rand(2, k, vocab_size, dtype=torch.float32)
...@@ -300,7 +312,7 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool, ...@@ -300,7 +312,7 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
size=(batch_size, 1), size=(batch_size, 1),
dtype=torch.int64) dtype=torch.int64)
output_token_ids = typical_acceptance_sampler( output_token_ids = typical_acceptance_sampler(
target_probs, target_with_bonus_probs,
bonus_token_ids, bonus_token_ids,
draft_probs=None, draft_probs=None,
draft_token_ids=draft_token_ids) draft_token_ids=draft_token_ids)
...@@ -356,15 +368,16 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool, ...@@ -356,15 +368,16 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
# Create a temperature zero target probability distribution and ensure # Create a temperature zero target probability distribution and ensure
# all draft token ids correspond to the tokens with 1.0 probability. # all draft token ids correspond to the tokens with 1.0 probability.
# Verify that all of them are accepted. # Verify that all of them are accepted.
target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist( target_with_bonus_probs, zero_temperature_token_ids = \
batch_size, k, vocab_size)) get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size)
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
draft_token_ids = zero_temperature_token_ids draft_token_ids = zero_temperature_token_ids
bonus_token_ids = torch.randint(low=0, bonus_token_ids = torch.randint(low=0,
high=vocab_size, high=vocab_size,
size=(batch_size, 1), size=(batch_size, 1),
dtype=torch.int64) dtype=torch.int64)
output_token_ids = typical_acceptance_sampler( output_token_ids = typical_acceptance_sampler(
target_probs, target_with_bonus_probs,
bonus_token_ids, bonus_token_ids,
draft_probs=None, draft_probs=None,
draft_token_ids=draft_token_ids) draft_token_ids=draft_token_ids)
...@@ -384,7 +397,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool, ...@@ -384,7 +397,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
draft_token_ids = torch.cat( draft_token_ids = torch.cat(
(draft_token_ids[:, :2], draft_token_ids_to_replace[:, -3:]), dim=1) (draft_token_ids[:, :2], draft_token_ids_to_replace[:, -3:]), dim=1)
output_token_ids = typical_acceptance_sampler( output_token_ids = typical_acceptance_sampler(
target_probs, target_with_bonus_probs,
bonus_token_ids, bonus_token_ids,
draft_probs=None, draft_probs=None,
draft_token_ids=draft_token_ids) draft_token_ids=draft_token_ids)
...@@ -421,8 +434,9 @@ def test_accept_tokens_set_non_default_posteriors(seed: int, ...@@ -421,8 +434,9 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
# 0.00001. Populate draft_token_ids such that they exclude the token_ids # 0.00001. Populate draft_token_ids such that they exclude the token_ids
# with probability = 1.0. Without any changes to the posterior thresholds # with probability = 1.0. Without any changes to the posterior thresholds
# none of the draft tokens are accepted. # none of the draft tokens are accepted.
target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist( target_probs, zero_temperature_token_ids = get_zero_temperature_prob_dist(
batch_size, k, vocab_size)) batch_size, k + 1, vocab_size)
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
target_probs[target_probs == 0] = 0.00001 target_probs[target_probs == 0] = 0.00001
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size, draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
zero_temperature_token_ids) zero_temperature_token_ids)
......
...@@ -5,9 +5,10 @@ from unittest.mock import MagicMock ...@@ -5,9 +5,10 @@ from unittest.mock import MagicMock
import pytest import pytest
import torch import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.sequence import (ExecuteModelRequest, HiddenStates, Logprob, from vllm.sequence import (ExecuteModelRequest, HiddenStates, Logprob,
SamplerOutput, get_all_seq_ids) get_all_seq_ids)
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.spec_decode.top1_proposer import Top1Proposer
......
...@@ -7,8 +7,9 @@ from unittest.mock import MagicMock ...@@ -7,8 +7,9 @@ from unittest.mock import MagicMock
import pytest import pytest
import torch import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.sequence import ExecuteModelRequest, SamplerOutput, SequenceOutput from vllm.sequence import ExecuteModelRequest, SequenceOutput
from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.metrics import (AsyncMetricsCollector, from vllm.spec_decode.metrics import (AsyncMetricsCollector,
SpecDecodeWorkerMetrics) SpecDecodeWorkerMetrics)
...@@ -229,9 +230,8 @@ def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int, ...@@ -229,9 +230,8 @@ def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int,
assert torch.equal(actual.bonus_token_ids, assert torch.equal(actual.bonus_token_ids,
target_token_ids.reshape(batch_size, k + 1)[:, -1:]) target_token_ids.reshape(batch_size, k + 1)[:, -1:])
assert torch.equal( assert torch.equal(actual.target_with_bonus_probs,
actual.target_probs, target_token_probs.reshape(batch_size, k + 1, -1))
target_token_probs.reshape(batch_size, k + 1, -1)[:, :-1])
assert torch.equal(actual.draft_token_ids, proposal_token_ids) assert torch.equal(actual.draft_token_ids, proposal_token_ids)
assert torch.equal(actual.draft_probs, proposal_probs) assert torch.equal(actual.draft_probs, proposal_probs)
......
...@@ -4,10 +4,12 @@ import pytest ...@@ -4,10 +4,12 @@ import pytest
import torch import torch
from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.layers.sampler import _get_ranks
from vllm.model_executor.layers.typical_acceptance_sampler import ( from vllm.model_executor.layers.typical_acceptance_sampler import (
TypicalAcceptanceSampler) TypicalAcceptanceSampler)
from vllm.sequence import SequenceGroupMetadata, get_all_seq_ids from vllm.sequence import SequenceGroupMetadata, get_all_seq_ids
from vllm.spec_decode.util import split_batch_by_proposal_len from vllm.spec_decode.util import (get_sampled_token_logprobs,
split_batch_by_proposal_len)
def test_get_all_seq_ids(): def test_get_all_seq_ids():
...@@ -55,10 +57,9 @@ def fake_sequence_group_metadata(): ...@@ -55,10 +57,9 @@ def fake_sequence_group_metadata():
def test_filter_zero_length_proposals(fake_sequence_group_metadata): def test_filter_zero_length_proposals(fake_sequence_group_metadata):
proposal_lens = [0, 1, 0] proposal_lens = [0, 1, 0]
filtered_groups, indices = split_batch_by_proposal_len( _, (filtered_groups,
fake_sequence_group_metadata, indices) = split_batch_by_proposal_len(fake_sequence_group_metadata,
proposal_lens, proposal_lens)
select_proposal_len_zero=True)
expected_groups = [ expected_groups = [
fake_sequence_group_metadata[0], fake_sequence_group_metadata[2] fake_sequence_group_metadata[0], fake_sequence_group_metadata[2]
...@@ -71,10 +72,9 @@ def test_filter_zero_length_proposals(fake_sequence_group_metadata): ...@@ -71,10 +72,9 @@ def test_filter_zero_length_proposals(fake_sequence_group_metadata):
def test_filter_non_zero_length_proposals(fake_sequence_group_metadata): def test_filter_non_zero_length_proposals(fake_sequence_group_metadata):
proposal_lens = [0, 1, 2] proposal_lens = [0, 1, 2]
filtered_groups, indices = split_batch_by_proposal_len( (filtered_groups,
fake_sequence_group_metadata, indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata,
proposal_lens, proposal_lens)
select_proposal_len_zero=False)
expected_groups = [ expected_groups = [
fake_sequence_group_metadata[1], fake_sequence_group_metadata[2] fake_sequence_group_metadata[1], fake_sequence_group_metadata[2]
...@@ -86,8 +86,7 @@ def test_filter_non_zero_length_proposals(fake_sequence_group_metadata): ...@@ -86,8 +86,7 @@ def test_filter_non_zero_length_proposals(fake_sequence_group_metadata):
def test_empty_inputs(): def test_empty_inputs():
filtered_groups, indices = split_batch_by_proposal_len( _, (filtered_groups, indices) = split_batch_by_proposal_len([], [])
[], [], select_proposal_len_zero=True)
assert filtered_groups == [] assert filtered_groups == []
assert indices == [] assert indices == []
...@@ -95,10 +94,9 @@ def test_empty_inputs(): ...@@ -95,10 +94,9 @@ def test_empty_inputs():
def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata): def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata):
proposal_lens = [0, 0, 0] proposal_lens = [0, 0, 0]
filtered_groups, indices = split_batch_by_proposal_len( (filtered_groups,
fake_sequence_group_metadata, indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata,
proposal_lens, proposal_lens)
select_proposal_len_zero=False)
assert filtered_groups == [] assert filtered_groups == []
assert indices == [] assert indices == []
...@@ -106,10 +104,9 @@ def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata): ...@@ -106,10 +104,9 @@ def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata):
def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata): def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata):
proposal_lens = [1, 1, 1] proposal_lens = [1, 1, 1]
filtered_groups, indices = split_batch_by_proposal_len( _, (filtered_groups,
fake_sequence_group_metadata, indices) = split_batch_by_proposal_len(fake_sequence_group_metadata,
proposal_lens, proposal_lens)
select_proposal_len_zero=True)
assert filtered_groups == [] assert filtered_groups == []
assert indices == [] assert indices == []
...@@ -131,3 +128,20 @@ def mock_spec_decode_sampler(acceptance_sampler_method): ...@@ -131,3 +128,20 @@ def mock_spec_decode_sampler(acceptance_sampler_method):
return sampler return sampler
else: else:
raise ValueError(f"Invalid sampler name {acceptance_sampler_method}") raise ValueError(f"Invalid sampler name {acceptance_sampler_method}")
def test_get_sampled_token_logprobs():
"""Verify get_sampled_token_logprobs returns consistent rankings
with regular get_ranks when probabilities match exactly.
"""
logprob_tensor = torch.tensor(
[[[-.1, -.1]] * 2]) # shape (num_steps, batch_size, vocab_size)
sampled_token_tensor = torch.tensor([[1,
0]]) # shape (num_steps, batch_size)
ranks_spec_dec, _ = get_sampled_token_logprobs(logprob_tensor,
sampled_token_tensor)
ranks_regular = _get_ranks(logprob_tensor.reshape((2, -1)),
sampled_token_tensor.reshape(-1))
assert torch.equal(ranks_spec_dec.reshape(-1), ranks_regular)
...@@ -8,12 +8,12 @@ from unittest.mock import MagicMock ...@@ -8,12 +8,12 @@ from unittest.mock import MagicMock
import torch import torch
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE,
CompletionSequenceGroupOutput, Logprob, CompletionSequenceGroupOutput, Logprob,
SamplerOutput, SequenceData, SequenceGroupMetadata, SequenceData, SequenceGroupMetadata, SequenceOutput)
SequenceOutput)
from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import ModelRunner from vllm.worker.model_runner import ModelRunner
......
...@@ -2,9 +2,10 @@ from array import array ...@@ -2,9 +2,10 @@ from array import array
import pytest import pytest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE,
CompletionSequenceGroupOutput, SamplerOutput, CompletionSequenceGroupOutput, SequenceData,
SequenceData, SequenceOutput) SequenceOutput)
from .core.utils import create_dummy_prompt from .core.utils import create_dummy_prompt
......
...@@ -132,6 +132,16 @@ def parser(): ...@@ -132,6 +132,16 @@ def parser():
return parser return parser
@pytest.fixture
def parser_with_config():
parser = FlexibleArgumentParser()
parser.add_argument('serve')
parser.add_argument('--config', type=str)
parser.add_argument('--port', type=int)
parser.add_argument('--tensor-parallel-size', type=int)
return parser
def test_underscore_to_dash(parser): def test_underscore_to_dash(parser):
args = parser.parse_args(['--image_input_type', 'pixel_values']) args = parser.parse_args(['--image_input_type', 'pixel_values'])
assert args.image_input_type == 'pixel_values' assert args.image_input_type == 'pixel_values'
...@@ -176,3 +186,37 @@ def test_missing_required_argument(parser): ...@@ -176,3 +186,37 @@ def test_missing_required_argument(parser):
parser.add_argument('--required-arg', required=True) parser.add_argument('--required-arg', required=True)
with pytest.raises(SystemExit): with pytest.raises(SystemExit):
parser.parse_args([]) parser.parse_args([])
def test_cli_override_to_config(parser_with_config):
args = parser_with_config.parse_args([
'serve', '--config', './data/test_config.yaml',
'--tensor-parallel-size', '3'
])
assert args.tensor_parallel_size == 3
args = parser_with_config.parse_args([
'serve', '--tensor-parallel-size', '3', '--config',
'./data/test_config.yaml'
])
assert args.tensor_parallel_size == 3
def test_config_args(parser_with_config):
args = parser_with_config.parse_args(
['serve', '--config', './data/test_config.yaml'])
assert args.tensor_parallel_size == 2
def test_config_file(parser_with_config):
with pytest.raises(FileNotFoundError):
parser_with_config.parse_args(['serve', '--config', 'test_config.yml'])
with pytest.raises(ValueError):
parser_with_config.parse_args(
['serve', '--config', './data/test_config.json'])
with pytest.raises(ValueError):
parser_with_config.parse_args([
'serve', '--tensor-parallel-size', '3', '--config', '--batch-size',
'32'
])
import pytest
import pytest_asyncio
from huggingface_hub import snapshot_download
from tests.utils import RemoteOpenAIServer
from .utils import ARGS, CONFIGS, ServerConfig
# for each server config, download the model and return the config
@pytest.fixture(scope="session", params=CONFIGS.keys())
def server_config(request):
config = CONFIGS[request.param]
# download model and tokenizer using transformers
snapshot_download(config["model"])
yield CONFIGS[request.param]
# run this for each server config
@pytest.fixture(scope="session")
def server(request, server_config: ServerConfig):
model = server_config["model"]
args_for_model = server_config["arguments"]
with RemoteOpenAIServer(model, ARGS + args_for_model,
max_wait_seconds=480) as server:
yield server
@pytest_asyncio.fixture
async def client(server: RemoteOpenAIServer):
async with server.get_async_client() as async_client:
yield async_client
from typing import List
import openai
import pytest
from .utils import MESSAGES_WITHOUT_TOOLS, WEATHER_TOOL
# test: make sure chat completions without tools provided work even when tools
# are enabled. This makes sure tool call chat templates work, AND that the tool
# parser stream processing doesn't change the output of the model.
@pytest.mark.asyncio
async def test_chat_completion_without_tools(client: openai.AsyncOpenAI):
models = await client.models.list()
model_name: str = models.data[0].id
chat_completion = await client.chat.completions.create(
messages=MESSAGES_WITHOUT_TOOLS,
temperature=0,
max_tokens=150,
model=model_name,
logprobs=False)
choice = chat_completion.choices[0]
stop_reason = chat_completion.choices[0].finish_reason
output_text = chat_completion.choices[0].message.content
# check to make sure we got text
assert output_text is not None
assert len(output_text) > 0
assert stop_reason != "tool_calls"
# check to make sure no tool calls were returned
assert (choice.message.tool_calls is None
or len(choice.message.tool_calls) == 0)
# make the same request, streaming
stream = await client.chat.completions.create(
messages=MESSAGES_WITHOUT_TOOLS,
temperature=0,
max_tokens=150,
model=model_name,
logprobs=False,
stream=True,
)
chunks: List[str] = []
finish_reason_count = 0
role_sent: bool = False
# assemble streamed chunks
async for chunk in stream:
delta = chunk.choices[0].delta
# make sure the role is assistant
if delta.role:
assert not role_sent
assert delta.role == 'assistant'
role_sent = True
if delta.content:
chunks.append(delta.content)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
assert chunk.choices[0].finish_reason == choice.finish_reason
# make sure tool call chunks aren't being streamed
assert not delta.tool_calls or len(delta.tool_calls) == 0
# make sure the role was sent, only 1 finish reason was sent, that chunks
# were in fact sent, and that the chunks match non-streaming
assert role_sent
assert finish_reason_count == 1
assert len(chunks)
assert "".join(chunks) == output_text
# test: conversation with tools enabled and provided that should not invoke
# tools, to make sure we can still get normal chat completion responses
# and that they won't be parsed as tools
@pytest.mark.asyncio
async def test_chat_completion_with_tools(client: openai.AsyncOpenAI):
models = await client.models.list()
model_name: str = models.data[0].id
chat_completion = await client.chat.completions.create(
messages=MESSAGES_WITHOUT_TOOLS,
temperature=0,
max_tokens=150,
model=model_name,
tools=[WEATHER_TOOL],
logprobs=False)
choice = chat_completion.choices[0]
stop_reason = chat_completion.choices[0].finish_reason
output_text = chat_completion.choices[0].message.content
# check to make sure we got text
assert output_text is not None
assert stop_reason != 'tool_calls'
assert len(output_text) > 0
# check to make sure no tool calls were returned
assert (choice.message.tool_calls is None
or len(choice.message.tool_calls) == 0)
# make the same request, streaming
stream = await client.chat.completions.create(
messages=MESSAGES_WITHOUT_TOOLS,
temperature=0,
max_tokens=150,
model=model_name,
logprobs=False,
tools=[WEATHER_TOOL],
stream=True,
)
chunks: List[str] = []
finish_reason_count = 0
role_sent: bool = False
# assemble streamed chunks
async for chunk in stream:
delta = chunk.choices[0].delta
# make sure the role is assistant
if delta.role:
assert delta.role == 'assistant'
role_sent = True
if delta.content:
chunks.append(delta.content)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
# make sure tool call chunks aren't being streamed
assert not delta.tool_calls or len(delta.tool_calls) == 0
# make sure the role was sent, only 1 finish reason was sent, that chunks
# were in fact sent, and that the chunks match non-streaming
assert role_sent
assert finish_reason_count == 1
assert chunk.choices[0].finish_reason == stop_reason
assert chunk.choices[0].finish_reason != 'tool_calls'
assert len(chunks)
assert "".join(chunks) == output_text
import json
from typing import Dict, List, Optional
import openai
import pytest
from .utils import (MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, SEARCH_TOOL,
WEATHER_TOOL)
# test: getting the model to generate parallel tool calls (streaming/not)
# when requested. NOTE that not all models may support this, so some exclusions
# may be added in the future. e.g. llama 3.1 models are not designed to support
# parallel tool calls.
@pytest.mark.asyncio
async def test_parallel_tool_calls(client: openai.AsyncOpenAI):
models = await client.models.list()
model_name: str = models.data[0].id
chat_completion = await client.chat.completions.create(
messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
temperature=0,
max_tokens=200,
model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False)
choice = chat_completion.choices[0]
stop_reason = chat_completion.choices[0].finish_reason
non_streamed_tool_calls = chat_completion.choices[0].message.tool_calls
# make sure 2 tool calls are present
assert choice.message.role == "assistant"
assert non_streamed_tool_calls is not None
assert len(non_streamed_tool_calls) == 2
for tool_call in non_streamed_tool_calls:
# make sure the tool includes a function and ID
assert tool_call.type == "function"
assert tool_call.function is not None
assert isinstance(tool_call.id, str)
assert len(tool_call.id) > 16
# make sure the weather tool was called correctly
assert tool_call.function.name == WEATHER_TOOL["function"]["name"]
assert isinstance(tool_call.function.arguments, str)
parsed_arguments = json.loads(tool_call.function.arguments)
assert isinstance(parsed_arguments, Dict)
assert isinstance(parsed_arguments.get("city"), str)
assert isinstance(parsed_arguments.get("state"), str)
assert stop_reason == "tool_calls"
# make the same request, streaming
stream = await client.chat.completions.create(
model=model_name,
messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
temperature=0,
max_tokens=200,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False,
stream=True)
role_name: Optional[str] = None
finish_reason_count: int = 0
tool_call_names: List[str] = []
tool_call_args: List[str] = []
tool_call_idx: int = -1
tool_call_id_count: int = 0
async for chunk in stream:
# if there's a finish reason make sure it's tools
if chunk.choices[0].finish_reason:
finish_reason_count += 1
assert chunk.choices[0].finish_reason == 'tool_calls'
# if a role is being streamed make sure it wasn't already set to
# something else
if chunk.choices[0].delta.role:
assert not role_name or role_name == 'assistant'
role_name = 'assistant'
# if a tool call is streamed make sure there's exactly one
# (based on the request parameters
streamed_tool_calls = chunk.choices[0].delta.tool_calls
if streamed_tool_calls and len(streamed_tool_calls) > 0:
# make sure only one diff is present - correct even for parallel
assert len(streamed_tool_calls) == 1
tool_call = streamed_tool_calls[0]
# if a new tool is being called, set up empty arguments
if tool_call.index != tool_call_idx:
tool_call_idx = tool_call.index
tool_call_args.append("")
# if a tool call ID is streamed, make sure one hasn't been already
if tool_call.id:
tool_call_id_count += 1
assert (isinstance(tool_call.id, str)
and (len(tool_call.id) > 16))
# if parts of the function start being streamed
if tool_call.function:
# if the function name is defined, set it. it should be streamed
# IN ENTIRETY, exactly one time.
if tool_call.function.name:
assert isinstance(tool_call.function.name, str)
tool_call_names.append(tool_call.function.name)
if tool_call.function.arguments:
# make sure they're a string and then add them to the list
assert isinstance(tool_call.function.arguments, str)
tool_call_args[
tool_call.index] += tool_call.function.arguments
assert finish_reason_count == 1
assert role_name == 'assistant'
assert (len(non_streamed_tool_calls) == len(tool_call_names) ==
len(tool_call_args))
for i in range(2):
assert non_streamed_tool_calls[i].function.name == tool_call_names[i]
streamed_args = json.loads(tool_call_args[i])
non_streamed_args = json.loads(
non_streamed_tool_calls[i].function.arguments)
assert streamed_args == non_streamed_args
# test: providing parallel tool calls back to the model to get a response
# (streaming/not)
@pytest.mark.asyncio
async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI):
models = await client.models.list()
model_name: str = models.data[0].id
chat_completion = await client.chat.completions.create(
messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE,
temperature=0,
max_tokens=200,
model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False)
choice = chat_completion.choices[0]
assert choice.finish_reason != "tool_calls" # "stop" or "length"
assert choice.message.role == "assistant"
assert choice.message.tool_calls is None \
or len(choice.message.tool_calls) == 0
assert choice.message.content is not None
assert "98" in choice.message.content # Dallas temp in tool response
assert "78" in choice.message.content # Orlando temp in tool response
stream = await client.chat.completions.create(
messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE,
temperature=0,
max_tokens=200,
model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False,
stream=True)
chunks: List[str] = []
finish_reason_count = 0
role_sent: bool = False
async for chunk in stream:
delta = chunk.choices[0].delta
if delta.role:
assert not role_sent
assert delta.role == "assistant"
role_sent = True
if delta.content:
chunks.append(delta.content)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
assert chunk.choices[0].finish_reason == choice.finish_reason
assert not delta.tool_calls or len(delta.tool_calls) == 0
assert role_sent
assert finish_reason_count == 1
assert len(chunks)
assert "".join(chunks) == choice.message.content
import json
from typing import Dict, List, Optional
import openai
import pytest
from .utils import (MESSAGES_ASKING_FOR_TOOLS, MESSAGES_WITH_TOOL_RESPONSE,
SEARCH_TOOL, WEATHER_TOOL)
# test: request a chat completion that should return tool calls, so we know they
# are parsable
@pytest.mark.asyncio
async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
models = await client.models.list()
model_name: str = models.data[0].id
chat_completion = await client.chat.completions.create(
messages=MESSAGES_ASKING_FOR_TOOLS,
temperature=0,
max_tokens=100,
model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False)
choice = chat_completion.choices[0]
stop_reason = chat_completion.choices[0].finish_reason
tool_calls = chat_completion.choices[0].message.tool_calls
# make sure a tool call is present
assert choice.message.role == 'assistant'
assert tool_calls is not None
assert len(tool_calls) == 1
assert tool_calls[0].type == 'function'
assert tool_calls[0].function is not None
assert isinstance(tool_calls[0].id, str)
assert len(tool_calls[0].id) > 16
# make sure the weather tool was called (classic example) with arguments
assert tool_calls[0].function.name == WEATHER_TOOL["function"]["name"]
assert tool_calls[0].function.arguments is not None
assert isinstance(tool_calls[0].function.arguments, str)
# make sure the arguments parse properly
parsed_arguments = json.loads(tool_calls[0].function.arguments)
assert isinstance(parsed_arguments, Dict)
assert isinstance(parsed_arguments.get("city"), str)
assert isinstance(parsed_arguments.get("state"), str)
assert parsed_arguments.get("city") == "Dallas"
assert parsed_arguments.get("state") == "TX"
assert stop_reason == "tool_calls"
function_name: Optional[str] = None
function_args_str: str = ''
tool_call_id: Optional[str] = None
role_name: Optional[str] = None
finish_reason_count: int = 0
# make the same request, streaming
stream = await client.chat.completions.create(
model=model_name,
messages=MESSAGES_ASKING_FOR_TOOLS,
temperature=0,
max_tokens=100,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False,
stream=True)
async for chunk in stream:
assert chunk.choices[0].index == 0
if chunk.choices[0].finish_reason:
finish_reason_count += 1
assert chunk.choices[0].finish_reason == 'tool_calls'
# if a role is being streamed make sure it wasn't already set to
# something else
if chunk.choices[0].delta.role:
assert not role_name or role_name == 'assistant'
role_name = 'assistant'
# if a tool call is streamed make sure there's exactly one
# (based on the request parameters
streamed_tool_calls = chunk.choices[0].delta.tool_calls
if streamed_tool_calls and len(streamed_tool_calls) > 0:
assert len(streamed_tool_calls) == 1
tool_call = streamed_tool_calls[0]
# if a tool call ID is streamed, make sure one hasn't been already
if tool_call.id:
assert not tool_call_id
tool_call_id = tool_call.id
# if parts of the function start being streamed
if tool_call.function:
# if the function name is defined, set it. it should be streamed
# IN ENTIRETY, exactly one time.
if tool_call.function.name:
assert function_name is None
assert isinstance(tool_call.function.name, str)
function_name = tool_call.function.name
if tool_call.function.arguments:
assert isinstance(tool_call.function.arguments, str)
function_args_str += tool_call.function.arguments
assert finish_reason_count == 1
assert role_name == 'assistant'
assert isinstance(tool_call_id, str) and (len(tool_call_id) > 16)
# validate the name and arguments
assert function_name == WEATHER_TOOL["function"]["name"]
assert function_name == tool_calls[0].function.name
assert isinstance(function_args_str, str)
# validate arguments
streamed_args = json.loads(function_args_str)
assert isinstance(streamed_args, Dict)
assert isinstance(streamed_args.get("city"), str)
assert isinstance(streamed_args.get("state"), str)
assert streamed_args.get("city") == "Dallas"
assert streamed_args.get("state") == "TX"
# make sure everything matches non-streaming except for ID
assert function_name == tool_calls[0].function.name
assert choice.message.role == role_name
assert choice.message.tool_calls[0].function.name == function_name
# compare streamed with non-streamed args Dict-wise, not string-wise
# because character-to-character comparison might not work e.g. the tool
# call parser adding extra spaces or something like that. we care about the
# dicts matching not byte-wise match
assert parsed_arguments == streamed_args
# test: providing tools and results back to model to get a non-tool response
# (streaming/not)
@pytest.mark.asyncio
async def test_tool_call_with_results(client: openai.AsyncOpenAI):
models = await client.models.list()
model_name: str = models.data[0].id
chat_completion = await client.chat.completions.create(
messages=MESSAGES_WITH_TOOL_RESPONSE,
temperature=0,
max_tokens=100,
model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False)
choice = chat_completion.choices[0]
assert choice.finish_reason != "tool_calls" # "stop" or "length"
assert choice.message.role == "assistant"
assert choice.message.tool_calls is None \
or len(choice.message.tool_calls) == 0
assert choice.message.content is not None
assert "98" in choice.message.content # the temperature from the response
stream = await client.chat.completions.create(
messages=MESSAGES_WITH_TOOL_RESPONSE,
temperature=0,
max_tokens=100,
model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False,
stream=True)
chunks: List[str] = []
finish_reason_count = 0
role_sent: bool = False
async for chunk in stream:
delta = chunk.choices[0].delta
if delta.role:
assert not role_sent
assert delta.role == "assistant"
role_sent = True
if delta.content:
chunks.append(delta.content)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
assert chunk.choices[0].finish_reason == choice.finish_reason
assert not delta.tool_calls or len(delta.tool_calls) == 0
assert role_sent
assert finish_reason_count == 1
assert len(chunks)
assert "".join(chunks) == choice.message.content
from typing import Dict, List
from openai.types.chat import (ChatCompletionMessageParam,
ChatCompletionToolParam)
from typing_extensions import TypedDict
from tests.utils import VLLM_PATH
class ServerConfig(TypedDict):
model: str
arguments: List[str]
# universal args for all models go here. also good if you need to test locally
# and change type or KV cache quantization or something.
ARGS: List[str] = ["--enable-auto-tool-choice", "--max-model-len", "8096"]
CONFIGS: Dict[str, ServerConfig] = {
"hermes": {
"model":
"NousResearch/Hermes-2-Pro-Llama-3-8B",
"arguments": [
"--tool-call-parser", "hermes", "--chat-template",
str(VLLM_PATH / "examples/tool_chat_template_hermes.jinja")
]
},
"mistral": {
"model":
"mistralai/Mistral-7B-Instruct-v0.3",
"arguments": [
"--tool-call-parser", "mistral", "--chat-template",
str(VLLM_PATH / "examples/tool_chat_template_mistral.jinja"),
"--ignore-patterns=\"consolidated.safetensors\""
]
}
}
WEATHER_TOOL: ChatCompletionToolParam = {
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type":
"string",
"description":
"The city to find the weather for, "
"e.g. 'San Francisco'"
},
"state": {
"type":
"string",
"description":
"the two-letter abbreviation for the state "
"that the city is in, e.g. 'CA' which would "
"mean 'California'"
},
"unit": {
"type": "string",
"description": "The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"]
}
}
}
}
}
SEARCH_TOOL: ChatCompletionToolParam = {
"type": "function",
"function": {
"name":
"web_search",
"description":
"Search the internet and get a summary of the top "
"10 webpages. Should only be used if you don't know "
"the answer to a user query, and the results are likely"
"to be able to be found with a web search",
"parameters": {
"type": "object",
"properties": {
"search_term": {
"type":
"string",
"description":
"The term to use in the search. This should"
"ideally be keywords to search for, not a"
"natural-language question"
}
},
"required": ["search_term"]
}
}
}
MESSAGES_WITHOUT_TOOLS: List[ChatCompletionMessageParam] = [{
"role":
"system",
"content":
"You are a helpful assistant with access to tools. If a tool"
" that you have would be helpful to answer a user query, "
"call the tool. Otherwise, answer the user's query directly "
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
"to the user's question - just respond to it normally."
}, {
"role":
"user",
"content":
"Hi! How are you?"
}, {
"role":
"assistant",
"content":
"I'm doing great! How can I assist you?"
}, {
"role":
"user",
"content":
"Can you tell me a joke please?"
}]
MESSAGES_ASKING_FOR_TOOLS: List[ChatCompletionMessageParam] = [{
"role":
"user",
"content":
"What is the weather in Dallas, Texas in Fahrenheit?"
}]
MESSAGES_WITH_TOOL_RESPONSE: List[ChatCompletionMessageParam] = [{
"role":
"user",
"content":
"What is the weather in Dallas, Texas in Fahrenheit?"
}, {
"role":
"assistant",
"tool_calls": [{
"id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295",
"type": "function",
"function": {
"name":
WEATHER_TOOL["function"]["name"],
"arguments":
'{"city": "Dallas", "state": "TX", '
'"unit": "fahrenheit"}'
}
}]
}, {
"role":
"tool",
"tool_call_id":
"chatcmpl-tool-03e6481b146e408e9523d9c956696295",
"content":
"The weather in Dallas is 98 degrees fahrenheit, with partly"
"cloudy skies and a low chance of rain."
}]
MESSAGES_ASKING_FOR_PARALLEL_TOOLS: List[ChatCompletionMessageParam] = [{
"role":
"user",
"content":
"What is the weather in Dallas, Texas and Orlando, Florida in "
"Fahrenheit?"
}]
MESSAGES_WITH_PARALLEL_TOOL_RESPONSE: List[ChatCompletionMessageParam] = [{
"role":
"user",
"content":
"What is the weather in Dallas, Texas and Orlando, Florida in "
"Fahrenheit?"
}, {
"role":
"assistant",
"tool_calls": [{
"id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295",
"type": "function",
"function": {
"name":
WEATHER_TOOL["function"]["name"],
"arguments":
'{"city": "Dallas", "state": "TX", '
'"unit": "fahrenheit"}'
}
}, {
"id": "chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b",
"type": "function",
"function": {
"name":
WEATHER_TOOL["function"]["name"],
"arguments":
'{"city": "Orlando", "state": "Fl", '
'"unit": "fahrenheit"}'
}
}]
}, {
"role":
"tool",
"tool_call_id":
"chatcmpl-tool-03e6481b146e408e9523d9c956696295",
"content":
"The weather in Dallas TX is 98 degrees fahrenheit with mostly "
"cloudy skies and a chance of rain in the evening."
}, {
"role":
"tool",
"tool_call_id":
"chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b",
"content":
"The weather in Orlando FL is 78 degrees fahrenheit with clear"
"skies."
}]
import glob
import os
import runpy
import tempfile
import depyf
# disable custom dispatcher, let Dynamo takes over
# all the control
os.environ['VLLM_DYNAMO_USE_CUSTOM_DISPATCHER'] = "0"
temp_dir = tempfile.mkdtemp()
with depyf.prepare_debug(temp_dir):
cur_dir = os.path.dirname(__file__)
parent_dir = os.path.dirname(cur_dir)
root_dir = os.path.dirname(parent_dir)
example_file = os.path.join(root_dir, "examples",
"offline_inference_tpu.py")
runpy.run_path(example_file)
compiled_code = sorted(
glob.glob(os.path.join(temp_dir, "__transformed_code*.py")))
# we should only trigger Dynamo compilation three times:
# one for the profiling phase without kv cache
# one for the prefill phase with symbolic shapes
# one for the decode phase with symbolic shapes
# and later calls should not trigger Dynamo compilation again.
# NOTE: it might still trigger XLA compilation.
# check we have three compiled code
# this is the assumption when we use the custom dispatcher
assert len(compiled_code) == 3
# check all the compilations are as expected
compiled_fn = sorted(
glob.glob(os.path.join(temp_dir, "__compiled_fn*Captured*.py")))
# the first compilation is the profiling phase,
# it should not have any kv cache
with open(compiled_fn[0]) as f:
content = f.read()
assert "kv_caches" not in content
# the second compilation is the prefill phase,
# it should have kv cache and the flash_attention op
with open(compiled_fn[1]) as f:
content = f.read()
assert "kv_caches" in content and "torch.ops.xla.flash_attention" in content
# the third compilation is the decode phase,
# it should have kv cache and the paged_attention op
with open(compiled_fn[2]) as f:
content = f.read()
assert "kv_caches" in content and "torch.ops.xla.paged_attention" in content
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment