Commit 7a985548 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.9.0' into v0.9.0-ori

parents 45d3785c dc1440cf
......@@ -152,7 +152,7 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"block_size": 8,
"block_size": 16,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override": 2 + 256 // 8,
"max_model_len": (2 + 256 // 8) * 8,
......
# SPDX-License-Identifier: Apache-2.0
"""This docstring details important information on the testing methodology.
This test verifies that memory usage remains constant (or never grows) when
we enable / disable speculation via --speculative-disable-by-batch-size.
There are a lot of things we try to keep track of between batches of requests
and if certain tensors are not freed from memory, can result in CUDA ooms.
This is particularly relevant for production situations where speculation might
be enabled during off hours, but disabled once traffic peaks during the workday.
Since traffic will stay high for a long period of time, verifying we do not
increase our memory usage over time is essential to prevent possible CUDA ooms.
"""
import torch
import vllm
from tests.core.utils import create_dummy_prompt
from vllm.sequence import SequenceGroup
ITERATIONS = 100
MAIN_MODEL = "JackFram/llama-68m"
# speculative model
SPEC_MODEL = "abhigoyal/vllm-medusa-llama-68m-random"
BATCH_SIZE = 5
SPEC_DISABLE_BATCH_SIZE = 2
def add_seq_group_to_engine(engine: vllm.LLMEngine, seq_group: SequenceGroup):
scheduler = engine.scheduler[0]
scheduler.add_seq_group(seq_group)
"""
Since we are using a batch size greater than the disabled batch size,
we can ensure we go through the _no_spec codepath for most of our engine steps.
"""
def test_memory_usage_no_spec():
previous_memory_allocated = None
llm = vllm.LLM(model=MAIN_MODEL,
speculative_config={
"model": SPEC_MODEL,
"num_speculative_tokens": 3,
"disable_by_batch_size": SPEC_DISABLE_BATCH_SIZE,
})
batch_sequences = set()
engine = llm.llm_engine
for i in range(ITERATIONS):
seq, seq_group = create_dummy_prompt(request_id=str(i),
prompt_length=10,
min_tokens=10,
max_tokens=10)
add_seq_group_to_engine(engine, seq_group)
batch_sequences.add(seq)
engine.step()
for seq in list(batch_sequences):
if seq.is_finished():
batch_sequences.remove(seq)
# If we aren't at our batch size yet, continue
if len(batch_sequences) <= BATCH_SIZE:
continue
# Otherwise, loop until at least one request is done
while not any(seq.is_finished() for seq in batch_sequences):
engine.step()
# Remove it from the set
for seq in list(batch_sequences):
if seq.is_finished():
batch_sequences.remove(seq)
# At this point, we are always at the case where we have finished
# processing some number of requests from the batch after running
# several _no_spec executions. The memory should not have
# increased between the previous time this was recorded and the
# current time.
if previous_memory_allocated is None:
previous_memory_allocated = torch.cuda.memory_allocated()
else:
assert previous_memory_allocated == torch.cuda.memory_allocated()
# SPDX-License-Identifier: Apache-2.0
import functools
import gc
from typing import Callable, TypeVar
import pytest
import torch
from typing_extensions import ParamSpec
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
......@@ -25,32 +18,6 @@ def cleanup():
cleanup_dist_env_and_memory(shutdown_ray=True)
_P = ParamSpec("_P")
_R = TypeVar("_R")
def retry_until_skip(n: int):
def decorator_retry(func: Callable[_P, _R]) -> Callable[_P, _R]:
@functools.wraps(func)
def wrapper_retry(*args: _P.args, **kwargs: _P.kwargs) -> _R:
for i in range(n):
try:
return func(*args, **kwargs)
except AssertionError:
gc.collect()
torch.cuda.empty_cache()
if i == n - 1:
pytest.skip(f"Skipping test after {n} attempts.")
raise AssertionError("Code should not be reached")
return wrapper_retry
return decorator_retry
@pytest.fixture(autouse=True)
def tensorizer_config():
config = TensorizerConfig(tensorizer_uri="vllm")
......
......@@ -28,7 +28,6 @@ from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
from vllm.utils import PlaceholderModule, import_from_path
from ..utils import VLLM_PATH, RemoteOpenAIServer
from .conftest import retry_until_skip
try:
from tensorizer import EncryptionParams
......@@ -325,7 +324,7 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(
assert outputs == deserialized_outputs
@retry_until_skip(3)
@pytest.mark.flaky(reruns=3)
def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
gc.collect()
torch.cuda.empty_cache()
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import MISSING, Field, asdict, dataclass, field
from typing import Literal, Union
import pytest
from vllm.config import ModelConfig, PoolerConfig, get_field
from vllm.config import (LoadConfig, ModelConfig, PoolerConfig, VllmConfig,
config, get_field)
from vllm.model_executor.layers.pooler import PoolingType
from vllm.platforms import current_platform
class TestConfig1:
pass
@dataclass
class TestConfig2:
a: int
"""docstring"""
@dataclass
class TestConfig3:
a: int = 1
@dataclass
class TestConfig4:
a: Union[Literal[1], Literal[2]] = 1
"""docstring"""
@pytest.mark.parametrize(("test_config", "expected_error"), [
(TestConfig1, "must be a dataclass"),
(TestConfig2, "must have a default"),
(TestConfig3, "must have a docstring"),
(TestConfig4, "must use a single Literal"),
])
def test_config(test_config, expected_error):
with pytest.raises(Exception, match=expected_error):
config(test_config)
def test_get_field():
@dataclass
......@@ -152,7 +186,7 @@ def test_get_pooling_config():
revision=None,
)
pooling_config = model_config._init_pooler_config(None)
pooling_config = model_config._init_pooler_config()
assert pooling_config is not None
assert pooling_config.normalize
......@@ -172,11 +206,12 @@ def test_get_pooling_config_from_args():
dtype="float16",
revision=None)
override_config = PoolerConfig(pooling_type='CLS', normalize=True)
override_pooler_config = PoolerConfig(pooling_type='CLS', normalize=True)
model_config.override_pooler_config = override_pooler_config
pooling_config = model_config._init_pooler_config(override_config)
pooling_config = model_config._init_pooler_config()
assert pooling_config is not None
assert asdict(pooling_config) == asdict(override_config)
assert asdict(pooling_config) == asdict(override_pooler_config)
@pytest.mark.skipif(current_platform.is_rocm(),
......@@ -376,3 +411,16 @@ def test_generation_config_loading():
override_generation_config=override_generation_config)
assert model_config.get_diff_sampling_param() == override_generation_config
@pytest.mark.parametrize("pt_load_map_location", [
"cuda",
{
"": "cuda"
},
])
def test_load_config_pt_load_map_location(pt_load_map_location):
load_config = LoadConfig(pt_load_map_location=pt_load_map_location)
config = VllmConfig(load_config=load_config)
assert config.load_config.pt_load_map_location == pt_load_map_location
......@@ -11,7 +11,7 @@ from vllm.scalar_type import scalar_types
(0, 15, scalar_types.uint4),
(-8, 7, scalar_types.uint4b8),
(-128, 127, scalar_types.uint8b128),
(-6., 6., scalar_types.float4_e2m1fn),
(-6., 6., scalar_types.float4_e2m1f),
(-28., 28., scalar_types.float6_e3m2f),
(torch.int8, scalar_types.int8),
(torch.uint8, scalar_types.uint8),
......
......@@ -10,7 +10,7 @@ import torch
from huggingface_hub import snapshot_download
from vllm import LLM, SamplingParams
from vllm.model_executor.model_loader.loader import ShardedStateLoader
from vllm.model_executor.model_loader import ShardedStateLoader
prompts = [
"Hello, my name is",
......
# SPDX-License-Identifier: Apache-2.0
import sys
import types
from unittest import mock
from vllm.triton_utils.importing import (TritonLanguagePlaceholder,
TritonPlaceholder)
def test_triton_placeholder_is_module():
triton = TritonPlaceholder()
assert isinstance(triton, types.ModuleType)
assert triton.__name__ == "triton"
def test_triton_language_placeholder_is_module():
triton_language = TritonLanguagePlaceholder()
assert isinstance(triton_language, types.ModuleType)
assert triton_language.__name__ == "triton.language"
def test_triton_placeholder_decorators():
triton = TritonPlaceholder()
@triton.jit
def foo(x):
return x
@triton.autotune
def bar(x):
return x
@triton.heuristics
def baz(x):
return x
assert foo(1) == 1
assert bar(2) == 2
assert baz(3) == 3
def test_triton_placeholder_decorators_with_args():
triton = TritonPlaceholder()
@triton.jit(debug=True)
def foo(x):
return x
@triton.autotune(configs=[], key="x")
def bar(x):
return x
@triton.heuristics(
{"BLOCK_SIZE": lambda args: 128 if args["x"] > 1024 else 64})
def baz(x):
return x
assert foo(1) == 1
assert bar(2) == 2
assert baz(3) == 3
def test_triton_placeholder_language():
lang = TritonLanguagePlaceholder()
assert isinstance(lang, types.ModuleType)
assert lang.__name__ == "triton.language"
assert lang.constexpr is None
assert lang.dtype is None
assert lang.int64 is None
def test_triton_placeholder_language_from_parent():
triton = TritonPlaceholder()
lang = triton.language
assert isinstance(lang, TritonLanguagePlaceholder)
def test_no_triton_fallback():
# clear existing triton modules
sys.modules.pop("triton", None)
sys.modules.pop("triton.language", None)
sys.modules.pop("vllm.triton_utils", None)
sys.modules.pop("vllm.triton_utils.importing", None)
# mock triton not being installed
with mock.patch.dict(sys.modules, {"triton": None}):
from vllm.triton_utils import HAS_TRITON, tl, triton
assert HAS_TRITON is False
assert triton.__class__.__name__ == "TritonPlaceholder"
assert triton.language.__class__.__name__ == "TritonLanguagePlaceholder"
assert tl.__class__.__name__ == "TritonLanguagePlaceholder"
......@@ -3,6 +3,7 @@
import asyncio
import hashlib
import json
import pickle
import socket
from collections.abc import AsyncIterator
......@@ -10,13 +11,15 @@ from unittest.mock import patch
import pytest
import torch
import zmq
from vllm_test_utils.monitor import monitor
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache,
MemorySnapshot, PlaceholderModule, StoreBoolean,
bind_kv_cache, deprecate_kwargs, get_open_port,
memory_profiling, merge_async_iterators, sha256,
make_zmq_path, make_zmq_socket, memory_profiling,
merge_async_iterators, sha256, split_zmq_path,
supports_kw, swap_dict_values)
from .utils import create_new_process_for_each_test, error_on_warning
......@@ -136,6 +139,7 @@ def parser():
parser.add_argument('--model-name')
parser.add_argument('--batch-size', type=int)
parser.add_argument('--enable-feature', action='store_true')
parser.add_argument('--hf-overrides', type=json.loads)
return parser
......@@ -249,6 +253,29 @@ def test_no_model_tag(parser_with_config, cli_config_file):
parser_with_config.parse_args(['serve', '--config', cli_config_file])
def test_dict_args(parser):
args = [
"--model-name=something.something",
"--hf-overrides.key1",
"val1",
"--hf-overrides.key2.key3",
"val2",
"--hf-overrides.key2.key4",
"val3",
"--hf-overrides.key5=val4",
]
parsed_args = parser.parse_args(args)
assert parsed_args.model_name == "something.something"
assert parsed_args.hf_overrides == {
"key1": "val1",
"key2": {
"key3": "val2",
"key4": "val3",
},
"key5": "val4",
}
# yapf: enable
@pytest.mark.parametrize(
"callable,kw_name,requires_kw_only,allow_var_kwargs,is_supported",
......@@ -662,3 +689,58 @@ def test_sha256(input: tuple, output: int):
# hashing different input, returns different value
assert hash != sha256(input + (1, ))
@pytest.mark.parametrize(
"path,expected",
[
("ipc://some_path", ("ipc", "some_path", "")),
("tcp://127.0.0.1:5555", ("tcp", "127.0.0.1", "5555")),
("tcp://[::1]:5555", ("tcp", "::1", "5555")), # IPv6 address
("inproc://some_identifier", ("inproc", "some_identifier", "")),
]
)
def test_split_zmq_path(path, expected):
assert split_zmq_path(path) == expected
@pytest.mark.parametrize(
"invalid_path",
[
"invalid_path", # Missing scheme
"tcp://127.0.0.1", # Missing port
"tcp://[::1]", # Missing port for IPv6
"tcp://:5555", # Missing host
]
)
def test_split_zmq_path_invalid(invalid_path):
with pytest.raises(ValueError):
split_zmq_path(invalid_path)
def test_make_zmq_socket_ipv6():
# Check if IPv6 is supported by trying to create an IPv6 socket
try:
sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
sock.close()
except socket.error:
pytest.skip("IPv6 is not supported on this system")
ctx = zmq.Context()
ipv6_path = "tcp://[::]:5555" # IPv6 loopback address
socket_type = zmq.REP # Example socket type
# Create the socket
zsock: zmq.Socket = make_zmq_socket(ctx, ipv6_path, socket_type)
# Verify that the IPV6 option is set
assert zsock.getsockopt(zmq.IPV6) == 1, "IPV6 option should be enabled for IPv6 addresses"
# Clean up
zsock.close()
ctx.term()
def test_make_zmq_path():
assert make_zmq_path("tcp", "127.0.0.1", "5555") == "tcp://127.0.0.1:5555"
assert make_zmq_path("tcp", "::1", "5555") == "tcp://[::1]:5555"
# SPDX-License-Identifier: Apache-2.0
import os
from unittest.mock import patch
import pytest
from vllm.envs import get_vllm_port
def test_get_vllm_port_not_set():
"""Test when VLLM_PORT is not set."""
with patch.dict(os.environ, {}, clear=True):
assert get_vllm_port() is None
def test_get_vllm_port_valid():
"""Test when VLLM_PORT is set to a valid integer."""
with patch.dict(os.environ, {"VLLM_PORT": "5678"}, clear=True):
assert get_vllm_port() == 5678
def test_get_vllm_port_invalid():
"""Test when VLLM_PORT is set to a non-integer value."""
with (patch.dict(os.environ, {"VLLM_PORT": "abc"}, clear=True),
pytest.raises(ValueError, match="must be a valid integer")):
get_vllm_port()
def test_get_vllm_port_uri():
"""Test when VLLM_PORT is set to a URI."""
with (patch.dict(os.environ, {"VLLM_PORT": "tcp://localhost:5678"},
clear=True),
pytest.raises(ValueError, match="appears to be a URI")):
get_vllm_port()
......@@ -60,8 +60,16 @@ def _run_incremental_decode(tokenizer,
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
request = EngineCoreRequest("", prompt_token_ids, None, None, None, params,
None, 0.0, None)
request = EngineCoreRequest("",
prompt_token_ids,
None,
None,
None,
params,
None,
0.0,
None,
cache_salt=None)
if fast is None:
detokenizer = IncrementalDetokenizer.from_new_request(
......
......@@ -2,7 +2,7 @@
"""
This test file includes some cases where it is inappropriate to
only get the `eos_token_id` from the tokenizer as defined by
:meth:`vllm.LLMEngine._get_eos_token_id`.
{meth}`vllm.LLMEngine._get_eos_token_id`.
"""
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.tokenizer import get_tokenizer
......
# SPDX-License-Identifier: Apache-2.0
import pytest
import vllm
from vllm.lora.request import LoRARequest
# This file contains tests to ensure that LoRA works correctly on the TPU
# backend. We use a series of custom trained adapters for Qwen2.5-3B-Instruct
# for this. The adapters are:
# Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter, where x ranges
# from 1 to 4.
# These adapters are trained using a standard huggingface peft training script,
# where all the inputs are "What is 1+1? \n" and all the outputs are "x". We run
# 100 training iterations with a training batch size of 100.
@pytest.fixture(scope="function", autouse=True)
def use_v1_only(monkeypatch: pytest.MonkeyPatch):
"""
Since Multi-LoRA is only supported on the v1 TPU backend, set VLLM_USE_V1=1
for all tests in this file
"""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
yield
def setup_vllm(num_loras: int) -> vllm.LLM:
return vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct",
num_scheduler_steps=1,
max_model_len=256,
max_seq_len_to_capture=256,
max_num_seqs=8,
enable_lora=True,
max_loras=num_loras,
max_lora_rank=8)
def test_single_lora():
"""
This test ensures we can run a single LoRA adapter on the TPU backend.
We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter" which
will force Qwen2.5-3B-Instruct to claim 1+1=1.
"""
llm = setup_vllm(1)
prompt = "What is 1+1? \n"
lora_request = LoRARequest(
"lora_adapter_1", 1,
"Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter")
output = llm.generate(prompt,
sampling_params=vllm.SamplingParams(max_tokens=256,
temperature=0),
lora_request=lora_request)[0].outputs[0].text
answer = output.strip()[0]
assert answer.isdigit()
assert int(answer) == 1
def test_lora_hotswapping():
"""
This test ensures we can run multiple LoRA adapters on the TPU backend, even
if we only have space to store 1.
We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which
will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x.
"""
lora_name_template = \
"Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter"
lora_requests = [
LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i))
for i in range(1, 5)
]
llm = setup_vllm(1)
prompt = "What is 1+1? \n"
for i, req in enumerate(lora_requests):
output = llm.generate(prompt,
sampling_params=vllm.SamplingParams(
max_tokens=256, temperature=0),
lora_request=req)[0].outputs[0].text
answer = output.strip()[0]
assert answer.isdigit()
assert int(answer) == i + 1
def test_multi_lora():
"""
This test ensures we can run multiple LoRA adapters on the TPU backend, when
we have enough space to store all of them.
We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which
will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x.
"""
lora_name_template = \
"Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter"
lora_requests = [
LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i))
for i in range(1, 5)
]
llm = setup_vllm(4)
prompt = "What is 1+1? \n"
for i, req in enumerate(lora_requests):
output = llm.generate(prompt,
sampling_params=vllm.SamplingParams(
max_tokens=256, temperature=0),
lora_request=req)[0].outputs[0].text
answer = output.strip()[0]
assert answer.isdigit()
assert int(output.strip()[0]) == i + 1
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
# Required to register the custom ops
import vllm.lora.ops.xla_ops.pallas # noqa # pylint: disable=unused-import
N_TOKENS = [16, 1024, 4096]
HIDDEN_SIZES = [1024, 2048, 4096]
DTYPES = [torch.bfloat16]
NUM_LORA = [1, 4, 16]
RANKS = [32, 256, 512]
def generate_test_data(T, D, L, N, seed, dtype=torch.float32):
"""
Inputs: (All integers)
T: Total number of tokens
D: Input dim
L: LoRA Dim
N: N LoRAs
Outputs:
inputs: torch.Tensor - shape (T, D)
loras: torch.Tensor - shape (N, 1, L, D)
idxs: torch.Tensor - shape (T, ) - all values must be in [0, N)
ref_output: torch.Tensor - shape (T, L) - inputs @ loras[idxs].T
"""
torch.manual_seed(seed)
inputs = torch.randn((T, D), device="xla", dtype=dtype)
loras = torch.randn((N, 1, L, D), device="xla", dtype=dtype)
idxs = torch.randint(0, N, (T, ), dtype=torch.int32, device="xla")
ref_output = ref_bgmv(inputs, loras, idxs)
return inputs, loras, idxs, ref_output
def ref_bgmv(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.Tensor):
selected_loras = loras[idxs]
if len(selected_loras.shape) == 4:
selected_loras = selected_loras.squeeze(axis=1)
batch_size, output_size, input_size = selected_loras.shape
return (selected_loras @ inputs.reshape(
(batch_size, input_size, 1))).reshape((batch_size, output_size))
# Parameterize tests with various shapes and dtypes
@pytest.mark.parametrize("T", N_TOKENS)
@pytest.mark.parametrize("D", HIDDEN_SIZES)
@pytest.mark.parametrize("L", RANKS)
@pytest.mark.parametrize("N", NUM_LORA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
@pytest.mark.parametrize("seed", [0])
def test_bgmv_correctness(T, D, L, N, dtype, op_type, seed):
if op_type == "expand":
D, L = L, D
inputs, loras, idxs, ref_output = generate_test_data(
T, D, L, N, seed, dtype)
# Run bgmv
output = torch.ops.xla.bgmv(inputs, loras, idxs)
# Make sure we have no NaNs
assert not torch.any(torch.isnan(output))
# Compare with reference output
assert torch.allclose(output, ref_output, rtol=1e-2, atol=1e-2)
# SPDX-License-Identifier: Apache-2.0
"""Tests for the Pallas MOE implementation.
Run `pytest tests/kernels/moe/test_moe_pallas.py`.
"""
import pytest
import torch
# yapf conflicts with isort for this block
# yapf: disable
from vllm.model_executor.layers.fused_moe.moe_pallas import (
fused_moe as pallas_moe)
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
fused_moe as torch_moe)
# yapf: enable
from vllm.platforms import current_platform
if not current_platform.is_tpu():
pytest.skip("This test needs a TPU.", allow_module_level=True)
NUM_EXPERTS = [8, 64]
EP_SIZE = [1]
TOP_KS = [2, 6]
# The Pallas GMM kernel requires num_tokens * topk to be a multiple of 16
@pytest.mark.parametrize("m", [8, 16, 64, 2048])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 511, 1024])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("ep_size", EP_SIZE)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
def test_pallas_moe(
m: int,
n: int,
k: int,
e: int,
topk: int,
ep_size: int,
dtype: torch.dtype,
):
import torch_xla.core.xla_model as xm
with torch.device(xm.xla_device()):
a = torch.randn((m, k), dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), dtype=dtype) / 10
w2 = torch.randn((e, k, n), dtype=dtype) / 10
score = torch.randn((m, e), dtype=dtype)
# TODO: Support ep
if ep_size > 1:
pytest.skip("No support for ep_size > 1 yet")
else:
e_map = None
# Run both implementations
torch_output = torch_moe(
hidden_states=a,
w1=w1,
w2=w2,
gating_output=score,
topk=topk,
global_num_experts=e,
expert_map=e_map,
renormalize=False,
)
pallas_output = pallas_moe(
hidden_states=a,
w1=w1,
w2=w2,
gating_output=score,
topk=topk,
global_num_experts=e,
expert_map=e_map,
renormalize=False,
)
xm.mark_step()
# Compare outputs
torch.testing.assert_close(
pallas_output.cpu(),
torch_output.cpu(),
atol=2e-2,
rtol=0,
)
......@@ -29,7 +29,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.model_executor.model_loader.loader import get_model_loader
from vllm.model_executor.model_loader import get_model_loader
from vllm.platforms import current_platform
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import (FlexibleArgumentParser, GB_bytes,
......@@ -952,7 +952,7 @@ def get_client_text_logprob_generations(
completions: list[Completion]) -> list[TextTextLogprobs]:
'''Operates on the output of a request made to an Open-AI-protocol
completions endpoint; obtains top-rank logprobs for each token in
each :class:`SequenceGroup`
each {class}`SequenceGroup`
'''
text_generations = get_client_text_generations(completions)
text = ''.join(text_generations)
......
# SPDX-License-Identifier: Apache-2.0
import importlib
import pytest
import torch
......@@ -10,8 +11,7 @@ from vllm.utils import GiB_bytes, sha256
from vllm.v1.core.kv_cache_manager import KVCacheManager
# disable yapf here as it formats differently than isort such that both fail
# yapf: disable
from vllm.v1.core.kv_cache_utils import (NONE_HASH, BlockHashType,
FreeKVCacheBlockQueue, KVCacheBlock,
from vllm.v1.core.kv_cache_utils import (FreeKVCacheBlockQueue, KVCacheBlock,
PrefixCachingMetrics,
estimate_max_model_len,
generate_block_hash_extra_keys,
......@@ -19,7 +19,8 @@ from vllm.v1.core.kv_cache_utils import (NONE_HASH, BlockHashType,
hash_request_tokens,
unify_kv_cache_configs)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheTensor)
KVCacheGroupSpec, KVCacheTensor,
SlidingWindowSpec)
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request
......@@ -29,7 +30,8 @@ from vllm.v1.request import Request
def make_request(request_id,
prompt_token_ids,
mm_positions=None,
mm_hashes=None):
mm_hashes=None,
cache_salt=None):
if mm_positions is None:
multi_modal_inputs = None
else:
......@@ -45,6 +47,7 @@ def make_request(request_id,
eos_token_id=100,
arrival_time=0,
lora_request=None,
cache_salt=cache_salt,
)
......@@ -52,21 +55,39 @@ def new_kv_cache_spec(block_size=16,
num_kv_heads=2,
head_size=64,
dtype=torch.float32,
use_mla=False):
use_mla=False,
sliding_window=None):
return FullAttentionSpec(block_size=block_size,
num_kv_heads=num_kv_heads,
head_size=head_size,
dtype=dtype,
use_mla=use_mla)
use_mla=use_mla,
sliding_window=sliding_window)
def test_none_hash():
assert NONE_HASH is not None
assert isinstance(NONE_HASH, int)
assert NONE_HASH != 0
def test_none_hash(monkeypatch):
import vllm.v1.core.kv_cache_utils
# case 1: PYTHONHASHSEED is not set, use random
with monkeypatch.context() as m:
m.delenv('PYTHONHASHSEED', raising=False)
reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils)
assert reloaded_kv_cache_utils.NONE_HASH is not None
assert isinstance(reloaded_kv_cache_utils.NONE_HASH, int)
assert reloaded_kv_cache_utils.NONE_HASH != 0
# case 2: PYTHONHASHSEED is set, use the seed
with monkeypatch.context() as m:
m.setenv('PYTHONHASHSEED', 'python hash seed')
reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils)
assert reloaded_kv_cache_utils.NONE_HASH is not None
assert isinstance(reloaded_kv_cache_utils.NONE_HASH, int)
assert sha256('python hash seed') == reloaded_kv_cache_utils.NONE_HASH
def test_kv_cache_block():
import vllm.v1.core.kv_cache_utils
# Test KVCacheBlock initialization
block = KVCacheBlock(block_id=0)
assert block.block_id == 0
......@@ -80,7 +101,8 @@ def test_kv_cache_block():
assert block.ref_cnt == 0
# Test block hash setting and resetting
block_hash = BlockHashType(hash_value=123, token_ids=(1, 2, 3))
block_hash = vllm.v1.core.kv_cache_utils.BlockHashType(hash_value=123,
token_ids=(1, 2, 3))
block.block_hash = block_hash
assert block.block_hash == block_hash
......@@ -213,15 +235,55 @@ def test_generate_block_hash_extra_keys_no_mm_inputs():
assert next_mm_idx == 0
def test_generate_block_hash_extra_keys_cache_salt():
request = make_request(
request_id=0,
prompt_token_ids=[_ for _ in range(6)],
mm_positions=None,
mm_hashes=None,
cache_salt="salt",
)
# salt is added for the first token
extra_keys, _ = generate_block_hash_extra_keys(request, 0, 1, 0)
assert extra_keys == ('salt', )
extra_keys, _ = generate_block_hash_extra_keys(request, 0, 10, 0)
assert extra_keys == ('salt', )
# no salt added for other tokens
extra_keys, _ = generate_block_hash_extra_keys(request, 1, 2, 0)
assert extra_keys is None
extra_keys, _ = generate_block_hash_extra_keys(request, 6, 10, 0)
assert extra_keys is None
# works together with other extra keys
request_mm = make_request(
request_id=0,
prompt_token_ids=[_ for _ in range(20)],
mm_positions=[
PlaceholderRange(offset=0, length=5),
],
mm_hashes=["hash1"],
cache_salt="salt",
)
# Test with no extra keys
extra_keys, next_mm_idx = generate_block_hash_extra_keys(
request_mm, 0, 5, 0)
assert extra_keys == ("hash1", "salt")
assert next_mm_idx == 1
@pytest.mark.parametrize("hash_fn", [sha256, hash])
def test_hash_block_tokens(hash_fn):
import vllm.v1.core.kv_cache_utils
parent_block_hash = 123
curr_block_token_ids = (1, 2, 3)
extra_keys = ("key1", "key2")
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
curr_block_token_ids, extra_keys)
assert isinstance(block_hash, BlockHashType)
assert isinstance(block_hash, vllm.v1.core.kv_cache_utils.BlockHashType)
assert block_hash.hash_value == hash_fn(
(parent_block_hash, curr_block_token_ids, extra_keys))
assert block_hash.token_ids == curr_block_token_ids
......@@ -230,6 +292,7 @@ def test_hash_block_tokens(hash_fn):
@pytest.mark.parametrize("hash_fn", [sha256, hash])
def test_hash_request_tokens(hash_fn):
import vllm.v1.core.kv_cache_utils
request = make_request(
request_id=0,
prompt_token_ids=[_ for _ in range(6)],
......@@ -244,8 +307,10 @@ def test_hash_request_tokens(hash_fn):
block_hashes = hash_request_tokens(hash_fn, block_size, request)
assert len(block_hashes) == 2
assert isinstance(block_hashes[0], BlockHashType)
assert isinstance(block_hashes[1], BlockHashType)
assert isinstance(block_hashes[0],
vllm.v1.core.kv_cache_utils.BlockHashType)
assert isinstance(block_hashes[1],
vllm.v1.core.kv_cache_utils.BlockHashType)
# Check the first block
assert block_hashes[0].token_ids == (0, 1, 2)
......@@ -430,6 +495,68 @@ def test_unify_kv_cache_configs():
unify_kv_cache_configs(diff_kv_cache_config)
def test_merge_kv_cache_spec():
same_layer_specs = [
new_kv_cache_spec(num_kv_heads=32),
new_kv_cache_spec(num_kv_heads=32),
]
merged_layer_spec = same_layer_specs[0].merge(same_layer_specs)
assert merged_layer_spec.block_size == 16
assert merged_layer_spec.num_kv_heads == 32
assert merged_layer_spec.head_size == 64
assert merged_layer_spec.dtype == torch.float32
assert merged_layer_spec.sliding_window is None
different_layer_specs = [
new_kv_cache_spec(num_kv_heads=32),
new_kv_cache_spec(num_kv_heads=16),
]
with pytest.raises(AssertionError):
different_layer_specs[0].merge(different_layer_specs)
full_spec = new_kv_cache_spec(num_kv_heads=32)
different_type_layer_specs = [
full_spec,
SlidingWindowSpec(
block_size=full_spec.block_size,
num_kv_heads=full_spec.num_kv_heads,
head_size=full_spec.head_size,
dtype=full_spec.dtype,
use_mla=full_spec.use_mla,
sliding_window=1,
),
]
with pytest.raises(AssertionError):
different_type_layer_specs[0].merge(different_type_layer_specs)
with pytest.raises(AssertionError):
different_type_layer_specs[1].merge(different_type_layer_specs)
different_sliding_window_layer_specs = [
new_kv_cache_spec(num_kv_heads=32),
new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
new_kv_cache_spec(num_kv_heads=32, sliding_window=2),
]
with pytest.raises(ValueError):
different_sliding_window_layer_specs[0].merge(
different_sliding_window_layer_specs)
same_sliding_window_layer_specs = [
new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
]
merged_layer_spec = same_sliding_window_layer_specs[0].merge(
same_sliding_window_layer_specs)
assert merged_layer_spec.sliding_window == 1
same_sliding_window_layer_spec_with_none = [
new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
new_kv_cache_spec(num_kv_heads=32, sliding_window=None),
]
merged_layer_spec = same_sliding_window_layer_spec_with_none[0].merge(
same_sliding_window_layer_spec_with_none)
assert merged_layer_spec.sliding_window == 1
@pytest.mark.parametrize(
("model_id", "max_model_len", "want_estimated_max_len"), [
("Qwen/Qwen1.5-7B", 16385, 16384),
......@@ -498,10 +625,10 @@ def test_allocate_with_lookahead():
max_model_len=100)
blocks = kv_cache_manager.allocate_slots(
request,
num_tokens=3,
num_new_tokens=3,
num_lookahead_tokens=2, # Total required: 3+2=5 tokens
)
assert len(blocks) == 2 # ceil(5/4)=2 blocks
assert len(blocks.blocks) == 2 # ceil(5/4)=2 blocks
# Test case 2: With precomputed blocks
kv_cache_manager = KVCacheManager(kv_cache_config=config,
......@@ -509,10 +636,10 @@ def test_allocate_with_lookahead():
# required_blocks = ceil((3 + 2) /4) = 2
blocks = kv_cache_manager.allocate_slots(
request,
num_tokens=3,
num_new_tokens=3,
num_lookahead_tokens=2,
)
assert len(blocks) == 2
assert len(blocks.blocks) == 2
# Test case 3: With precomputed blocks
# required_blocks = ceil((3 + 4) / 4) = 2
......@@ -520,7 +647,7 @@ def test_allocate_with_lookahead():
max_model_len=100)
blocks = kv_cache_manager.allocate_slots(
request,
num_tokens=3,
num_new_tokens=3,
num_lookahead_tokens=4,
)
assert len(blocks) == 2
assert len(blocks.blocks) == 2
......@@ -6,6 +6,7 @@ from typing import Optional
import pytest
import torch
from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.utils import sha256
......@@ -14,14 +15,15 @@ from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
hash_block_tokens)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec)
KVCacheGroupSpec, SlidingWindowSpec)
def make_request(request_id,
prompt_token_ids,
mm_positions=None,
mm_hashes=None,
prompt_logprobs: Optional[int] = None):
prompt_logprobs: Optional[int] = None,
cache_salt: Optional[str] = None):
if mm_positions is None:
multi_modal_inputs = None
else:
......@@ -38,6 +40,7 @@ def make_request(request_id,
eos_token_id=100,
arrival_time=0,
lora_request=None,
cache_salt=cache_salt,
)
......@@ -46,9 +49,10 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
num_blocks=num_blocks,
tensors={},
kv_cache_groups=[
KVCacheGroupSpec(['layer'],
FullAttentionSpec(block_size, 1, 1, torch.float32,
False))
KVCacheGroupSpec(
["layer"],
FullAttentionSpec(block_size, 1, 1, torch.float32, False),
)
],
)
......@@ -75,10 +79,12 @@ def test_prefill(hash_algo):
req0 = make_request("0", all_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert len(manager.req_to_block_hashes[req0.request_id]) == 3
assert not computed_blocks
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [1, 2, 3, 4]
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
# Check full block metadata
parent_block_hash = None
......@@ -101,12 +107,14 @@ def test_prefill(hash_algo):
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert [b.block_id for b in computed_blocks] == [1, 2, 3]
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [5]
for block in computed_blocks:
blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[5]]
for block in computed_blocks.blocks:
assert block.ref_cnt == 2
# At this point, we should have 5 free blocks left.
......@@ -133,11 +141,13 @@ def test_prefill(hash_algo):
req2 = make_request("2", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(manager.req_to_block_hashes[req2.request_id]) == 3
assert [b.block_id for b in computed_blocks] == [1, 2, 3]
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [6]
blocks = manager.allocate_slots(req2, num_new_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[6]]
# Although we only have 6 free blocks, we have 8 blocks in
# the free block queue due to lazy removal.
......@@ -155,11 +165,13 @@ def test_prefill(hash_algo):
# Cache miss and eviction.
req3 = make_request("3", [99] * (16 * 10))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert not computed_blocks
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req3, 16 * 10, computed_blocks)
blocks = manager.allocate_slots(req3, 16 * 10,
len(computed_blocks.blocks) * 16,
computed_blocks)
# This block ID order also checks the eviction order.
assert [b.block_id for b in blocks] == [7, 8, 9, 10, 4, 5, 6, 3, 2, 1]
assert blocks.get_block_ids() == [[7, 8, 9, 10, 4, 5, 6, 3, 2, 1]]
assert manager.block_pool.free_block_queue.num_free_blocks == 0
assert manager.block_pool.free_block_queue.free_list_head is None
assert manager.block_pool.free_block_queue.free_list_tail is None
......@@ -190,12 +202,14 @@ def test_prefill_plp():
all_token_ids = common_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids, prompt_logprobs=5)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert len(manager.req_to_block_hashes[req0.request_id]) == 3
assert not computed_blocks
assert len(manager.req_to_block_hashes[req0.request_id]) == 0
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [1, 2, 3, 4]
req0_block_hashes = [b.block_hash for b in blocks]
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
req0_block_hashes = [b.block_hash for b in blocks.blocks]
# Check full block metadata
parent_block_hash = None
......@@ -219,12 +233,14 @@ def test_prefill_plp():
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert [b.block_id for b in computed_blocks] == [1, 2, 3]
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [5]
for block in computed_blocks:
blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[5]]
for block in computed_blocks.blocks:
assert block.ref_cnt == 2
# At this point, we should have 5 free blocks left.
......@@ -252,18 +268,20 @@ def test_prefill_plp():
common_token_ids + unique_token_ids,
prompt_logprobs=5)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(manager.req_to_block_hashes[req2.request_id]) == 3
assert not computed_blocks
assert len(manager.req_to_block_hashes[req2.request_id]) == 0
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req2, 55, computed_blocks)
block_ids = [b.block_id for b in blocks]
blocks = manager.allocate_slots(req2, 55,
len(computed_blocks.blocks) * 16,
computed_blocks)
block_ids = blocks.get_block_ids()
# Duplicate cached blocks have different ids but same hashes vs request #0
assert [b.block_hash for b in blocks] == req0_block_hashes
assert block_ids != [1, 2, 3, 4]
assert [b.block_hash for b in blocks.blocks] == req0_block_hashes
assert block_ids != [[1, 2, 3, 4]]
# Request #2 block hashes are valid since request #0 hashes are.
# Check block reference counts.
for block_id in block_ids:
for block_id in block_ids[0]:
assert manager.block_pool.blocks[block_id].ref_cnt == 1
manager.free(req2)
......@@ -284,18 +302,23 @@ def test_decode():
unique_token_ids = [3] * 7
req0 = make_request("0", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [1, 2, 3, 4]
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
# Append slots without allocating a new block.
req0.num_computed_tokens = 55
for _ in range(4):
req0.append_output_token_ids(8)
new_blocks = manager.allocate_slots(req0, 4)
assert new_blocks is not None and len(new_blocks) == 0
assert manager.req_to_blocks[req0.request_id][-1].block_hash is None
new_blocks = manager.allocate_slots(req0, 4,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert new_blocks is not None and len(new_blocks.blocks) == 0
assert manager.single_type_manager.req_to_blocks[
req0.request_id][-1].block_hash is None
# Append slots with allocating a new block.
req0.num_computed_tokens = 59
......@@ -303,10 +326,14 @@ def test_decode():
# the preallocated block.
for _ in range(9 + 10):
req0.append_output_token_ids(7)
new_blocks = manager.allocate_slots(req0, 19)
assert new_blocks is not None and len(new_blocks) == 1
assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None
assert manager.req_to_blocks[req0.request_id][-1].block_hash is None
new_blocks = manager.allocate_slots(req0, 19,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert new_blocks is not None and len(new_blocks.blocks) == 1
assert manager.single_type_manager.req_to_blocks[
req0.request_id][-2].block_hash is not None
assert manager.single_type_manager.req_to_blocks[
req0.request_id][-1].block_hash is None
def test_evict():
......@@ -319,19 +346,23 @@ def test_evict():
last_token_id = 5 * 16 + 7
req0 = make_request("0", list(range(last_token_id)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks)
assert len(blocks) == 6 # 5 full + 1 partial
blocks = manager.allocate_slots(req0, 5 * 16 + 7,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 6 # 5 full + 1 partial
# 3 blocks.
req1 = make_request("1", list(range(last_token_id,
last_token_id + 3 * 16)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req1, 3 * 16, computed_blocks)
assert len(blocks) == 3 # 3 full blocks
blocks = manager.allocate_slots(req1, 3 * 16,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 3 # 3 full blocks
last_token_id += 3 * 16
# 10 - (6 + 3) == 1
......@@ -348,10 +379,12 @@ def test_evict():
# Touch the first 2 blocks.
req2 = make_request("2", list(range(2 * 16 + 3)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert [b.block_id for b in computed_blocks] == [1, 2]
assert computed_blocks.get_block_ids() == [[1, 2]]
assert num_computed_tokens == 2 * 16
blocks = manager.allocate_slots(req2, 3, computed_blocks)
assert [b.block_id for b in blocks] == [10]
blocks = manager.allocate_slots(req2, 3,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[10]]
assert manager.block_pool.free_block_queue.num_free_blocks == 7
......@@ -371,10 +404,12 @@ def test_hash_block_correct_reuse():
num_tokens = block_size * 1
req = make_request("0", list(range(num_tokens)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req, num_tokens, computed_blocks)
assert len(blocks) == 1
blocks = manager.allocate_slots(req, num_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 1
# Deallocate the block.
manager.free(req)
......@@ -383,12 +418,15 @@ def test_hash_block_correct_reuse():
# block is cleared.
req = make_request("1", list(range(num_tokens - 1)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks)
assert len(blocks) == 1
blocks = manager.allocate_slots(req, num_tokens - 1,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 1
assert manager.block_pool.blocks[blocks[0].block_id].block_hash is None
assert manager.block_pool.blocks[
blocks.blocks[0].block_id].block_hash is None
def test_computed_blocks_not_evicted():
......@@ -407,20 +445,24 @@ def test_computed_blocks_not_evicted():
num_tokens = block_size * 1
req0 = make_request("0", list(range(num_tokens)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, num_tokens, computed_blocks)
assert len(blocks) == 1
assert blocks[0].block_id == 1
blocks = manager.allocate_slots(req0, num_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 1
assert blocks.blocks[0].block_id == 1
# Allocate another block.
req1 = make_request("1", list(range(num_tokens, num_tokens * 2)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req1, num_tokens, computed_blocks)
assert len(blocks) == 1
assert blocks[0].block_id == 2
blocks = manager.allocate_slots(req1, num_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 1
assert blocks.blocks[0].block_id == 2
# Free the blocks.
manager.free(req0)
......@@ -430,14 +472,15 @@ def test_computed_blocks_not_evicted():
# cached block rather than the first one.
req2 = make_request("2", list(range(num_tokens * 2)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(computed_blocks) == 1
assert computed_blocks[0].block_id == 1
assert len(computed_blocks.blocks) == 1
assert computed_blocks.blocks[0].block_id == 1
assert num_computed_tokens == block_size
blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks) == 1
assert blocks[0].block_id == 2
assert len(blocks.blocks) == 1
assert blocks.blocks[0].block_id == 2
def test_basic_prefix_caching_disabled():
......@@ -454,10 +497,12 @@ def test_basic_prefix_caching_disabled():
req1 = make_request("1", list(range(10))) # 2 blocks and some more
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req1, 10, computed_blocks)
assert len(blocks) == 3
blocks = manager.allocate_slots(req1, 10,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 3
# Free the blocks.
manager.free(req1)
......@@ -465,17 +510,21 @@ def test_basic_prefix_caching_disabled():
# No caching.
req2 = make_request("2", list(range(16))) # shared prefix
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert not computed_blocks
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req2, 16, computed_blocks)
assert len(blocks) == 4
blocks = manager.allocate_slots(req2, 16,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 4
# New requests should not have any blocks.
req3 = make_request("3", list(range(4)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert not computed_blocks
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req3, 4, computed_blocks)
blocks = manager.allocate_slots(req3, 4,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert not blocks
......@@ -565,7 +614,7 @@ def test_mm_prefix_caching():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
# Completed block should have hashes with extra keys.
assert not computed_blocks
assert not computed_blocks.blocks
assert num_computed_tokens == 0
block_hashes = manager.req_to_block_hashes[req0.request_id]
assert len(block_hashes) == 3
......@@ -573,15 +622,19 @@ def test_mm_prefix_caching():
assert block_hashes[1].extra_keys == ("aaa", "bbb")
assert block_hashes[2].extra_keys == ("bbb", )
blocks = manager.allocate_slots(req0, 59, computed_blocks)
assert [b.block_id for b in blocks] == [1, 2, 3, 4]
blocks = manager.allocate_slots(req0, 59,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
req0.num_computed_tokens = 59
# Append slots without allocating a new block.
for _ in range(5):
req0.append_output_token_ids(8)
new_blocks = manager.allocate_slots(req0, 5)
assert new_blocks is not None and len(new_blocks) == 0
new_blocks = manager.allocate_slots(req0, 5,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert new_blocks is not None and len(new_blocks.blocks) == 0
# The just completed block should have hashes with extra keys.
assert len(block_hashes) == 4
......@@ -599,10 +652,74 @@ def test_mm_prefix_caching():
mm_positions=mm_positions,
mm_hashes=mm_hashes)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(computed_blocks) == 3
assert len(computed_blocks.blocks) == 3
assert num_computed_tokens == 3 * 16
def test_cache_key_salting():
"""
This tests that cache salts are applied during hashing and the cache
is separated cache as expected.
"""
block_size = 16
manager = KVCacheManager(
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
)
# 3 complete blocks and an incomplete block with 11 tokens.
common_token_ids = [i for i in range(3) for _ in range(block_size)]
token_ids = common_token_ids + [3] * 11
req0 = make_request("0", token_ids, cache_salt="salt1")
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
# Completed block should have hashes with extra keys.
assert not computed_blocks.blocks
assert num_computed_tokens == 0
block_hashes = manager.req_to_block_hashes[req0.request_id]
assert len(block_hashes) == 3
assert block_hashes[0].extra_keys == ("salt1", )
assert block_hashes[1].extra_keys is None
assert block_hashes[2].extra_keys is None
blocks = manager.allocate_slots(req0, 59,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
req0.num_computed_tokens = 59
# Append slots without allocating a new block.
for _ in range(5):
req0.append_output_token_ids(8)
new_blocks = manager.allocate_slots(req0, 5,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert new_blocks is not None and len(new_blocks.blocks) == 0
# Now one more block that should not have extra keys.
assert len(block_hashes) == 4
assert block_hashes[3].extra_keys is None
# Test cache hit with a new request that has the same salt.
token_ids = common_token_ids + [4] * 11
req1 = make_request("1", token_ids, cache_salt="salt1")
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
# Should match only a prefix of 3 blocks.
assert len(computed_blocks.blocks) == 3
assert num_computed_tokens == 3 * block_size
# Test cache miss with same content but different salt.
token_ids = common_token_ids + [4] * 11
req2 = make_request("2", token_ids, cache_salt="salt2")
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(computed_blocks.blocks) == 0
assert num_computed_tokens == 0
block_hashes = manager.req_to_block_hashes[req2.request_id]
assert len(block_hashes) == 3
assert block_hashes[0].extra_keys == ("salt2", )
def test_prefill_not_enough_free_blocks_with_computed_blocks():
"""
This is a unit test that tests the correctness of the allocate_slots
......@@ -621,18 +738,20 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
common_token_ids = [i for i in range(3) for _ in range(16)]
req0 = make_request("0", common_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks
assert not computed_blocks.blocks
assert num_computed_tokens == 0
manager.allocate_slots(req0, 48, computed_blocks)
block_part0 = manager.req_to_blocks[req0.request_id]
manager.allocate_slots(req0, 48,
len(computed_blocks.blocks) * 16, computed_blocks)
block_part0 = manager.single_type_manager.req_to_blocks[req0.request_id]
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
req1 = make_request("1", common_token_ids * 2)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert computed_blocks == block_part0
assert computed_blocks.blocks == block_part0
assert num_computed_tokens == 3 * 16
manager.allocate_slots(req1, 48, computed_blocks)
block_part1 = manager.req_to_blocks[req1.request_id]
manager.allocate_slots(req1, 48,
len(computed_blocks.blocks) * 16, computed_blocks)
block_part1 = manager.single_type_manager.req_to_blocks[req1.request_id]
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
# | Req1-5(F)| ... |
manager.free(req1)
......@@ -643,9 +762,10 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# | Req1-5(F)| Req2-0 | Req2-1 | ... |
req2 = make_request("2", [7] * block_size * 2)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert not computed_blocks
assert not computed_blocks.blocks
assert num_computed_tokens == 0
manager.allocate_slots(req2, block_size * 2, computed_blocks)
manager.allocate_slots(req2, block_size * 2,
len(computed_blocks.blocks) * 16, computed_blocks)
# Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed,
# but it cannot be allocated due to insufficient free blocks (2).
......@@ -653,10 +773,12 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
assert manager.block_pool.free_block_queue.num_free_blocks == 5
req3 = make_request("3", common_token_ids * 3)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert computed_blocks == block_part1
assert computed_blocks.blocks == block_part1
assert num_computed_tokens == 6 * 16
# Req3 cannot be allocated.
assert manager.allocate_slots(req3, 48, computed_blocks) is None
assert manager.allocate_slots(req3, 48,
len(computed_blocks.blocks) * 16,
computed_blocks) is None
# Block 0-2 are used by Req 1.
assert {block.ref_cnt for block in block_part1[:3]} == {1}
# Block 3-5 are free.
......@@ -675,16 +797,18 @@ def test_reset_prefix_cache():
all_token_ids = full_block_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids)
blocks = manager.allocate_slots(req0, 55)
assert [b.block_id for b in blocks] == [1, 2, 3, 4]
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
unique_token_ids = [4] * 7
all_token_ids = full_block_token_ids + unique_token_ids
req1 = make_request("1", all_token_ids)
computed_blocks, _ = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert len(computed_blocks) == 3
blocks = manager.allocate_slots(req1, 7, computed_blocks)
assert [b.block_id for b in blocks] == [5]
assert len(computed_blocks.blocks) == 3
blocks = manager.allocate_slots(req1, 7,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[5]]
# Failed to reset prefix cache because some blocks are not freed yet.
assert not manager.reset_prefix_cache()
......@@ -712,15 +836,70 @@ def test_prefix_cache_stats_disabled():
# Call all functions that check whether log_stats is disabled.
req = make_request("0", list(range(16)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks
assert not computed_blocks.blocks
assert num_computed_tokens == 0
manager.allocate_slots(req, 16, computed_blocks)
manager.allocate_slots(req, 16,
len(computed_blocks.blocks) * 16, computed_blocks)
manager.reset_prefix_cache()
# Ensure prefix_cache_stats remains None
assert manager.prefix_cache_stats is None
@pytest.mark.parametrize("blocks_to_cache", [2, 3, 10])
def test_kv_cache_events(blocks_to_cache: int):
block_size = 16
num_blocks = blocks_to_cache + 1
# Allocate Blocks
# Should see a single block stored event with a blocks_to_cache number of
# block hashes
# take_events should reset the kv_event_queue
manager = KVCacheManager(
make_kv_cache_config(block_size, num_blocks),
max_model_len=8192,
enable_caching=True,
enable_kv_cache_events=True,
)
num_tokens = block_size * blocks_to_cache
req0 = make_request("0", list(range(num_tokens)))
_ = manager.allocate_slots(req0, num_tokens)
events = manager.take_events()
block = events[-1]
assert (len(block.block_hashes) == blocks_to_cache == len(
manager.block_pool.cached_block_hash_to_block))
assert len(block.token_ids) == block.block_size * len(block.block_hashes)
assert len(manager.block_pool.kv_event_queue) == 0
stored_block_hash = block.block_hashes
# Remove blocks and send another request
# Should see block_to_cache number of removed block events and a new block
# stored event
manager.free(req0)
req1 = make_request("1", list(range(num_tokens)))
_ = manager.allocate_slots(req1, num_tokens)
events = manager.take_events()
for blocks in events[:-1]:
assert blocks.block_hashes[0] in stored_block_hash
assert len(events) == blocks_to_cache + 1
assert (isinstance(events[-2], BlockRemoved))
assert (len(events[-1].block_hashes) == blocks_to_cache == len(
manager.block_pool.cached_block_hash_to_block))
# All Blocks Cleared
# Should see a single all blocks cleared event
manager.free(req1)
manager.reset_prefix_cache()
events = manager.take_events()
assert isinstance(events[-1], AllBlocksCleared)
assert len(manager.block_pool.cached_block_hash_to_block) == 0
def test_eagle_enabled_removes_last_block():
"""Verify Eagle does NOT remove blocks when request
length is divisible by block size."""
......@@ -738,18 +917,19 @@ def test_eagle_enabled_removes_last_block():
# Prime the cache
computed_blocks, _ = manager.get_computed_blocks(req)
manager.allocate_slots(req, len(token_ids), computed_blocks)
manager.allocate_slots(req, len(token_ids),
len(computed_blocks.blocks) * 16, computed_blocks)
manager.free(req)
# New request with same tokens + Eagle enabled
req_eagle = make_request("eagle_divisible", token_ids)
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
# Should retain 2 blocks:
# Should retain 1 block:
# 1. Original 3 blocks → pop last hash → 2 matched blocks
# 2. last_block_hash is not None → Eagle pop is not SKIPPED
assert len(computed_blocks) == 1
assert num_tokens == 1 * block_size # 32 tokens
# 2. drop last matched block → 1 remaining block
assert len(computed_blocks.blocks) == 1
assert num_tokens == 1 * block_size # 16 tokens
def test_eagle_with_partial_blocks():
......@@ -767,12 +947,70 @@ def test_eagle_with_partial_blocks():
# Prime the cache
computed_blocks, _ = manager.get_computed_blocks(req)
manager.allocate_slots(req, len(token_ids), computed_blocks)
manager.allocate_slots(req, len(token_ids),
len(computed_blocks.blocks) * 16, computed_blocks)
manager.free(req)
# New request with Eagle enabled
req_eagle = make_request("partial_eagle", token_ids)
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
assert len(computed_blocks.blocks) == 1
assert num_tokens == 1 * block_size
def test_eagle_with_sliding_window():
"""Test Eagle behavior with sliding window."""
block_size = 16
sliding_window_spec = SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=block_size,
use_mla=False,
)
manager = KVCacheManager(
KVCacheConfig(
num_blocks=10,
tensors={},
kv_cache_groups=[KVCacheGroupSpec(['layer'], sliding_window_spec)],
),
max_model_len=8192,
enable_caching=True,
use_eagle=True,
)
# 2 full blocks + 5 tokens (non-divisible length)
token_ids = [0] * (2 * block_size + 5)
req = make_request("partial_block_test", token_ids)
# Prime the cache
computed_blocks, _ = manager.get_computed_blocks(req)
manager.allocate_slots(req, len(token_ids),
len(computed_blocks.blocks) * 16, computed_blocks)
# record the block hash of the first block in the request for later use
block_hash_first_block = manager.req_to_block_hashes[req.request_id][0]
assert block_hash_first_block is not None
manager.free(req)
# New request with Eagle enabled
req_eagle = make_request("partial_eagle", token_ids)
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
assert len(computed_blocks) == 1
assert len(computed_blocks.blocks) == 1
assert num_tokens == 1 * block_size
# Evict the first block in the request
assert manager.block_pool.get_cached_block(
block_hash_first_block) is not None
manager.block_pool.cached_block_hash_to_block.pop(block_hash_first_block)
# New request
req_after_evict = make_request("partial_eagle_after_evict", token_ids)
computed_blocks, num_tokens = manager.get_computed_blocks(req_after_evict)
# Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is
# not considered. But after dropping the last matched block due to eagle,
# there will be no matched prefix.
assert len(computed_blocks.blocks) == 0
assert num_tokens == 0
......@@ -44,7 +44,7 @@ def create_scheduler(
(None)
Returns:
:class:`Scheduler` instance
{class}`Scheduler` instance
'''
if max_model_len is None:
max_model_len = max_num_batched_tokens
......@@ -812,10 +812,11 @@ def _assert_right_kv_cache_manager(
# Make sure the request stats are right.
EXPECTED_TOTAL_BLOCKS = num_tokens // block_size
for req_id in req_ids:
blocks = scheduler.kv_cache_manager.req_to_blocks[req_id]
blocks = (scheduler.kv_cache_manager.single_type_manager.
req_to_blocks[req_id])
hashes = scheduler.kv_cache_manager.req_to_block_hashes[req_id]
assert (scheduler.kv_cache_manager.num_cached_block[req_id] ==
EXPECTED_TOTAL_BLOCKS)
assert (scheduler.kv_cache_manager.single_type_manager.
num_cached_block[req_id] == EXPECTED_TOTAL_BLOCKS)
assert len(blocks) == EXPECTED_TOTAL_BLOCKS
assert len(hashes) == EXPECTED_TOTAL_BLOCKS
......@@ -869,7 +870,7 @@ def test_kv_connector_basic():
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
scheduler.connector.get_num_new_matched_tokens.return_value = (
NUM_MATCHED_NEW_TOKENS)
NUM_MATCHED_NEW_TOKENS, False)
######################################################
# FIRST SET OF REQUESTS - External Hit Only
......@@ -980,7 +981,7 @@ def test_kv_connector_unable_to_allocate():
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
scheduler.connector.get_num_new_matched_tokens.return_value = (
NUM_MATCHED_NEW_TOKENS)
NUM_MATCHED_NEW_TOKENS, False)
# Create two requests. The second request will not be able to
# allocate slots because it will not have enough blocks.
......@@ -1059,7 +1060,7 @@ def test_kv_connector_handles_preemption():
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
scheduler.connector.get_num_new_matched_tokens.return_value = (
NUM_MATCHED_NEW_TOKENS)
NUM_MATCHED_NEW_TOKENS, False)
# Create two requests.
# Both can be scheduled at first, but the second request
......@@ -1195,9 +1196,11 @@ def assert_scheduler_empty(scheduler: Scheduler):
assert len(scheduler.encoder_cache_manager.cached) == 0
# KVCache Manager.
assert len(scheduler.kv_cache_manager.req_to_blocks) == 0
assert len(
scheduler.kv_cache_manager.single_type_manager.req_to_blocks) == 0
assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0
assert len(scheduler.kv_cache_manager.num_cached_block) == 0
assert len(
scheduler.kv_cache_manager.single_type_manager.num_cached_block) == 0
num_free_blocks = (
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
assert num_free_blocks == (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment