"vllm/vscode:/vscode.git/clone" did not exist on "3d4721f27f879fd23b866e4c2689ee77b923fb26"
Commit 99b471c2 authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.4.1

parents 1925d2e9 468d761b
import pytest
from vllm import SamplingParams
from .conftest import get_output_from_llm_generator
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
{
# Expect failure as spec decode not supported by
# Ray backend.
"worker_use_ray": True,
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_xfail_ray(test_llm_generator):
"""Verify that speculative decoding with Ray fails.
"""
output_len = 128
temperature = 0.0
prompts = [
"Hello, my name is",
]
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)
with pytest.raises(AssertionError,
match="Speculative decoding not yet supported for "):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"enable_chunked_prefill": True,
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_xfail_chunked_prefill(test_llm_generator):
"""Verify that speculative decoding with chunked prefill fails.
"""
output_len = 128
temperature = 0.0
prompts = [
"Hello, my name is",
]
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)
with pytest.raises(ValueError,
match="Speculative decoding and chunked prefill"):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "meta-llama/Llama-2-7b-chat-hf",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
{
# Speculative max model len > overridden max model len should raise.
"max_model_len": 128,
"speculative_max_model_len": 129,
},
{
# Speculative max model len > draft max model len should raise.
# https://huggingface.co/JackFram/llama-68m/blob/3b606af5198a0b26762d589a3ee3d26ee6fa6c85/config.json#L12
"speculative_max_model_len": 2048 + 1,
},
{
# Speculative max model len > target max model len should raise.
# https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/f5db02db724555f92da89c216ac04704f23d4590/config.json#L12
"speculative_max_model_len": 4096 + 1,
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_xfail_spec_max_model_len(test_llm_generator):
"""Verify that speculative decoding validates speculative_max_model_len.
"""
output_len = 128
temperature = 0.0
prompts = [
"Hello, my name is",
]
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)
with pytest.raises(ValueError, match="cannot be larger than"):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)
@pytest.mark.parametrize("common_llm_kwargs", [{
"model": "JackFram/llama-68m",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_xfail_block_manager_v1(test_llm_generator):
"""Verify that speculative decoding with block manager v1 fails.
"""
output_len = 128
temperature = 0.0
prompts = [
"Hello, my name is",
]
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)
with pytest.raises(ValueError,
match="Speculative decoding requires usage of the V2"):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)
"""The tests in this file verify end-to-end speculative decoding correctness.
This docstring details important information on the testing methodology.
Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.
Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality. This gives us good coverage of temp=0.
For temp>0, we rely on unit tests on the rejection sampler to verify that the
output distribution is the same with spec decode vs. no spec decode (this would
be prohibitively expensive to run with a real model).
NOTE: Speculative decoding's distribution equality requires that the measured
distributions of the target model and proposal model be deterministic given the
same input. vLLM largely guarantees this.
@cadedaniel has seen cases where the output probabilities of a draft/target
model change slightly with certain batch sizes or prompts, even with Torch
determinism flags set. It is unclear if this is a bug in vLLM, due to non-
determinism in on-device batched operations, a bug in vLLM's spec decode
implementation, or the "hardware numerics" limitations. Either way, rejection
sampling ensures the output distribution matches the target model, but it breaks
greedy-equality tests for those batch sizes/prompts.
"""
from itertools import cycle
import pytest
from transformers import AutoTokenizer
from vllm import SamplingParams
from .conftest import get_output_from_llm_generator
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Use a small model for a fast test.
# Note this is repeated in the test body; to initialize a tokenizer.
"model": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
{
# Verify the detokenizer assertions in the test work when spec
# decode is disabled.
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_e2e_with_detokenization(test_llm_generator,
batch_size: int):
"""Run generation with speculative decoding on a batch. Verify the engine
generates the correct number of tokens (via ignore_eos=True), and that the
detokenization matches HF transformers.
"""
output_len = 32
temperature = 0.0
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)
batch_tokens, batch_token_ids = get_output_from_llm_generator(
test_llm_generator, prompts, sampling_params)
# Expect a generation for each prompt in the batch.
assert len(batch_token_ids) == len(prompts)
# Expect each generation to have expected number of tokens (note ignore_eos
# is True).
assert [len(token_ids)
for token_ids in batch_token_ids] == ([output_len] * batch_size)
# Expect detokenized string to match.
tok = AutoTokenizer.from_pretrained("JackFram/llama-68m")
for actual_tokens, actual_token_ids in zip(batch_tokens, batch_token_ids):
expected_tokens = tok.decode(actual_token_ids)
print(f"{actual_token_ids=}")
assert actual_tokens.strip() == expected_tokens.strip()
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
# Try two different tiny base models.
# Note that one is equal to the draft model, another isn't.
{
"model": "JackFram/llama-68m",
},
{
"model": "JackFram/llama-160m",
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use long output len for the small model test.
1536,
])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
baseline_llm_generator, test_llm_generator, batch_size: int,
output_len: int):
"""Verify greedy equality on a tiny model with batch size of one.
Since this test is cheaper than other e2e correctness tests, we generate
with a higher output_len.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
# Try two different tiny base models.
# Note that one is equal to the draft model, another isn't.
{
"model": "JackFram/llama-68m",
},
{
"model": "JackFram/llama-160m",
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use small output len for fast test.
256,
])
@pytest.mark.parametrize("batch_size", [64])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
baseline_llm_generator, test_llm_generator, batch_size: int,
output_len: int):
"""Verify greedy equality on a tiny model and large batch size.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
# Try two different tiny base models.
# Note that one is equal to the draft model, another isn't.
{
"model": "JackFram/llama-68m",
},
{
"model": "JackFram/llama-160m",
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
])
@pytest.mark.parametrize("max_output_len", [
256,
])
@pytest.mark.parametrize("batch_size", [32])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
baseline_llm_generator, test_llm_generator, batch_size: int,
max_output_len: int):
"""Verify greedy equality on a tiny model, with a large batch size, and when
sampling respects the EOS token.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len,
force_output_len=False)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# A "real" model (not tiny).
"model": "meta-llama/Llama-2-7b-chat-hf",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize(
"output_len",
[
# Use decently long output len for a high quality test.
256,
])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
baseline_llm_generator, test_llm_generator, batch_size: int,
output_len: int):
"""Verify greedy equality on a "real" model and batch size of 1. This is
separate from large BS tests to make identifying the source of bugs easier.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# A "real" model (not tiny).
"model": "meta-llama/Llama-2-7b-chat-hf",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
])
@pytest.mark.parametrize("batch_size", [32])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
64,
])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
baseline_llm_generator, test_llm_generator, batch_size: int,
output_len: int):
"""Verify greedy equality with a "real" model on a nontrivial batch size.
This is the closest test to a real production workload.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"block_size": 8,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override": 2 + 256 // 8,
"max_model_len": (2 + 256 // 8) * 8,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"model": "JackFram/llama-160m",
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use small output len for fast test.
256,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_e2e_greedy_correctness_with_preemption(
baseline_llm_generator, test_llm_generator, batch_size: int,
output_len: int):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-160m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
# As of this writing, vLLM only compiles with these 3 block sizes by
# default.
{
"block_size": 8,
},
{
"block_size": 16,
},
{
"block_size": 32,
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_different_block_size(baseline_llm_generator,
test_llm_generator, batch_size: int,
output_len: int):
"""Verify greedy equality over different block sizes.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-160m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
# Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens.
"speculative_max_model_len": 32,
},
])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
"output_len",
[
# This must be a good bit larger than speculative_max_model_len so that
# we can test the case where all seqs are skipped, but still small to
# ensure fast test.
64,
])
@pytest.mark.parametrize("seed", [1])
def test_skip_speculation(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify greedy equality when some (or all) sequences skip speculation.
We do this by setting the max model len of the draft model to an
artificially low value, such that when the sequences grow beyond it, they
are skipped in speculative decoding.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": k,
}
# Try a range of common k, as well as large speculation.
for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
output_len: int):
"""Verify that speculative decoding produces exact equality to without spec
decode with many different values of k.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
def run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len,
force_output_len: bool,
print_tokens: bool = False):
"""Helper method that compares the outputs of both the baseline LLM and
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero.
"""
temperature = 0.0
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
"San Francisco is know for its",
"Facebook was created in 2004 by",
"Curious George is a",
"Python 3.11 brings improvements to its",
]
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
# If the test requires that we generated max_output_len tokens, then set the
# sampling params to ignore eos token.
ignore_eos = force_output_len
sampling_params = SamplingParams(
max_tokens=max_output_len,
ignore_eos=ignore_eos,
temperature=temperature,
)
spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator(
test_llm_generator, prompts, sampling_params)
(baseline_batch_tokens,
baseline_batch_token_ids) = get_output_from_llm_generator(
baseline_llm_generator, prompts, sampling_params)
assert len(baseline_batch_token_ids) == len(prompts)
assert len(spec_batch_token_ids) == len(prompts)
for i, (baseline_token_ids, baseline_tokens, spec_token_ids,
spec_tokens) in enumerate(
zip(baseline_batch_token_ids, baseline_batch_tokens,
spec_batch_token_ids, spec_batch_tokens)):
if print_tokens:
print(f'{i=} {baseline_tokens=}')
print(f'{i=} {spec_tokens=}')
print(f'{i=} {baseline_token_ids=}')
print(f'{i=} {spec_token_ids=}')
assert baseline_token_ids == spec_token_ids
...@@ -7,6 +7,7 @@ from .utils import create_seq_group_metadata_from_prompts, mock_worker ...@@ -7,6 +7,7 @@ from .utils import create_seq_group_metadata_from_prompts, mock_worker
@pytest.mark.parametrize('num_target_seq_ids', [100]) @pytest.mark.parametrize('num_target_seq_ids', [100])
@pytest.mark.skip_global_cleanup
def test_create_target_seq_id_iterator(num_target_seq_ids: int): def test_create_target_seq_id_iterator(num_target_seq_ids: int):
"""Verify all new sequence ids are greater than all input """Verify all new sequence ids are greater than all input
seq ids. seq ids.
...@@ -27,6 +28,7 @@ def test_create_target_seq_id_iterator(num_target_seq_ids: int): ...@@ -27,6 +28,7 @@ def test_create_target_seq_id_iterator(num_target_seq_ids: int):
@pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.skip_global_cleanup
def test_get_token_ids_to_score(k: int): def test_get_token_ids_to_score(k: int):
"""Verify correct tokens are selected for scoring. """Verify correct tokens are selected for scoring.
""" """
...@@ -53,6 +55,7 @@ def test_get_token_ids_to_score(k: int): ...@@ -53,6 +55,7 @@ def test_get_token_ids_to_score(k: int):
@pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.skip_global_cleanup
def test_create_single_target_seq_group_metadata(k: int): def test_create_single_target_seq_group_metadata(k: int):
"""Verify correct creation of a batch-expanded seq group metadata. """Verify correct creation of a batch-expanded seq group metadata.
""" """
......
...@@ -119,7 +119,7 @@ def test_initial_metrics_has_correct_values(has_data: bool): ...@@ -119,7 +119,7 @@ def test_initial_metrics_has_correct_values(has_data: bool):
num_draft_tokens = 0 num_draft_tokens = 0
k = 5 k = 5
num_possible_tokens = AsyncMetricsCollector.get_max_num_accepted_tokens( max_num_emitted_tokens = AsyncMetricsCollector.get_max_num_emitted_tokens(
num_draft_tokens, k) num_draft_tokens, k)
rej_sampler = MagicMock() rej_sampler = MagicMock()
...@@ -153,7 +153,7 @@ def test_initial_metrics_has_correct_values(has_data: bool): ...@@ -153,7 +153,7 @@ def test_initial_metrics_has_correct_values(has_data: bool):
assert (metrics.draft_acceptance_rate == num_accepted_tokens / assert (metrics.draft_acceptance_rate == num_accepted_tokens /
num_draft_tokens) num_draft_tokens)
assert (metrics.system_efficiency == num_emitted_tokens / assert (metrics.system_efficiency == num_emitted_tokens /
num_possible_tokens) max_num_emitted_tokens)
else: else:
assert math.isnan(metrics.draft_acceptance_rate) assert math.isnan(metrics.draft_acceptance_rate)
assert math.isnan(metrics.system_efficiency) assert math.isnan(metrics.system_efficiency)
...@@ -125,7 +125,7 @@ def test_same_output_for_single_step(): ...@@ -125,7 +125,7 @@ def test_same_output_for_single_step():
zero_kv_cache(worker.cache_engine) zero_kv_cache(worker.cache_engine)
set_random_seed(seed) set_random_seed(seed)
expected_output = worker.execute_model( expected_output = worker.execute_model(
**single_step_execute_model_data.to_dict(), ) **single_step_execute_model_data.to_dict(), )[0]
actual_token_ids = [ actual_token_ids = [
output.samples[0].output_token for output in actual_output output.samples[0].output_token for output in actual_output
...@@ -219,7 +219,7 @@ def test_same_output_for_multi_step(): ...@@ -219,7 +219,7 @@ def test_same_output_for_multi_step():
continuations=continuations, continuations=continuations,
final_seq_lens=final_seq_lens)) final_seq_lens=final_seq_lens))
single_step_output.append( single_step_output.extend(
worker.execute_model(**execute_model_data.to_dict(), )) worker.execute_model(**execute_model_data.to_dict(), ))
# Append output tokens to new sequence data. # Append output tokens to new sequence data.
...@@ -344,8 +344,8 @@ def test_draft_proposals_no_speculations(): ...@@ -344,8 +344,8 @@ def test_draft_proposals_no_speculations():
assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs) assert torch.is_tensor(proposals.proposal_probs)
assert proposals.proposal_token_ids.shape == torch.Size([0, k]) assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
assert proposals.proposal_probs.shape[:-1] == torch.Size([0, k]) assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])
assert proposals.proposal_lens.shape == torch.Size([batch_size]) assert proposals.proposal_lens.shape == torch.Size([batch_size])
assert proposals.proposal_lens.tolist() == [0 for _ in range(batch_size)] assert proposals.proposal_lens.tolist() == [0 for _ in range(batch_size)]
......
import random import random
from types import SimpleNamespace
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
...@@ -6,6 +7,7 @@ import torch ...@@ -6,6 +7,7 @@ import torch
from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplerOutput
from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.metrics import (AsyncMetricsCollector, from vllm.spec_decode.metrics import (AsyncMetricsCollector,
SpecDecodeWorkerMetrics) SpecDecodeWorkerMetrics)
...@@ -37,7 +39,8 @@ def test_correctly_calls_draft_model(k: int, batch_size: int): ...@@ -37,7 +39,8 @@ def test_correctly_calls_draft_model(k: int, batch_size: int):
execute_model_data, _, _ = create_batch(batch_size, k) execute_model_data, _, _ = create_batch(batch_size, k)
with pytest.raises(ValueError, match=exception_secret): with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k) worker.execute_model(**execute_model_data.to_dict(),
num_lookahead_slots=k)
call_args_list = draft_worker.get_spec_proposals.call_args_list call_args_list = draft_worker.get_spec_proposals.call_args_list
assert len(call_args_list) == 1 assert len(call_args_list) == 1
...@@ -60,8 +63,8 @@ def test_correctly_calls_target_model(k: int, batch_size: int): ...@@ -60,8 +63,8 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
"""Verify SpecDecodeWorker calls the target model with correct """Verify SpecDecodeWorker calls the target model with correct
inputs. Everything else is mocked out. inputs. Everything else is mocked out.
""" """
draft_worker = mock_worker(cls=MultiStepWorker) draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
target_worker = mock_worker() target_worker = mock_worker(use_spec=False)
rejection_sampler = MagicMock(spec=RejectionSampler) rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64 rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
...@@ -102,7 +105,8 @@ def test_correctly_calls_target_model(k: int, batch_size: int): ...@@ -102,7 +105,8 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
target_worker.execute_model.side_effect = ValueError(exception_secret) target_worker.execute_model.side_effect = ValueError(exception_secret)
with pytest.raises(ValueError, match=exception_secret): with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k) worker.execute_model(**execute_model_data.to_dict(),
num_lookahead_slots=k)
seen_contexts = [] seen_contexts = []
...@@ -141,8 +145,10 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int): ...@@ -141,8 +145,10 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
""" """
vocab_size = 32_000 vocab_size = 32_000
draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size) draft_worker = mock_worker(cls=MultiStepWorker,
target_worker = mock_worker(vocab_size=vocab_size) vocab_size=vocab_size,
use_spec=False)
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
rejection_sampler = MagicMock(spec=RejectionSampler) rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64 rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
...@@ -189,26 +195,26 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int): ...@@ -189,26 +195,26 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
target_output = create_sampler_output_list(target_token_ids, target_output = create_sampler_output_list(target_token_ids,
target_token_probs) target_token_probs)
target_worker.execute_model.return_value = target_output[0] target_worker.execute_model.return_value = [target_output[0]]
exception_secret = 'artifical stop' exception_secret = 'artifical stop'
rejection_sampler.side_effect = ValueError(exception_secret) rejection_sampler.side_effect = ValueError(exception_secret)
with pytest.raises(ValueError, match=exception_secret): with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k) worker.execute_model(**execute_model_data.to_dict(),
num_lookahead_slots=k)
assert len(rejection_sampler.call_args_list) == 1 assert len(rejection_sampler.call_args_list) == 1
args, _ = rejection_sampler.call_args_list[0] _, kwargs = rejection_sampler.call_args_list[0]
(actual_proposal_scores, actual_bonus_token_ids, actual_proposal_probs, actual = SimpleNamespace(**kwargs)
actual_proposal_token_ids) = args
assert torch.equal(actual_bonus_token_ids, assert torch.equal(actual.bonus_token_ids,
target_token_ids.reshape(batch_size, k + 1)[:, -1:]) target_token_ids.reshape(batch_size, k + 1)[:, -1:])
assert torch.equal( assert torch.equal(
actual_proposal_scores, actual.target_probs,
target_token_probs.reshape(batch_size, k + 1, -1)[:, :-1]) target_token_probs.reshape(batch_size, k + 1, -1)[:, :-1])
assert torch.equal(actual_proposal_token_ids, proposal_token_ids) assert torch.equal(actual.draft_token_ids, proposal_token_ids)
assert torch.equal(actual_proposal_probs, proposal_probs) assert torch.equal(actual.draft_probs, proposal_probs)
@pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('k', [1, 2, 6])
...@@ -220,8 +226,10 @@ def test_correctly_formats_output(k: int, batch_size: int): ...@@ -220,8 +226,10 @@ def test_correctly_formats_output(k: int, batch_size: int):
""" """
vocab_size = 32_000 vocab_size = 32_000
draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size) draft_worker = mock_worker(cls=MultiStepWorker,
target_worker = mock_worker(vocab_size=vocab_size) vocab_size=vocab_size,
use_spec=False)
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
rejection_sampler = MagicMock(spec=RejectionSampler) rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64 rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
...@@ -268,7 +276,7 @@ def test_correctly_formats_output(k: int, batch_size: int): ...@@ -268,7 +276,7 @@ def test_correctly_formats_output(k: int, batch_size: int):
target_output = create_sampler_output_list(target_token_ids, target_output = create_sampler_output_list(target_token_ids,
target_token_probs) target_token_probs)
target_worker.execute_model.return_value = target_output[0] target_worker.execute_model.return_value = [target_output[0]]
rejection_sampler_output = torch.randint(low=0, rejection_sampler_output = torch.randint(low=0,
high=vocab_size, high=vocab_size,
...@@ -283,7 +291,7 @@ def test_correctly_formats_output(k: int, batch_size: int): ...@@ -283,7 +291,7 @@ def test_correctly_formats_output(k: int, batch_size: int):
rejection_sampler.return_value = rejection_sampler_output rejection_sampler.return_value = rejection_sampler_output
output = worker.execute_model(**execute_model_data.to_dict(), output = worker.execute_model(**execute_model_data.to_dict(),
num_spec_tokens=k) num_lookahead_slots=k)
expected_output = create_sampler_output_list( expected_output = create_sampler_output_list(
rejection_sampler_output.transpose(0, 1), [None for _ in range(k + 1)]) rejection_sampler_output.transpose(0, 1), [None for _ in range(k + 1)])
...@@ -332,8 +340,10 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): ...@@ -332,8 +340,10 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
""" """
vocab_size = 32_000 vocab_size = 32_000
draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size) draft_worker = mock_worker(cls=MultiStepWorker,
target_worker = mock_worker(vocab_size=vocab_size) vocab_size=vocab_size,
use_spec=False)
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
rejection_sampler = MagicMock(spec=RejectionSampler) rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64 rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
...@@ -380,7 +390,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): ...@@ -380,7 +390,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
target_output = create_sampler_output_list(target_token_ids, target_output = create_sampler_output_list(target_token_ids,
target_token_probs) target_token_probs)
target_worker.execute_model.return_value = target_output[0] target_worker.execute_model.return_value = [target_output[0]]
rejection_sampler_output = torch.randint(low=0, rejection_sampler_output = torch.randint(low=0,
high=vocab_size, high=vocab_size,
...@@ -400,7 +410,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): ...@@ -400,7 +410,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
mock_rejsample_metrics) mock_rejsample_metrics)
output = worker.execute_model(**execute_model_data.to_dict(), output = worker.execute_model(**execute_model_data.to_dict(),
num_spec_tokens=k) num_lookahead_slots=k)
assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics
call_args_list = ( call_args_list = (
...@@ -423,6 +433,8 @@ def test_k_equals_zero(k: int, batch_size: int): ...@@ -423,6 +433,8 @@ def test_k_equals_zero(k: int, batch_size: int):
rejection_sampler.token_id_dtype = torch.int64 rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)]
draft_worker.device = 'cuda' draft_worker.device = 'cuda'
target_worker.device = 'cuda' target_worker.device = 'cuda'
...@@ -435,7 +447,7 @@ def test_k_equals_zero(k: int, batch_size: int): ...@@ -435,7 +447,7 @@ def test_k_equals_zero(k: int, batch_size: int):
batch_size, k, prev_output_token_len=0) batch_size, k, prev_output_token_len=0)
out = worker.execute_model(**execute_model_data.to_dict(), out = worker.execute_model(**execute_model_data.to_dict(),
num_spec_tokens=k) num_lookahead_slots=k)
assert len(out) == 1, f"expected only one token output when {k=}" assert len(out) == 1, f"expected only one token output when {k=}"
assert out[0].probs is None, "expect gpu tensor references to be None" assert out[0].probs is None, "expect gpu tensor references to be None"
...@@ -443,7 +455,7 @@ def test_k_equals_zero(k: int, batch_size: int): ...@@ -443,7 +455,7 @@ def test_k_equals_zero(k: int, batch_size: int):
0].sampled_tokens is None, "expect gpu tensor references to be None" 0].sampled_tokens is None, "expect gpu tensor references to be None"
draft_worker.execute_model.assert_called_once_with( draft_worker.execute_model.assert_called_once_with(
**execute_model_data.to_dict(), return_python_output=False) **execute_model_data.to_dict())
target_worker.execute_model.assert_called_once_with( target_worker.execute_model.assert_called_once_with(
**execute_model_data.to_dict()) **execute_model_data.to_dict())
...@@ -462,6 +474,8 @@ def test_empty_input_batch(k: int, batch_size: int): ...@@ -462,6 +474,8 @@ def test_empty_input_batch(k: int, batch_size: int):
rejection_sampler.token_id_dtype = torch.int64 rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)]
draft_worker.device = 'cuda' draft_worker.device = 'cuda'
target_worker.device = 'cuda' target_worker.device = 'cuda'
...@@ -474,7 +488,7 @@ def test_empty_input_batch(k: int, batch_size: int): ...@@ -474,7 +488,7 @@ def test_empty_input_batch(k: int, batch_size: int):
batch_size, k, prev_output_token_len=0) batch_size, k, prev_output_token_len=0)
out = worker.execute_model(**execute_model_data.to_dict(), out = worker.execute_model(**execute_model_data.to_dict(),
num_spec_tokens=k) num_lookahead_slots=k)
assert len(out) == 1, f"expected only one token output when {k=}" assert len(out) == 1, f"expected only one token output when {k=}"
assert out[0].probs is None, "expect gpu tensor references to be None" assert out[0].probs is None, "expect gpu tensor references to be None"
...@@ -482,18 +496,18 @@ def test_empty_input_batch(k: int, batch_size: int): ...@@ -482,18 +496,18 @@ def test_empty_input_batch(k: int, batch_size: int):
0].sampled_tokens is None, "expect gpu tensor references to be None" 0].sampled_tokens is None, "expect gpu tensor references to be None"
draft_worker.execute_model.assert_called_once_with( draft_worker.execute_model.assert_called_once_with(
**execute_model_data.to_dict(), return_python_output=False) **execute_model_data.to_dict())
target_worker.execute_model.assert_called_once_with( target_worker.execute_model.assert_called_once_with(
**execute_model_data.to_dict()) **execute_model_data.to_dict())
@torch.inference_mode() @pytest.mark.skip_global_cleanup
def test_init_device(): def test_init_device():
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as """Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as
well as other GPU initialization. well as other GPU initialization.
""" """
draft_worker = mock_worker(cls=MultiStepWorker) draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
target_worker = mock_worker() target_worker = mock_worker(use_spec=False)
rejection_sampler = MagicMock(spec=RejectionSampler) rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64 rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
...@@ -512,8 +526,8 @@ def test_init_device(): ...@@ -512,8 +526,8 @@ def test_init_device():
@torch.inference_mode() @torch.inference_mode()
def test_init_cache_engine(): def test_initialize_cache():
"""Verify SpecDecodeWorker invokes init_cache_engine on proposer/scorer """Verify SpecDecodeWorker invokes initialize_cache on proposer/scorer
workers. workers.
""" """
draft_worker = mock_worker(cls=MultiStepWorker) draft_worker = mock_worker(cls=MultiStepWorker)
...@@ -525,23 +539,22 @@ def test_init_cache_engine(): ...@@ -525,23 +539,22 @@ def test_init_cache_engine():
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector) metrics_collector)
cache_config = MagicMock() kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023}
worker.initialize_cache(**kwargs)
worker.init_cache_engine(cache_config)
draft_worker.init_cache_engine.assert_called_once_with(cache_config) draft_worker.initialize_cache.assert_called_once_with(**kwargs)
target_worker.init_cache_engine.assert_called_once_with(cache_config) target_worker.initialize_cache.assert_called_once_with(**kwargs)
@pytest.mark.parametrize('available_gpu_blocks', [1, 1024]) @pytest.mark.parametrize('available_gpu_blocks', [1, 1024])
@pytest.mark.parametrize('available_cpu_blocks', [500]) @pytest.mark.parametrize('available_cpu_blocks', [500])
@pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096]) @pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096])
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096]) @pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
@torch.inference_mode() @pytest.mark.skip_global_cleanup
def test_profile_num_available_blocks(available_gpu_blocks: int, def test_determine_num_available_blocks(available_gpu_blocks: int,
available_cpu_blocks: int, available_cpu_blocks: int,
target_cache_block_size_bytes: int, target_cache_block_size_bytes: int,
draft_kv_size_bytes: int): draft_kv_size_bytes: int):
"""Verify SpecDecodeWorker correctly profiles num available GPU blocks. """Verify SpecDecodeWorker correctly profiles num available GPU blocks.
Specifically, it should run profiling in the scorer worker, and then evenly Specifically, it should run profiling in the scorer worker, and then evenly
split the blocks between proposer and scorer worker. split the blocks between proposer and scorer worker.
...@@ -552,7 +565,7 @@ def test_profile_num_available_blocks(available_gpu_blocks: int, ...@@ -552,7 +565,7 @@ def test_profile_num_available_blocks(available_gpu_blocks: int,
rejection_sampler.token_id_dtype = torch.int64 rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
target_worker.profile_num_available_blocks.return_value = ( target_worker.determine_num_available_blocks.return_value = (
available_gpu_blocks, available_cpu_blocks) available_gpu_blocks, available_cpu_blocks)
target_worker.get_cache_block_size_bytes.return_value = ( target_worker.get_cache_block_size_bytes.return_value = (
target_cache_block_size_bytes) target_cache_block_size_bytes)
...@@ -561,17 +574,9 @@ def test_profile_num_available_blocks(available_gpu_blocks: int, ...@@ -561,17 +574,9 @@ def test_profile_num_available_blocks(available_gpu_blocks: int,
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector) metrics_collector)
# These values do not directly impact the adjusted block size calculation, num_gpu_blocks, num_cpu_blocks = worker.determine_num_available_blocks()
# so they can be fixed.
gpu_memory_utilization = 0.9
cpu_swap_space = 100
block_size = 16
num_gpu_blocks, num_cpu_blocks = worker.profile_num_available_blocks(
block_size, gpu_memory_utilization, cpu_swap_space, cache_dtype="auto")
target_worker.profile_num_available_blocks.assert_called_once_with( target_worker.determine_num_available_blocks.assert_called_once()
block_size, gpu_memory_utilization, cpu_swap_space, "auto")
assert num_cpu_blocks == available_cpu_blocks assert num_cpu_blocks == available_cpu_blocks
assert num_gpu_blocks == split_num_cache_blocks_evenly( assert num_gpu_blocks == split_num_cache_blocks_evenly(
...@@ -584,7 +589,7 @@ def test_profile_num_available_blocks(available_gpu_blocks: int, ...@@ -584,7 +589,7 @@ def test_profile_num_available_blocks(available_gpu_blocks: int,
@pytest.mark.parametrize('target_cache_block_size_bytes', @pytest.mark.parametrize('target_cache_block_size_bytes',
[2 * 2 * 4096, 2 * 2 * 8192]) [2 * 2 * 4096, 2 * 2 * 8192])
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096]) @pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
@torch.inference_mode() @pytest.mark.skip_global_cleanup
def test_split_num_cache_blocks_evenly(available_gpu_blocks: int, def test_split_num_cache_blocks_evenly(available_gpu_blocks: int,
target_cache_block_size_bytes: int, target_cache_block_size_bytes: int,
draft_kv_size_bytes: int): draft_kv_size_bytes: int):
......
...@@ -63,11 +63,14 @@ def create_execute_model_data( ...@@ -63,11 +63,14 @@ def create_execute_model_data(
def mock_worker(cls=None, def mock_worker(cls=None,
vocab_size: int = 30_000, vocab_size: int = 30_000,
max_model_len: int = 2048, max_model_len: int = 2048,
rank: int = 0) -> MagicMock: rank: int = 0,
use_spec: bool = True) -> MagicMock:
if cls is None: if cls is None:
cls = Worker cls = Worker
worker = MagicMock(spec=cls) spec = cls if use_spec else None
worker = MagicMock(spec=spec)
worker.vocab_size = vocab_size worker.vocab_size = vocab_size
worker.max_model_len = max_model_len worker.max_model_len = max_model_len
worker.rank = rank worker.rank = rank
...@@ -107,18 +110,18 @@ def create_worker(cls: type, ...@@ -107,18 +110,18 @@ def create_worker(cls: type,
block_size=block_size, block_size=block_size,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
) )
engine_config = engine_args.create_engine_config()
(model_config, cache_config, parallel_config, scheduler_config,
device_config, _, _) = engine_args.create_engine_configs()
distributed_init_method = get_distributed_init_method( distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port()) get_ip(), get_open_port())
worker = cls( worker = cls(
model_config=model_config, model_config=engine_config.model_config,
parallel_config=parallel_config, parallel_config=engine_config.parallel_config,
scheduler_config=scheduler_config, scheduler_config=engine_config.scheduler_config,
device_config=device_config, device_config=engine_config.device_config,
cache_config=engine_config.cache_config,
load_config=engine_config.load_config,
local_rank=0, local_rank=0,
rank=0, rank=0,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
...@@ -128,10 +131,11 @@ def create_worker(cls: type, ...@@ -128,10 +131,11 @@ def create_worker(cls: type,
worker.init_device() worker.init_device()
worker.load_model() worker.load_model()
cache_config.num_gpu_blocks = num_gpu_blocks engine_config.cache_config.num_gpu_blocks = num_gpu_blocks
cache_config.num_cpu_blocks = 0 engine_config.cache_config.num_cpu_blocks = 0
worker.init_cache_engine(cache_config) worker.initialize_cache(
worker.warm_up_model() num_gpu_blocks=engine_config.cache_config.num_gpu_blocks,
num_cpu_blocks=engine_config.cache_config.num_cpu_blocks)
return worker return worker
...@@ -211,7 +215,7 @@ def create_sampler_output_list( ...@@ -211,7 +215,7 @@ def create_sampler_output_list(
SequenceOutput( SequenceOutput(
output_token=token_id, output_token=token_id,
parent_seq_id=seq_ids[seq_index], parent_seq_id=seq_ids[seq_index],
logprobs={token_id: 0}, logprobs={token_id: Logprob(0)},
) )
], ],
prompt_logprobs=None, prompt_logprobs=None,
......
import argparse
import dataclasses
import os
import time
import uuid
from functools import partial
from typing import Type
import torch
import torch.nn as nn
from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer,
TensorSerializer, stream_io)
from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor
from transformers import AutoConfig, PretrainedConfig
from vllm.distributed import initialize_model_parallel
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.model_executor.model_loader.tensorizer import TensorizerArgs
from vllm.model_executor.models import ModelRegistry
# yapf conflicts with isort for this docstring
# yapf: disable
"""
tensorize_vllm_model.py is a script that can be used to serialize and
deserialize vLLM models. These models can be loaded using tensorizer directly
to the GPU extremely quickly. Tensor encryption and decryption is also
supported, although libsodium must be installed to use it. Install
vllm with tensorizer support using `pip install vllm[tensorizer]`.
To serialize a model, you can run something like this:
python tensorize_vllm_model.py \
--model EleutherAI/gpt-j-6B \
--dtype float16 \
serialize \
--serialized-directory s3://my-bucket/ \
--suffix vllm
Which downloads the model from HuggingFace, loads it into vLLM, serializes it,
and saves it to your S3 bucket. A local directory can also be used.
You can also encrypt the model weights with a randomly-generated key by
providing a `--keyfile` argument.
To deserialize a model, you can run something like this:
python tensorize_vllm_model.py \
--model EleutherAI/gpt-j-6B \
--dtype float16 \
deserialize \
--path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors
Which downloads the model tensors from your S3 bucket and deserializes them.
To provide S3 credentials, you can provide `--s3-access-key-id` and
`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this script,
the OpenAI entrypoint, as arguments for LLM(), or as environment variables
in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`.
You can also provide a `--keyfile` argument to decrypt the model weights if
they were serialized with encryption.
For more information on the available arguments, run
`python tensorize_vllm_model.py --help`.
"""
def parse_args():
parser = argparse.ArgumentParser(
description="An example script that can be used to serialize and "
"deserialize vLLM models. These models "
"can be loaded using tensorizer directly to the GPU "
"extremely quickly. Tensor encryption and decryption is "
"also supported, although libsodium must be installed to "
"use it.")
parser = TensorizerArgs.add_cli_args(EngineArgs.add_cli_args(parser))
subparsers = parser.add_subparsers(dest='command')
serialize_parser = subparsers.add_parser(
'serialize', help="Serialize a model to `--serialized-directory`")
serialize_parser.add_argument(
"--suffix",
type=str,
required=False,
help=(
"The suffix to append to the serialized model directory, which is "
"used to construct the location of the serialized model tensors, "
"e.g. if `--serialized-directory` is `s3://my-bucket/` and "
"`--suffix` is `v1`, the serialized model tensors will be "
"saved to "
"`s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors`. "
"If none is provided, a random UUID will be used."))
serialize_parser.add_argument(
"--serialized-directory",
type=str,
required=True)
serialize_parser.add_argument(
"--keyfile",
type=str,
required=False,
help=("Encrypt the model weights with a randomly-generated binary key,"
" and save the key at this path"))
deserialize_parser = subparsers.add_parser(
'deserialize',
help=("Deserialize a model from `--path-to-tensors`"
" to verify it can be loaded and used."))
deserialize_parser.add_argument(
"--path-to-tensors",
type=str,
required=True,
help="The local path or S3 URI to the model tensors to deserialize. ")
deserialize_parser.add_argument(
"--keyfile",
type=str,
required=False,
help=("Path to a binary key to use to decrypt the model weights,"
" if the model was serialized with encryption"))
return parser.parse_args()
def make_model_contiguous(model):
# Ensure tensors are saved in memory contiguously
for param in model.parameters():
param.data = param.data.contiguous()
def _get_vllm_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
architectures = getattr(config, "architectures", [])
for arch in architectures:
model_cls = ModelRegistry.load_model_cls(arch)
if model_cls is not None:
return model_cls
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
def serialize():
eng_args_dict = {f.name: getattr(args, f.name) for f in
dataclasses.fields(EngineArgs)}
engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict))
engine = LLMEngine.from_engine_args(engine_args)
model = (engine.model_executor.driver_worker.
model_runner.model)
encryption_params = EncryptionParams.random() if keyfile else None
if keyfile:
with _write_stream(keyfile) as stream:
stream.write(encryption_params.key)
with _write_stream(model_path) as stream:
serializer = TensorSerializer(stream, encryption=encryption_params)
serializer.write_module(model)
serializer.close()
print("Serialization complete. Model tensors saved to", model_path)
if keyfile:
print("Key saved to", keyfile)
def deserialize():
config = AutoConfig.from_pretrained(model_ref)
with no_init_or_tensor():
model_class = _get_vllm_model_architecture(config)
model = model_class(config)
before_mem = get_mem_usage()
start = time.time()
if keyfile:
with _read_stream(keyfile) as stream:
key = stream.read()
decryption_params = DecryptionParams.from_key(key)
tensorizer_args.deserializer_params['encryption'] = \
decryption_params
with (_read_stream(model_path)) as stream, TensorDeserializer(
stream, **tensorizer_args.deserializer_params) as deserializer:
deserializer.load_into_module(model)
end = time.time()
# Brag about how fast we are.
total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
duration = end - start
per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
after_mem = get_mem_usage()
print(
f"Deserialized {total_bytes_str} in {end - start:0.2f}s, {per_second}/s"
)
print(f"Memory usage before: {before_mem}")
print(f"Memory usage after: {after_mem}")
return model
args = parse_args()
s3_access_key_id = (args.s3_access_key_id or os.environ.get("S3_ACCESS_KEY_ID")
or None)
s3_secret_access_key = (args.s3_secret_access_key
or os.environ.get("S3_SECRET_ACCESS_KEY") or None)
s3_endpoint = (args.s3_endpoint or os.environ.get("S3_ENDPOINT_URL") or None)
_read_stream, _write_stream = (partial(
stream_io.open_stream,
mode=mode,
s3_access_key_id=s3_access_key_id,
s3_secret_access_key=s3_secret_access_key,
s3_endpoint=s3_endpoint,
) for mode in ("rb", "wb+"))
model_ref = args.model
model_name = model_ref.split("/")[1]
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "8080"
torch.distributed.init_process_group(world_size=1, rank=0)
initialize_model_parallel()
keyfile = args.keyfile if args.keyfile else None
if args.command == "serialize":
input_dir = args.serialized_directory.rstrip('/')
suffix = args.suffix if args.suffix else uuid.uuid4().hex
base_path = f"{input_dir}/vllm/{model_ref}/{suffix}"
model_path = f"{base_path}/model.tensors"
serialize()
elif args.command == "deserialize":
tensorizer_args = TensorizerArgs.from_cli_args(args)
model_path = args.path_to_tensors
deserialize()
else:
raise ValueError("Either serialize or deserialize must be specified.")
import gc
import json
import os
import subprocess
from unittest.mock import MagicMock, patch
import openai
import pytest
import ray
import torch
from tests.entrypoints.test_openai_server import ServerRunner
from vllm import SamplingParams
from vllm.model_executor.model_loader.tensorizer import (
EncryptionParams, TensorizerConfig, TensorSerializer,
is_vllm_serialized_tensorizer, load_with_tensorizer, open_stream)
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)
model_ref = "facebook/opt-125m"
tensorize_model_for_testing_script = os.path.join(
os.path.dirname(__file__), "tensorize_vllm_model_for_testing.py")
def is_curl_installed():
try:
subprocess.check_call(['curl', '--version'])
return True
except (subprocess.CalledProcessError, FileNotFoundError):
return False
@pytest.fixture(autouse=True)
def tensorizer_config():
config = TensorizerConfig(tensorizer_uri="vllm", vllm_tensorized=True)
return config
@patch('vllm.model_executor.model_loader.tensorizer.TensorizerAgent')
def test_load_with_tensorizer(mock_agent, tensorizer_config):
mock_linear_method = MagicMock()
mock_agent_instance = mock_agent.return_value
mock_agent_instance.deserialize.return_value = MagicMock()
result = load_with_tensorizer(tensorizer_config,
linear_method=mock_linear_method)
mock_agent.assert_called_once_with(tensorizer_config,
linear_method=mock_linear_method)
mock_agent_instance.deserialize.assert_called_once()
assert result == mock_agent_instance.deserialize.return_value
def test_is_vllm_model_with_vllm_in_uri(tensorizer_config):
tensorizer_config.vllm_tensorized = True
result = is_vllm_serialized_tensorizer(tensorizer_config)
assert result is True
def test_is_vllm_model_without_vllm_in_uri(tensorizer_config):
tensorizer_config.vllm_tensorized = False
result = is_vllm_serialized_tensorizer(tensorizer_config)
assert result is False
def test_deserialized_vllm_model_has_same_outputs(vllm_runner, tmp_path):
vllm_model = vllm_runner(model_ref)
model_path = tmp_path / (model_ref + ".tensors")
outputs = vllm_model.generate(prompts, sampling_params)
model = (vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)
with open_stream(model_path, "wb+") as stream:
serializer = TensorSerializer(stream)
serializer.write_module(model)
del vllm_model, model
gc.collect()
torch.cuda.empty_cache()
loaded_vllm_model = vllm_runner(
model_ref,
load_format="tensorizer",
model_loader_extra_config=TensorizerConfig(tensorizer_uri=model_path,
num_readers=1,
vllm_tensorized=True),
)
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)
# Assumes SamplingParams being seeded ensures the outputs are deterministic
assert outputs == deserialized_outputs
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_can_deserialize_s3(vllm_runner):
model_ref = "EleutherAI/pythia-1.4b"
tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"
loaded_hf_model = vllm_runner(model_ref,
load_format="tensorizer",
model_loader_extra_config=TensorizerConfig(
tensorizer_uri=tensorized_path,
num_readers=1,
vllm_tensorized=False,
s3_endpoint="object.ord1.coreweave.com",
))
deserialized_outputs = loaded_hf_model.generate(prompts, sampling_params)
assert deserialized_outputs
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_deserialized_encrypted_vllm_model_has_same_outputs(
vllm_runner, tmp_path):
vllm_model = vllm_runner(model_ref)
model_path = tmp_path / (model_ref + ".tensors")
key_path = tmp_path / (model_ref + ".key")
outputs = vllm_model.generate(prompts, sampling_params)
model = (vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)
encryption_params = EncryptionParams.random()
with open_stream(model_path, "wb+") as stream:
serializer = TensorSerializer(stream, encryption=encryption_params)
serializer.write_module(model)
with open_stream(key_path, "wb+") as stream:
stream.write(encryption_params.key)
del vllm_model, model
gc.collect()
torch.cuda.empty_cache()
loaded_vllm_model = vllm_runner(model_ref,
load_format="tensorizer",
model_loader_extra_config=TensorizerConfig(
tensorizer_uri=model_path,
encryption_keyfile=key_path,
num_readers=1,
vllm_tensorized=True))
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)
# Assumes SamplingParams being seeded ensures the outputs are deterministic
assert outputs == deserialized_outputs
def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
tmp_path):
hf_model = hf_runner(model_ref)
model_path = tmp_path / (model_ref + ".tensors")
max_tokens = 50
outputs = hf_model.generate_greedy(prompts, max_tokens=max_tokens)
with open_stream(model_path, "wb+") as stream:
serializer = TensorSerializer(stream)
serializer.write_module(hf_model.model)
del hf_model
gc.collect()
torch.cuda.empty_cache()
loaded_hf_model = vllm_runner(model_ref,
load_format="tensorizer",
model_loader_extra_config=TensorizerConfig(
tensorizer_uri=model_path,
num_readers=1,
vllm_tensorized=False))
deserialized_outputs = loaded_hf_model.generate_greedy(
prompts, max_tokens=max_tokens)
assert outputs == deserialized_outputs
def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
from huggingface_hub import snapshot_download
from examples.multilora_inference import (create_test_prompts,
process_requests)
model_ref = "meta-llama/Llama-2-7b-hf"
lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
test_prompts = create_test_prompts(lora_path)
# Serialize model before deserializing and binding LoRA adapters
vllm_model = vllm_runner(model_ref, )
model_path = tmp_path / (model_ref + ".tensors")
model = (vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)
with open_stream(model_path, "wb+") as stream:
serializer = TensorSerializer(stream)
serializer.write_module(model)
del vllm_model, model
gc.collect()
torch.cuda.empty_cache()
loaded_vllm_model = vllm_runner(
model_ref,
load_format="tensorizer",
model_loader_extra_config=TensorizerConfig(
tensorizer_uri=model_path,
num_readers=1,
vllm_tensorized=True,
),
enable_lora=True,
max_loras=1,
max_lora_rank=8,
max_cpu_loras=2,
max_num_seqs=50,
max_model_len=1000,
)
process_requests(loaded_vllm_model.model.llm_engine, test_prompts)
assert loaded_vllm_model
def test_load_without_tensorizer_load_format(vllm_runner):
with pytest.raises(ValueError):
vllm_runner(model_ref,
model_loader_extra_config=TensorizerConfig(
tensorizer_uri="test", vllm_tensorized=False))
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_tensorize_vllm_model(tmp_path):
# Test serialize command
serialize_args = [
"python3", tensorize_model_for_testing_script, "--model", model_ref,
"--dtype", "float16", "serialize", "--serialized-directory", tmp_path,
"--suffix", "tests"
]
result = subprocess.run(serialize_args, capture_output=True, text=True)
print(result.stdout) # Print the output of the serialize command
assert result.returncode == 0, (f"Serialize command failed with output:"
f"\n{result.stdout}\n{result.stderr}")
path_to_tensors = f"{tmp_path}/vllm/{model_ref}/tests/model.tensors"
# Test deserialize command
deserialize_args = [
"python3", tensorize_model_for_testing_script, "--model", model_ref,
"--dtype", "float16", "deserialize", "--path-to-tensors",
path_to_tensors
]
result = subprocess.run(deserialize_args, capture_output=True, text=True)
assert result.returncode == 0, (f"Deserialize command failed with output:"
f"\n{result.stdout}\n{result.stderr}")
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_openai_apiserver_with_tensorizer(tmp_path):
## Serialize model
serialize_args = [
"python3", tensorize_model_for_testing_script, "--model", model_ref,
"--dtype", "float16", "serialize", "--serialized-directory", tmp_path,
"--suffix", "tests"
]
result = subprocess.run(serialize_args, capture_output=True, text=True)
print(result.stdout) # Print the output of the serialize command
assert result.returncode == 0, (f"Serialize command failed with output:"
f"\n{result.stdout}\n{result.stderr}")
path_to_tensors = f"{tmp_path}/vllm/{model_ref}/tests/model.tensors"
model_loader_extra_config = {
"tensorizer_uri": path_to_tensors,
"vllm_tensorized": True
}
## Start OpenAI API server
openai_args = [
"--model", model_ref, "--dtype", "float16", "--load-format",
"tensorizer", "--model-loader-extra-config",
json.dumps(model_loader_extra_config), "--port", "8000"
]
server = ServerRunner.remote(openai_args)
assert ray.get(server.ready.remote())
print("Server ready.")
client = openai.OpenAI(
base_url="http://localhost:8000/v1",
api_key="token-abc123",
)
completion = client.completions.create(model=model_ref,
prompt="Hello, my name is",
max_tokens=5,
temperature=0.0)
assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 1
assert completion.choices[0].text is not None and len(
completion.choices[0].text) >= 5
assert completion.choices[0].finish_reason == "length"
assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5, prompt_tokens=6, total_tokens=11)
def test_raise_value_error_on_invalid_load_format(vllm_runner):
with pytest.raises(ValueError):
vllm_runner(model_ref,
load_format="safetensors",
model_loader_extra_config=TensorizerConfig(
tensorizer_uri="test", vllm_tensorized=False))
def test_tensorizer_with_tp(vllm_runner):
with pytest.raises(ValueError):
model_ref = "EleutherAI/pythia-1.4b"
tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"
vllm_runner(
model_ref,
load_format="tensorizer",
model_loader_extra_config=TensorizerConfig(
tensorizer_uri=tensorized_path,
num_readers=1,
vllm_tensorized=False,
s3_endpoint="object.ord1.coreweave.com",
),
tensor_parallel_size=2,
)
...@@ -11,8 +11,6 @@ def test_get_sliding_window(): ...@@ -11,8 +11,6 @@ def test_get_sliding_window():
"Qwen/Qwen1.5-7B", "Qwen/Qwen1.5-7B",
tokenizer_mode="auto", tokenizer_mode="auto",
trust_remote_code=False, trust_remote_code=False,
download_dir=None,
load_format="dummy",
seed=0, seed=0,
dtype="float16", dtype="float16",
revision=None, revision=None,
...@@ -30,8 +28,6 @@ def test_get_sliding_window(): ...@@ -30,8 +28,6 @@ def test_get_sliding_window():
"mistralai/Mistral-7B-v0.1", "mistralai/Mistral-7B-v0.1",
tokenizer_mode="auto", tokenizer_mode="auto",
trust_remote_code=False, trust_remote_code=False,
download_dir=None,
load_format="dummy",
seed=0, seed=0,
dtype="float16", dtype="float16",
revision=None, revision=None,
......
import os
import sys
import tempfile
from vllm.logger import enable_trace_function_call
def f1(x):
return f2(x)
def f2(x):
return x
def test_trace_function_call():
fd, path = tempfile.mkstemp()
cur_dir = os.path.dirname(__file__)
enable_trace_function_call(path, cur_dir)
f1(1)
with open(path, 'r') as f:
content = f.read()
assert "f1" in content
assert "f2" in content
sys.settrace(None)
os.remove(path)
...@@ -37,7 +37,12 @@ def _prepare_test( ...@@ -37,7 +37,12 @@ def _prepare_test(
1e-2, 1e-2,
dtype=input_tensor.dtype) dtype=input_tensor.dtype)
logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits) logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits)
model_runner = ModelRunner(None, None, None, None, None) model_runner = ModelRunner(model_config=None,
parallel_config=None,
scheduler_config=None,
device_config=None,
load_config=None,
lora_config=None)
return input_tensor, fake_logits, logits_processor, model_runner return input_tensor, fake_logits, logits_processor, model_runner
......
import time
from typing import Optional
import pytest import pytest
from vllm.sequence import (SamplerOutput, SequenceData, SequenceGroupOutput, from vllm import SamplingParams
SequenceOutput) from vllm.lora.request import LoRARequest
from vllm.sequence import (SamplerOutput, Sequence, SequenceData,
SequenceGroup, SequenceGroupOutput, SequenceOutput)
def create_dummy_prompt(
request_id: str,
prompt_length: int,
block_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
use_beam_search: bool = False,
best_of: int = 1,
) -> SequenceGroup:
if not block_size:
block_size = prompt_length
# Create dummy prompt sequence with tokens 0...block_size-1
# and prompt "0 ... block_size".
prompt_tokens = list(range(prompt_length))
prompt_str = " ".join([str(t) for t in prompt_tokens])
prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size)
seq_group = SequenceGroup(
request_id, [prompt],
SamplingParams(use_beam_search=use_beam_search, best_of=best_of),
time.time(), lora_request)
return seq_group
@pytest.fixture @pytest.fixture
...@@ -67,6 +96,29 @@ def test_sequence_data_prefill(): ...@@ -67,6 +96,29 @@ def test_sequence_data_prefill():
# append tokens and reset, simulating recompute # append tokens and reset, simulating recompute
seq_data.append_token_id(1, logprob=0.0) seq_data.append_token_id(1, logprob=0.0)
seq_data.reset_num_computed_tokens() seq_data.reset_state_for_recompute()
assert seq_data.get_num_uncomputed_tokens() == 5 assert seq_data.get_num_uncomputed_tokens() == 5
assert seq_data.get_num_computed_tokens() == 0 assert seq_data.get_num_computed_tokens() == 0
def test_sequence_group_stage():
seq_group = create_dummy_prompt("1", 12)
assert seq_group.is_prefill() is True
seq_group.update_num_computed_tokens(6)
assert seq_group.is_prefill() is True
seq_group.update_num_computed_tokens(5)
assert seq_group.is_prefill() is True
seq_group.update_num_computed_tokens(1)
assert seq_group.is_prefill() is False
seqs = seq_group.get_seqs()
assert len(seqs) == 1
seqs[0].data.append_token_id(1, logprob=0.0)
for seq in seq_group.get_seqs():
seq.reset_state_for_recompute()
assert seq_group.is_prefill() is True
seq_group.update_num_computed_tokens(5)
assert seq_group.is_prefill() is True
seq_group.update_num_computed_tokens(7)
assert seq_group.is_prefill() is True
seq_group.update_num_computed_tokens(1)
assert seq_group.is_prefill() is False
...@@ -4,8 +4,8 @@ import pytest ...@@ -4,8 +4,8 @@ import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import (Detokenizer,
from vllm.transformers_utils.tokenizer import detokenize_incrementally detokenize_incrementally)
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
TRUTH = [ TRUTH = [
......
import pytest import pytest
import torch import torch
from vllm.config import ModelConfig from vllm.config import ModelConfig, SchedulerConfig
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
@pytest.mark.parametrize("batch_size", list(range(1, 257))) @pytest.mark.parametrize("batch_size", list(range(1, 257)))
def test_prepare_prompt(batch_size): def test_prepare_prompt(batch_size):
model_runner = ModelRunner(None, None, None, None, None) scheduler_config = SchedulerConfig(100000,
100000,
100000,
enable_chunked_prefill=False)
model_runner = ModelRunner(model_config=None,
parallel_config=None,
scheduler_config=scheduler_config,
device_config=None,
load_config=None,
lora_config=None)
model_runner.set_block_size(16) model_runner.set_block_size(16)
prompt_lens = [] prompt_lens = []
...@@ -36,8 +45,10 @@ def test_prepare_prompt(batch_size): ...@@ -36,8 +45,10 @@ def test_prepare_prompt(batch_size):
prompt_len - 1) prompt_len - 1)
selected_token_start_idx += prompt_len selected_token_start_idx += prompt_len
(input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _, (input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _,
_, _) = (model_runner._prepare_prompt(seq_group_metadata_list)) _, _,
slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
assert return_prompt_lens == prompt_lens assert return_prompt_lens == prompt_lens
assert len(slot_mapping) == len(input_tokens)
# Verify input metadata is correct for prompts. # Verify input metadata is correct for prompts.
device = model_runner.device device = model_runner.device
...@@ -45,8 +56,6 @@ def test_prepare_prompt(batch_size): ...@@ -45,8 +56,6 @@ def test_prepare_prompt(batch_size):
assert torch.allclose(attn_metadata.prompt_lens_tensor, assert torch.allclose(attn_metadata.prompt_lens_tensor,
torch.tensor(prompt_lens, device=device)) torch.tensor(prompt_lens, device=device))
assert attn_metadata.prompt_lens == prompt_lens assert attn_metadata.prompt_lens == prompt_lens
assert attn_metadata.num_prompt_tokens == sum(prompt_lens)
assert attn_metadata.num_generation_tokens == 0
assert attn_metadata.max_prompt_len == max(prompt_lens) assert attn_metadata.max_prompt_len == max(prompt_lens)
# Test subquery start locs. # Test subquery start locs.
...@@ -83,23 +92,22 @@ def test_prepare_prompt(batch_size): ...@@ -83,23 +92,22 @@ def test_prepare_prompt(batch_size):
assert torch.allclose(attn_metadata.block_tables, expected) assert torch.allclose(attn_metadata.block_tables, expected)
# Cuda graph should not be used for prerill. # Cuda graph should not be used for prerill.
assert attn_metadata.use_cuda_graph is False assert attn_metadata.use_cuda_graph is False
assert attn_metadata.kv_cache_dtype == "auto"
assert input_tokens.shape == (sum(prompt_lens), ) assert len(input_tokens) == sum(prompt_lens)
assert input_positions.shape == (sum(prompt_lens), ) assert len(input_positions) == sum(prompt_lens)
torch.testing.assert_close(input_tokens, input_positions) torch.testing.assert_close(input_tokens, input_positions)
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens, prompt_lens,
subquery_lens=prompt_lens) subquery_lens=prompt_lens)
assert input_tokens.shape == (sum(prompt_lens), ) assert len(input_tokens) == sum(prompt_lens)
assert input_positions.shape == (sum(prompt_lens), ) assert len(input_positions) == sum(prompt_lens)
actual = sampling_metadata.selected_token_indices actual = sampling_metadata.selected_token_indices
expected = torch.tensor(expected_selected_token_indices, expected = torch.tensor(expected_selected_token_indices,
device=actual.device, device=actual.device,
dtype=actual.dtype) dtype=actual.dtype)
torch.testing.assert_close(actual, expected) torch.testing.assert_close(actual, expected)
torch.testing.assert_close(input_tokens, input_positions) assert input_tokens == input_positions
actual = sampling_metadata.selected_token_indices actual = sampling_metadata.selected_token_indices
expected = torch.tensor(expected_selected_token_indices, expected = torch.tensor(expected_selected_token_indices,
...@@ -115,14 +123,21 @@ def test_prepare_decode_cuda_graph(batch_size): ...@@ -115,14 +123,21 @@ def test_prepare_decode_cuda_graph(batch_size):
"facebook/opt-125m", "facebook/opt-125m",
tokenizer_mode="auto", tokenizer_mode="auto",
trust_remote_code=False, trust_remote_code=False,
download_dir=None,
load_format="dummy",
seed=0, seed=0,
dtype="float16", dtype="float16",
revision=None, revision=None,
enforce_eager=False, enforce_eager=False,
) )
model_runner = ModelRunner(model_config, None, None, None, None) scheduler_config = SchedulerConfig(100000,
100000,
100000,
enable_chunked_prefill=False)
model_runner = ModelRunner(model_config=model_config,
parallel_config=None,
scheduler_config=scheduler_config,
device_config=None,
load_config=None,
lora_config=None)
model_runner.set_block_size(16) model_runner.set_block_size(16)
prompt_lens = [] prompt_lens = []
...@@ -143,16 +158,15 @@ def test_prepare_decode_cuda_graph(batch_size): ...@@ -143,16 +158,15 @@ def test_prepare_decode_cuda_graph(batch_size):
assert seq_group_metadata.token_chunk_size == 1 assert seq_group_metadata.token_chunk_size == 1
seq_group_metadata_list.append(seq_group_metadata) seq_group_metadata_list.append(seq_group_metadata)
input_tokens, input_positions, attn_metadata, _, _, _ = ( input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = (
model_runner._prepare_decode(seq_group_metadata_list)) model_runner._prepare_decode(seq_group_metadata_list))
assert len(slot_mapping) == len(input_tokens)
expected_bs = _get_graph_batch_size(len(seq_group_metadata_list)) expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
# Verify input metadata is correct for prompts. # Verify input metadata is correct for prompts.
device = model_runner.device device = model_runner.device
assert attn_metadata.is_prompt is False assert attn_metadata.is_prompt is False
assert attn_metadata.prompt_lens is None assert attn_metadata.prompt_lens is None
assert attn_metadata.num_prompt_tokens == 0
assert attn_metadata.num_generation_tokens == expected_bs
assert attn_metadata.max_prompt_len is None assert attn_metadata.max_prompt_len is None
assert attn_metadata.subquery_start_loc is None assert attn_metadata.subquery_start_loc is None
assert attn_metadata.seq_start_loc is None assert attn_metadata.seq_start_loc is None
...@@ -170,11 +184,10 @@ def test_prepare_decode_cuda_graph(batch_size): ...@@ -170,11 +184,10 @@ def test_prepare_decode_cuda_graph(batch_size):
model_runner.get_max_block_per_batch()) model_runner.get_max_block_per_batch())
# Cuda graph should not be used for prerill. # Cuda graph should not be used for prerill.
assert attn_metadata.use_cuda_graph is True assert attn_metadata.use_cuda_graph is True
assert attn_metadata.kv_cache_dtype == "auto"
assert input_tokens.shape == (expected_bs, ) assert len(input_tokens) == expected_bs
assert input_positions.shape == (expected_bs, ) assert len(input_positions) == expected_bs
torch.testing.assert_close(input_tokens, input_positions) assert input_tokens == input_positions
# Verify Sampling # Verify Sampling
expected_selected_token_indices = [] expected_selected_token_indices = []
...@@ -190,3 +203,150 @@ def test_prepare_decode_cuda_graph(batch_size): ...@@ -190,3 +203,150 @@ def test_prepare_decode_cuda_graph(batch_size):
device=actual.device, device=actual.device,
dtype=actual.dtype) dtype=actual.dtype)
torch.testing.assert_close(actual, expected) torch.testing.assert_close(actual, expected)
def test_empty_seq_group():
"""Verify prepare prompt and decode returns empty output."""
model_config = ModelConfig(
"facebook/opt-125m",
"facebook/opt-125m",
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
revision=None,
enforce_eager=False,
)
model_runner = ModelRunner(model_config=model_config,
parallel_config=None,
scheduler_config=None,
device_config=None,
load_config=None,
lora_config=None)
model_runner.set_block_size(16)
seq_group_metadata_list = []
input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = (
model_runner._prepare_decode(seq_group_metadata_list))
assert len(input_tokens) == 0
assert len(input_positions) == 0
assert attn_metadata is None
assert len(slot_mapping) == 0
(input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _,
_, _,
slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
assert len(input_tokens) == 0
assert len(input_positions) == 0
assert attn_metadata is None
assert len(slot_mapping) == 0
assert len(return_prompt_lens) == 0
@pytest.mark.parametrize("batch_size", list(range(2, 128)))
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_hybrid_batches(batch_size, enforce_eager, monkeypatch):
def get_world_size(group=None):
return 1
def mock_get_process_group_ranks(group=None):
return [0]
monkeypatch.setattr(torch.distributed, "get_world_size", get_world_size)
monkeypatch.setattr(torch.distributed, "get_process_group_ranks",
mock_get_process_group_ranks)
model_config = ModelConfig(
"facebook/opt-125m",
"facebook/opt-125m",
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
revision=None,
enforce_eager=enforce_eager,
)
scheduler_config = SchedulerConfig(100000,
100000,
100000,
enable_chunked_prefill=True)
model_runner = ModelRunner(model_config=model_config,
parallel_config=None,
scheduler_config=scheduler_config,
device_config=None,
load_config=None,
lora_config=None,
is_driver_worker=True)
model_runner.set_block_size(16)
# Add prefill requests.
prompt_lens = []
seq_group_metadata_list = []
prefill_metadata_list = []
decode_metadata_list = []
block_tables = {0: [1]}
prefill_batch_size = batch_size // 2
decode_batch_size = batch_size - prefill_batch_size
for i in range(prefill_batch_size):
# make sure all tokens fit into one block
prompt_len = i % (model_runner.block_size - 1) + 1
prompt_lens.append(prompt_len)
seq_data = SequenceData(list(range(prompt_len)))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: seq_data},
sampling_params=SamplingParams(temperature=0),
block_tables=block_tables,
)
assert seq_group_metadata.token_chunk_size == seq_data.get_len()
seq_group_metadata_list.append(seq_group_metadata)
prefill_metadata_list.append(seq_group_metadata)
# Add decode requests
for i in range(prefill_batch_size, batch_size):
# make sure all tokens fit into one block
prompt_len = i % (model_runner.block_size - 1) + 1
prompt_toks = list(range(prompt_len))
seq_data = SequenceData(prompt_toks)
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=False,
seq_data={0: seq_data},
sampling_params=SamplingParams(temperature=0),
block_tables={0: [1]},
)
assert seq_group_metadata.token_chunk_size == 1
seq_group_metadata_list.append(seq_group_metadata)
decode_metadata_list.append(seq_group_metadata)
(input_tokens, input_positions, attn_metadata, _, _, _,
_) = model_runner.prepare_input_tensors(seq_group_metadata_list)
prefill_meta_actual = attn_metadata.prefill_metadata
decode_meta_actual = attn_metadata.decode_metadata
assert len(attn_metadata.slot_mapping) == len(input_tokens)
assert len(input_positions) == len(input_tokens)
assert attn_metadata.kv_cache_dtype == "auto"
assert attn_metadata.num_prefills == prefill_batch_size
if enforce_eager:
assert attn_metadata.num_decode_tokens == decode_batch_size
else:
assert attn_metadata.num_decode_tokens == _get_graph_batch_size(
decode_batch_size)
assert attn_metadata.num_prefill_tokens == sum(prompt_lens)
# Verify attn metadata is consistent. We don't need to test individual
# values here because they are tested above.
prefill_meta = model_runner._prepare_prompt(
prefill_metadata_list).attn_metadata
decode_meta = model_runner._prepare_decode(
decode_metadata_list).attn_metadata
for attr_expected, attr_actual in zip(vars(prefill_meta),
vars(prefill_meta_actual)):
assert attr_expected[1] == attr_actual[1]
for attr_expected, attr_actual in zip(vars(decode_meta),
vars(decode_meta_actual)):
assert attr_expected[1] == attr_actual[1]
...@@ -10,19 +10,20 @@ def test_swap() -> None: ...@@ -10,19 +10,20 @@ def test_swap() -> None:
engine_args = EngineArgs(model="facebook/opt-125m", engine_args = EngineArgs(model="facebook/opt-125m",
dtype="half", dtype="half",
load_format="dummy") load_format="dummy")
(model_config, cache_config, parallel_config, scheduler_config, engine_config = engine_args.create_engine_config()
device_config, _, _) = engine_args.create_engine_configs() engine_config.cache_config.num_gpu_blocks = 1000
cache_config.num_gpu_blocks = 100 engine_config.cache_config.num_cpu_blocks = 1000
cache_config.num_cpu_blocks = 100
# Create the worker. # Create the worker.
distributed_init_method = get_distributed_init_method( distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port()) get_ip(), get_open_port())
worker = Worker( worker = Worker(
model_config=model_config, model_config=engine_config.model_config,
parallel_config=parallel_config, parallel_config=engine_config.parallel_config,
scheduler_config=scheduler_config, scheduler_config=engine_config.scheduler_config,
device_config=device_config, device_config=engine_config.device_config,
cache_config=engine_config.cache_config,
load_config=engine_config.load_config,
local_rank=0, local_rank=0,
rank=0, rank=0,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
...@@ -32,8 +33,9 @@ def test_swap() -> None: ...@@ -32,8 +33,9 @@ def test_swap() -> None:
# Initialize the worker. # Initialize the worker.
worker.init_device() worker.init_device()
worker.load_model() worker.load_model()
worker.init_cache_engine(cache_config) worker.initialize_cache(
worker.warm_up_model() num_gpu_blocks=engine_config.cache_config.num_gpu_blocks,
num_cpu_blocks=engine_config.cache_config.num_cpu_blocks)
# Randomly initialize the cache. # Randomly initialize the cache.
gpu_cache = worker.cache_engine.gpu_cache gpu_cache = worker.cache_engine.gpu_cache
......
...@@ -5,14 +5,16 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine ...@@ -5,14 +5,16 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.engine.ray_utils import initialize_ray_cluster from vllm.engine.ray_utils import initialize_ray_cluster
from vllm.entrypoints.llm import LLM from vllm.entrypoints.llm import LLM
from vllm.model_executor.models import ModelRegistry
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.version import __dcu_version__ from vllm.version import __dcu_version__
__version__ = "0.4.0" __version__ = "0.4.1"
__all__ = [ __all__ = [
"LLM", "LLM",
"ModelRegistry",
"SamplingParams", "SamplingParams",
"RequestOutput", "RequestOutput",
"CompletionOutput", "CompletionOutput",
......
from typing import Dict, Optional, Tuple
import torch
try:
from vllm._C import cache_ops as vllm_cache_ops
from vllm._C import ops as vllm_ops
except ImportError:
pass
# activation ops
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
vllm_ops.silu_and_mul(out, x)
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
vllm_ops.gelu_and_mul(out, x)
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
vllm_ops.gelu_tanh_and_mul(out, x)
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
vllm_ops.gelu_fast(out, x)
def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
vllm_ops.gelu_new(out, x)
# page attention ops
def paged_attention_v1(
out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
block_size: int,
max_context_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
kv_scale: float,
) -> None:
vllm_ops.paged_attention_v1(out, query, key_cache, value_cache,
num_kv_heads, scale, block_tables,
context_lens, block_size, max_context_len,
alibi_slopes, kv_cache_dtype, kv_scale)
def paged_attention_v2(
out: torch.Tensor,
exp_sum: torch.Tensor,
max_logits: torch.Tensor,
tmp_out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
block_size: int,
max_context_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
kv_scale: float,
) -> None:
vllm_ops.paged_attention_v2(out, exp_sum, max_logits, tmp_out, query,
key_cache, value_cache, num_kv_heads, scale,
block_tables, context_lens, block_size,
max_context_len, alibi_slopes, kv_cache_dtype,
kv_scale)
# pos encoding ops
def rotary_embedding(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool,
) -> None:
vllm_ops.rotary_embedding(positions, query, key, head_size, cos_sin_cache,
is_neox)
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, head_size: int,
cos_sin_cache: torch.Tensor, is_neox: bool,
rot_dim: int,
cos_sin_cache_offsets: torch.Tensor) -> None:
vllm_ops.batched_rotary_embedding(positions, query, key, head_size,
cos_sin_cache, is_neox, rot_dim,
cos_sin_cache_offsets)
# layer norm ops
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
epsilon: float) -> None:
vllm_ops.rms_norm(out, input, weight, epsilon)
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, epsilon: float) -> None:
vllm_ops.fused_add_rms_norm(input, residual, weight, epsilon)
# quantization ops
# awq
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
zeros: torch.Tensor, split_k_iters: int, thx: int,
thy: int) -> torch.Tensor:
return vllm_ops.awq_dequantize(qweight, scales, zeros, split_k_iters, thx,
thy)
def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
return vllm_ops.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
# gptq
def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
b_g_idx: torch.Tensor, use_exllama: bool,
bit: int) -> torch.Tensor:
return vllm_ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
b_g_idx, use_exllama, bit)
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
bit: int) -> None:
vllm_ops.gptq_shuffle(q_weight, q_perm, bit)
# squeezellm
def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor,
lookup_table: torch.Tensor) -> None:
vllm_ops.squeezellm_gemm(vec, mat, mul, lookup_table)
# marlin
def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int,
size_n: int, size_k: int) -> torch.Tensor:
return vllm_ops.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m,
size_n, size_k)
# fp8
def scaled_fp8_quant(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
vllm_ops.scaled_fp8_quant(output, input, scale)
return output, scale
# moe
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
block_size: int, sorted_token_ids: torch.Tensor,
experts_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor) -> None:
vllm_ops.moe_align_block_size(topk_ids, num_experts, block_size,
sorted_token_ids, experts_ids,
num_tokens_post_pad)
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
kv_scale: float,
) -> None:
vllm_cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype, kv_scale)
def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor,
block_mapping: torch.Tensor) -> None:
vllm_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
block_mapping: Dict[int, int]) -> None:
vllm_cache_ops.swap_blocks(src, dst, block_mapping)
def convert_fp8(output: torch.Tensor, input: torch.Tensor) -> None:
vllm_cache_ops.convert_fp8(output, input)
#TODO: cuda_utils, custom_ar
from vllm.attention.backends.abstract import (AttentionBackend, from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata) AttentionMetadata,
AttentionMetadataPerStage)
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
...@@ -8,4 +9,5 @@ __all__ = [ ...@@ -8,4 +9,5 @@ __all__ = [
"AttentionMetadata", "AttentionMetadata",
"Attention", "Attention",
"get_attn_backend", "get_attn_backend",
"AttentionMetadataPerStage",
] ]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment