Unverified Commit 5206e5e2 authored by Harry Huang's avatar Harry Huang Committed by GitHub
Browse files

[V1][Hybrid] Mamba Prefix Caching with align mode (#30877)


Signed-off-by: default avatarhuanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: default avatarChen Zhang <zhangch99@outlook.com>
Co-authored-by: default avatarChen Zhang <zhangch99@outlook.com>
parent fec9da0a
......@@ -24,7 +24,7 @@ pytestmark = pytest.mark.cpu_test
def get_sliding_window_manager(sliding_window_spec, block_pool, enable_caching=True):
return SlidingWindowManager(
sliding_window_spec,
block_pool,
block_pool=block_pool,
enable_caching=enable_caching,
kv_cache_group_id=0,
)
......@@ -35,7 +35,7 @@ def get_chunked_local_attention_manager(
):
return ChunkedLocalAttentionManager(
chunked_local_attention_spec,
block_pool,
block_pool=block_pool,
enable_caching=enable_caching,
kv_cache_group_id=0,
)
......@@ -342,11 +342,15 @@ def test_get_num_blocks_to_allocate():
]
assert (
manager.get_num_blocks_to_allocate("1", 20 * block_size, cached_blocks_1, 0)
manager.get_num_blocks_to_allocate(
"1", 20 * block_size, cached_blocks_1, 0, 20 * block_size
)
== 20
)
assert (
manager.get_num_blocks_to_allocate("2", 20 * block_size, cached_blocks_2, 0)
manager.get_num_blocks_to_allocate(
"2", 20 * block_size, cached_blocks_2, 0, 20 * block_size
)
== 15
)
......@@ -375,6 +379,7 @@ def test_evictable_cached_blocks_not_double_allocated():
num_tokens=2 * block_size,
new_computed_blocks=[evictable_block],
total_computed_tokens=block_size,
num_tokens_main_model=2 * block_size,
)
# Free capacity check should count evictable cached blocks, but allocation
# should only allocate the truly new block.
......@@ -386,7 +391,9 @@ def test_evictable_cached_blocks_not_double_allocated():
num_local_computed_tokens=block_size,
num_external_computed_tokens=0,
)
new_blocks = manager.allocate_new_blocks(request_id, num_tokens=4)
new_blocks = manager.allocate_new_blocks(
request_id, num_tokens=4, num_tokens_main_model=4
)
assert len(new_blocks) == 1
assert len(manager.req_to_blocks[request_id]) == 2
......@@ -411,10 +418,14 @@ def test_chunked_local_attention_get_num_blocks_to_allocate():
]
assert (
manager.get_num_blocks_to_allocate("1", 20 * block_size, cached_blocks_1, 0)
manager.get_num_blocks_to_allocate(
"1", 20 * block_size, cached_blocks_1, 0, 20 * block_size
)
== 20
)
assert (
manager.get_num_blocks_to_allocate("2", 20 * block_size, cached_blocks_2, 0)
manager.get_num_blocks_to_allocate(
"2", 20 * block_size, cached_blocks_2, 0, 20 * block_size
)
== 15
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import multiprocessing as mp
import os
import traceback
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
import datasets
import pytest
import torch
from vllm import LLM, SamplingParams, TokensPrompt
from vllm.config import CacheConfig
from vllm.model_executor.layers.mamba.mamba_utils import MambaStateCopyFunc
from vllm.sequence import IntermediateTensors
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.engine.core_client import InprocClient
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import SamplerOutput
from vllm.v1.request import Request
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.worker import mamba_utils
from vllm.v1.worker.gpu_input_batch import CachedRequestState
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.lora_model_runner_mixin import GPUInputBatch
from vllm.v1.worker.mamba_utils import get_mamba_groups
@dataclass
class StepAction:
num_computed_tokens_start: int
num_scheduled_tokens: int
kv_cache_block_ids: list[int] # [] to follow last step
preprocess_copy_idx: tuple[int, int] # -1, -1 for no copy
postprocess_copy_idx: tuple[int, int] # -1, -1 for no copy
num_speculative_tokens = 3
num_accepted_tokens = 1
prompt_token_ids: list[int] = []
MODEL = "Qwen/Qwen3-Next-80B-A3B-Instruct-FP8"
BLOCK_SIZE = 560
NUM_HIDDEN_LAYERS = 1
cur_step_action_idx = 0
cur_step_action: StepAction | None = None
step_actions: list[StepAction] = []
def get_fake_sample_fn() -> SamplerOutput:
def fake_sample_fn(
self: GPUModelRunner,
logits: torch.Tensor | None,
spec_decode_metadata: SpecDecodeMetadata | None,
) -> SamplerOutput:
assert logits is not None
num_computed_tokens_cpu_tensor = self.input_batch.num_computed_tokens_cpu_tensor
num_computed_tokens = num_computed_tokens_cpu_tensor[0].item()
if num_computed_tokens < self.input_batch.num_prompt_tokens[0].item():
first_token_id_index = self.input_batch.num_prompt_tokens[0].item()
else:
first_token_id_index = num_computed_tokens + 1
if spec_decode_metadata is None:
return SamplerOutput(
sampled_token_ids=torch.tensor(
[[prompt_token_ids[first_token_id_index]]],
device="cuda",
dtype=torch.int32,
),
logprobs_tensors=None,
)
num_sampled_tokens = spec_decode_metadata.cu_num_sampled_tokens[0].item() + 1
accpeted_tokens = prompt_token_ids[
first_token_id_index : first_token_id_index
+ min(num_accepted_tokens, logits.shape[0])
]
sampled_token_ids = accpeted_tokens + [-1] * (
num_sampled_tokens - len(accpeted_tokens)
)
return SamplerOutput(
sampled_token_ids=torch.tensor(
[sampled_token_ids], device="cuda", dtype=torch.int32
),
logprobs_tensors=None,
)
return fake_sample_fn
def get_fake_propose_draft_token_ids_fn():
def fake_propose_draft_token_ids_fn(
self: GPUModelRunner,
scheduler_output: SchedulerOutput,
sampled_token_ids: torch.Tensor | list[list[int]],
sampling_metadata: SamplingMetadata,
hidden_states: torch.Tensor,
sample_hidden_states: torch.Tensor,
aux_hidden_states: list[torch.Tensor] | None,
spec_decode_metadata: SpecDecodeMetadata | None,
common_attn_metadata: CommonAttentionMetadata,
) -> list[list[int]]:
num_computed_tokens_cpu_tensor = self.input_batch.num_computed_tokens_cpu_tensor
num_computed_tokens = num_computed_tokens_cpu_tensor[0].item()
if (
self.input_batch.num_tokens_no_spec[0].item()
<= self.input_batch.num_prompt_tokens[0].item()
):
first_token_id_index = self.input_batch.num_prompt_tokens[0].item()
else:
first_token_id_index = (
num_computed_tokens + 1
) # bonus token isn't considered as computed
first_token_id_index += self.input_batch.num_accepted_tokens_cpu[0].item()
proposed_draft_token_ids = [
prompt_token_ids[
first_token_id_index : first_token_id_index + num_speculative_tokens
]
]
return proposed_draft_token_ids
return fake_propose_draft_token_ids_fn
def get_fake_step_action_fn(original_step_action_fn: Callable):
def fake_get_output(self: InprocClient):
global cur_step_action_idx
global cur_step_action
if cur_step_action_idx < len(step_actions):
cur_step_action = step_actions[cur_step_action_idx]
cur_step_action_idx += 1
else:
cur_step_action = None
print(f"cur_step_action: {cur_step_action_idx=} {cur_step_action=}")
return original_step_action_fn(self)
return fake_get_output
def get_fake_allocate_slots_fn(original_allocate_slots_fn: Callable):
def fake_allocate_slots_fn(
self: KVCacheManager,
request: Request,
num_new_tokens: int,
num_new_computed_tokens: int = 0,
new_computed_blocks: KVCacheBlocks | None = None,
num_lookahead_tokens: int = 0,
num_external_computed_tokens: int = 0,
delay_cache_blocks: bool = False,
num_encoder_tokens: int = 0,
):
ret = original_allocate_slots_fn(
self,
request,
num_new_tokens,
num_new_computed_tokens,
new_computed_blocks,
num_lookahead_tokens,
num_external_computed_tokens,
delay_cache_blocks,
num_encoder_tokens,
)
if cur_step_action is not None:
cur_block_ids = self.coordinator.single_type_managers[0].req_to_blocks[
request.request_id
]
not_null_block_flags = [not block.is_null for block in cur_block_ids]
block_ids = [1 if block else 0 for block in not_null_block_flags]
assert block_ids == cur_step_action.kv_cache_block_ids
return ret
return fake_allocate_slots_fn
mamba_kv_cache_dict = {}
def get_fake_execute_model_fn(original_execute_model_fn: Callable):
last_num_computed_tokens = 0
def fake_execute_model_fn(
self: GPUModelRunner,
scheduler_output: SchedulerOutput,
intermediate_tensors: IntermediateTensors | None = None,
):
if cur_step_action is not None:
num_scheduled_tokens = next(
iter(scheduler_output.num_scheduled_tokens.values())
)
assert num_scheduled_tokens == cur_step_action.num_scheduled_tokens
mamba_group_ids, mamba_spec = get_mamba_groups(self.kv_cache_config)
mamba_group_id = mamba_group_ids[0]
mamba_layer_name = self.kv_cache_config.kv_cache_groups[
mamba_group_id
].layer_names[0]
nonlocal last_num_computed_tokens
if len(scheduler_output.scheduled_cached_reqs.req_ids) > 0:
num_computed_tokens = (
scheduler_output.scheduled_cached_reqs.num_computed_tokens[0]
)
if (
num_computed_tokens // BLOCK_SIZE
> last_num_computed_tokens // BLOCK_SIZE
):
# generated a new aligned block in this step
block_idx = num_computed_tokens // mamba_spec.block_size - 1
block_id = (
self.input_batch.block_table.block_tables[mamba_group_id]
.block_table.cpu[0, block_idx]
.item()
)
if block_id != 0:
kv_cache = self.compilation_config.static_forward_context[
mamba_layer_name
].kv_cache
mamba_kv_cache_dict[
num_computed_tokens - num_computed_tokens % BLOCK_SIZE
] = (
kv_cache[0][0][block_id].clone(),
kv_cache[0][1][block_id].clone(),
)
last_num_computed_tokens = num_computed_tokens
else:
last_num_computed_tokens = 0
ret = original_execute_model_fn(self, scheduler_output, intermediate_tensors)
if cur_step_action is not None:
assert (
cur_step_action.num_computed_tokens_start
== self.input_batch.num_computed_tokens_cpu[0].item()
)
return ret
return fake_execute_model_fn
def get_fake_process_mamba_fn(
original_preprocess_mamba_fn: Callable,
original_post_process_mamba_fn: Callable,
original_copy_fn: Callable,
):
copy_info: tuple[list[int], list[int], list[int]] | None = None
def check_copy_info(
action: tuple[int, int],
kv_cache_config: KVCacheConfig,
forward_context: dict[str, Any],
input_batch: GPUInputBatch,
):
assert copy_info is not None
if action == (-1, -1):
assert len(copy_info[0]) == len(copy_info[1]) == len(copy_info[2]) == 0
else:
assert len(copy_info[0]) == len(copy_info[1]) == len(copy_info[2]) == 2
mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config)
mamba_group_id = mamba_group_ids[0]
mamba_layer_name = kv_cache_config.kv_cache_groups[
mamba_group_id
].layer_names[0]
mamba_kv_cache = forward_context[mamba_layer_name].kv_cache[0][-1]
mamba_block_table = input_batch.block_table.block_tables[
mamba_group_id
].block_table.cpu[0]
expected_temporal_src = mamba_kv_cache[
mamba_block_table[action[0]]
].data_ptr()
expected_temporal_dest = mamba_kv_cache[
mamba_block_table[action[1]]
].data_ptr()
# -1 is qwen3-next's temporal. We skip checking conv as it is more complex.
assert copy_info[0][-1] == expected_temporal_src
assert copy_info[1][-1] == expected_temporal_dest
def fake_preprocess_mamba_fn(
scheduler_output: SchedulerOutput,
kv_cache_config: KVCacheConfig,
cache_config: CacheConfig,
mamba_state_idx: dict[str, int],
input_batch: GPUInputBatch,
requests: dict[str, CachedRequestState],
forward_context: dict[str, Any],
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
):
nonlocal copy_info
copy_info = None
ret = original_preprocess_mamba_fn(
scheduler_output,
kv_cache_config,
cache_config,
mamba_state_idx,
input_batch,
requests,
forward_context,
mamba_state_copy_funcs,
)
if cur_step_action is not None:
check_copy_info(
cur_step_action.preprocess_copy_idx,
kv_cache_config,
forward_context,
input_batch,
)
return ret
def fake_post_process_mamba_fn(
scheduler_output: SchedulerOutput,
kv_cache_config: KVCacheConfig,
input_batch: GPUInputBatch,
requests: dict[str, CachedRequestState],
mamba_state_idx: dict[str, int],
forward_context: dict[str, Any],
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
):
nonlocal copy_info
copy_info = None
ret = original_post_process_mamba_fn(
scheduler_output,
kv_cache_config,
input_batch,
requests,
mamba_state_idx,
forward_context,
mamba_state_copy_funcs,
)
if cur_step_action is not None:
check_copy_info(
cur_step_action.postprocess_copy_idx,
kv_cache_config,
forward_context,
input_batch,
)
return ret
def fake_copy_fn(
src_state_list: list[int],
dest_state_list: list[int],
num_elements_list: list[int],
):
nonlocal copy_info
assert copy_info is None
copy_info = (src_state_list, dest_state_list, num_elements_list)
return original_copy_fn(
src_state_list,
dest_state_list,
num_elements_list,
)
return fake_preprocess_mamba_fn, fake_post_process_mamba_fn, fake_copy_fn
def run_ref_mamba_state_in_subprocess() -> None:
ctx = mp.get_context("spawn")
proc = ctx.Process(target=_run_ref_mamba_state_worker)
proc.start()
proc.join(timeout=600)
if proc.exitcode != 0:
raise RuntimeError(f"Ref mamba state process exited with code {proc.exitcode}.")
def _run_ref_mamba_state_worker():
try:
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
num_generated_tokens = 8000
num_prompt_tokens = 500
sampling_params = SamplingParams(
temperature=0.0, max_tokens=num_generated_tokens
)
prompt_dataset = datasets.load_dataset("heheda/a_long_article")
full_prompt = prompt_dataset["train"][0]["text"]
fake_execute_model_fn = get_fake_execute_model_fn(GPUModelRunner.execute_model)
GPUModelRunner.execute_model = fake_execute_model_fn
fake_sample_fn = get_fake_sample_fn()
GPUModelRunner._sample = fake_sample_fn
engine = LLM(
model=MODEL,
block_size=BLOCK_SIZE,
hf_overrides={"num_hidden_layers": NUM_HIDDEN_LAYERS},
seed=42,
)
global prompt_token_ids
prompt_token_ids = engine.get_tokenizer().encode(full_prompt)
print(f"Token IDs length: {len(prompt_token_ids)}")
_outputs = engine.generate(
[TokensPrompt(prompt_token_ids=prompt_token_ids[:num_prompt_tokens])],
sampling_params,
)
# ref_mamba_kv_cache_dict = torch.load("mamba_kv_cache_dict.pth")
# check_mamba_state_equal(ref_mamba_kv_cache_dict, mamba_kv_cache_dict)
# torch.save(mamba_kv_cache_dict, "mamba_kv_cache_dict.pth")
cpu_state_ref = {
key: tuple(tensor.detach().cpu() for tensor in tensors)
for key, tensors in mamba_kv_cache_dict.items()
}
torch.save(cpu_state_ref, "mamba_kv_cache_dict_ref.pth")
mamba_kv_cache_dict.clear()
except Exception:
traceback.print_exc()
raise
def check_mamba_state_equal(
mamba_state_ref: dict, mamba_state_new: dict, keys_to_check: list[int]
):
atol = 1e-2
rtol = 1e-2
for key in keys_to_check:
assert key in mamba_state_new
assert key in mamba_state_ref
# mamba state new is a subset of mamba state ref
for i, (ref, new) in enumerate(zip(mamba_state_ref[key], mamba_state_new[key])):
if ref.device != new.device:
new = new.to(ref.device)
new = new[: ref.shape[0]]
if not torch.allclose(ref, new, atol=atol, rtol=rtol):
diff_mask = ~torch.isclose(ref, new, atol=atol, rtol=rtol)
diff_idx = torch.nonzero(diff_mask)
if diff_idx.shape[0] * 100 < ref.numel():
print(
f"[WARNING] found {diff_idx.shape[0] * 100 / ref.numel()}% of the elements are different" # noqa: E501
)
continue
raise ValueError(
f"Mamba state is not equal for key: {key} at index {i}"
)
return True
@dataclass
class TestConfig:
num_prompt_tokens: int
num_generated_tokens: int
num_accepted_tokens: int
step_actions: list[StepAction]
def apply_patch(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
fake_sample_fn = get_fake_sample_fn()
monkeypatch.setattr(GPUModelRunner, "_sample", fake_sample_fn)
fake_propose_draft_token_ids_fn = get_fake_propose_draft_token_ids_fn()
monkeypatch.setattr(
GPUModelRunner, "propose_draft_token_ids", fake_propose_draft_token_ids_fn
)
fake_execute_model_fn = get_fake_execute_model_fn(GPUModelRunner.execute_model)
monkeypatch.setattr(GPUModelRunner, "execute_model", fake_execute_model_fn)
fake_step_action_fn = get_fake_step_action_fn(InprocClient.get_output)
monkeypatch.setattr(InprocClient, "get_output", fake_step_action_fn)
fake_allocate_slots_fn = get_fake_allocate_slots_fn(KVCacheManager.allocate_slots)
monkeypatch.setattr(KVCacheManager, "allocate_slots", fake_allocate_slots_fn)
fake_preprocess_mamba_fn, fake_post_process_mamba_fn, fake_copy_fn = (
get_fake_process_mamba_fn(
mamba_utils.preprocess_mamba,
mamba_utils.postprocess_mamba,
mamba_utils.do_mamba_copy_block,
)
)
monkeypatch.setattr(mamba_utils, "preprocess_mamba", fake_preprocess_mamba_fn)
monkeypatch.setattr(mamba_utils, "postprocess_mamba", fake_post_process_mamba_fn)
monkeypatch.setattr(mamba_utils, "do_mamba_copy_block", fake_copy_fn)
@pytest.mark.skip(
reason="Skipping test_mamba_prefix_cache because it is based on spec "
"decode which is not allowed now."
)
def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
run_ref_mamba_state_in_subprocess()
apply_patch(monkeypatch)
prompt_dataset = datasets.load_dataset("heheda/a_long_article")
full_prompt = prompt_dataset["train"][0]["text"]
tests = {
"accept_1": TestConfig(
num_prompt_tokens=554,
num_generated_tokens=20,
num_accepted_tokens=1,
step_actions=[
StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(554, 4, [], (-1, -1), (-1, -1)),
StepAction(555, 4, [], (-1, -1), (-1, -1)),
StepAction(556, 4, [], (-1, -1), (-1, -1)),
StepAction(557, 4, [1, 1, 1, 1, 1], (0, 1), (-1, -1)),
StepAction(558, 4, [], (-1, -1), (-1, -1)),
StepAction(559, 4, [], (-1, -1), (1, 0)),
StepAction(560, 4, [], (-1, -1), (-1, -1)),
StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
],
),
# test case 2.1: no hit, accept 2 tokens
"accept_2_1": TestConfig(
num_prompt_tokens=554,
num_generated_tokens=20,
num_accepted_tokens=2,
step_actions=[
StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(554, 4, [], (-1, -1), (-1, -1)),
StepAction(556, 4, [], (-1, -1), (-1, -1)),
StepAction(558, 4, [1, 1, 1, 1, 1], (1, 1), (2, 0)),
StepAction(560, 4, [], (-1, -1), (-1, -1)),
StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
],
),
# test case 2.2: no hit, accept 2 tokens
"accept_2_2": TestConfig(
num_prompt_tokens=555,
num_generated_tokens=20,
num_accepted_tokens=2,
step_actions=[
StepAction(0, 555, [1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(555, 4, [], (-1, -1), (-1, -1)),
StepAction(557, 4, [1, 1, 1, 1, 1], (1, 1), (-1, -1)),
StepAction(559, 4, [], (-1, -1), (1, 0)),
StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
],
),
"accept_3_1": TestConfig(
num_prompt_tokens=553,
num_generated_tokens=20,
num_accepted_tokens=3,
step_actions=[
StepAction(0, 553, [1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(553, 4, [], (-1, -1), (-1, -1)),
StepAction(556, 4, [], (-1, -1), (-1, -1)),
StepAction(559, 4, [1, 1, 1, 1, 1], (2, 1), (1, 0)),
StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
],
),
"accept_3_2": TestConfig(
num_prompt_tokens=554,
num_generated_tokens=20,
num_accepted_tokens=3,
step_actions=[
StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(554, 4, [], (-1, -1), (-1, -1)),
StepAction(557, 4, [1, 1, 1, 1, 1], (2, 1), (3, 0)),
StepAction(560, 4, [], (-1, -1), (-1, -1)),
StepAction(563, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
],
),
"accept_3_3": TestConfig(
num_prompt_tokens=555,
num_generated_tokens=20,
num_accepted_tokens=3,
step_actions=[
StepAction(0, 555, [1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(555, 4, [], (-1, -1), (-1, -1)),
StepAction(558, 4, [1, 1, 1, 1, 1], (2, 1), (2, 0)),
StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
],
),
"accept_4_1": TestConfig(
num_prompt_tokens=553,
num_generated_tokens=20,
num_accepted_tokens=4,
step_actions=[
StepAction(0, 553, [1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(553, 4, [], (-1, -1), (-1, -1)),
StepAction(557, 4, [1, 1, 1, 1, 1], (3, 1), (3, 0)),
StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(565, 4, [], (-1, -1), (-1, -1)),
],
),
"accept_4_2": TestConfig(
num_prompt_tokens=554,
num_generated_tokens=25,
num_accepted_tokens=4,
step_actions=[
StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(554, 4, [], (-1, -1), (-1, -1)),
StepAction(558, 4, [1, 1, 1, 1, 1], (3, 1), (2, 0)),
StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(566, 4, [], (-1, -1), (-1, -1)),
],
),
"accept_4_3": TestConfig(
num_prompt_tokens=555,
num_generated_tokens=25,
num_accepted_tokens=4,
step_actions=[
StepAction(0, 555, [1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(555, 4, [], (-1, -1), (-1, -1)),
StepAction(559, 4, [1, 1, 1, 1, 1], (3, 1), (1, 0)),
StepAction(563, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
],
),
"accept_4_4": TestConfig(
num_prompt_tokens=556,
num_generated_tokens=25,
num_accepted_tokens=4,
step_actions=[
StepAction(0, 556, [1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(556, 4, [], (-1, -1), (3, 0)),
StepAction(560, 4, [1, 1, 1, 1, 1], (0, 1), (-1, -1)),
StepAction(564, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
],
),
"prompt_block_size": TestConfig(
num_prompt_tokens=560,
num_generated_tokens=10,
num_accepted_tokens=4,
step_actions=[
StepAction(0, 560, [1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(560, 4, [1, 1, 1, 1, 1], (0, 1), (-1, -1)),
],
),
"prompt_2_block_size": TestConfig(
num_prompt_tokens=560 * 2,
num_generated_tokens=10,
num_accepted_tokens=4,
step_actions=[
StepAction(0, 560, [1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(560, 560, [1, 1, 1, 1, 1], (0, 1), (-1, -1)),
StepAction(560 * 2, 4, [0, 1, 1, 1, 1, 1], (1, 2), (-1, -1)),
],
),
"prompt_2_block_size_10": TestConfig(
num_prompt_tokens=560 * 2 + 10,
num_generated_tokens=10,
num_accepted_tokens=4,
step_actions=[
StepAction(0, 560, [1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(560, 570, [1, 0, 1, 1, 1, 1], (0, 2), (-1, -1)),
StepAction(560 * 2 + 10, 4, [0, 0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
],
),
"prompt_3_block_size": TestConfig(
num_prompt_tokens=560 * 3,
num_generated_tokens=10,
num_accepted_tokens=4,
step_actions=[
StepAction(0, 560 * 2, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(560 * 2, 560, [0, 1, 1, 1, 1, 1], (1, 2), (-1, -1)),
StepAction(560 * 3, 4, [0, 0, 1, 1, 1, 1, 1], (2, 3), (-1, -1)),
],
),
"prompt_3_block_size_10": TestConfig(
num_prompt_tokens=560 * 3 + 10,
num_generated_tokens=10,
num_accepted_tokens=4,
step_actions=[
StepAction(0, 560 * 2, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(560 * 2, 570, [0, 1, 0, 1, 1, 1, 1], (1, 3), (-1, -1)),
StepAction(560 * 3 + 10, 4, [0, 0, 0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
],
),
"prompt_10_block_size": TestConfig(
num_prompt_tokens=560 * 10,
num_generated_tokens=10,
num_accepted_tokens=4,
step_actions=[
StepAction(0, 560 * 5, [0, 0, 0, 0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(
560 * 5,
560 * 4,
[0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1],
(4, 8),
(-1, -1),
),
StepAction(
560 * 9,
560,
[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
(8, 9),
(-1, -1),
),
StepAction(
560 * 10,
4,
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
(9, 10),
(-1, -1),
),
],
),
"prompt_10_block_size_10": TestConfig(
num_prompt_tokens=560 * 10 + 10,
num_generated_tokens=10,
num_accepted_tokens=4,
step_actions=[
StepAction(0, 560 * 5, [0, 0, 0, 0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
StepAction(
560 * 5,
560 * 4,
[0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1],
(4, 8),
(-1, -1),
),
StepAction(
560 * 9,
560 + 10,
[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1],
(8, 10),
(-1, -1),
),
],
),
}
engine = LLM(
model=MODEL,
enable_prefix_caching=True,
block_size=BLOCK_SIZE,
mamba_cache_mode="align",
speculative_config={
"method": "qwen3_next_mtp",
"num_speculative_tokens": num_speculative_tokens,
},
max_num_batched_tokens=3072,
hf_overrides={"num_hidden_layers": NUM_HIDDEN_LAYERS},
seed=42,
)
global prompt_token_ids
prompt_token_ids = engine.get_tokenizer().encode(full_prompt)
print(f"Token IDs length: {len(prompt_token_ids)}")
for test_case_name, test_config in tests.items():
print(f"Running test case: {test_case_name}")
num_generated_tokens = test_config.num_generated_tokens
num_prompt_tokens = test_config.num_prompt_tokens
global num_accepted_tokens
num_accepted_tokens = test_config.num_accepted_tokens
sampling_params = SamplingParams(
temperature=0.0, max_tokens=num_generated_tokens
)
global cur_step_action_idx
cur_step_action_idx = 0
for step_action_prev, step_action_next in zip(
test_config.step_actions[:-1], test_config.step_actions[1:]
):
if (
step_action_next.kv_cache_block_ids is not None
and len(step_action_next.kv_cache_block_ids) == 0
):
prev_block_ids = step_action_prev.kv_cache_block_ids
if prev_block_ids is not None:
step_action_next.kv_cache_block_ids = prev_block_ids.copy()
global step_actions
step_actions = test_config.step_actions
_ = engine.generate(
[TokensPrompt(prompt_token_ids=prompt_token_ids[:num_prompt_tokens])],
sampling_params,
)
assert engine.llm_engine.engine_core.engine_core.scheduler.reset_prefix_cache()
print(f"End test case: {test_case_name}")
keys_to_check = [
(action.postprocess_copy_idx[1] + 1) * BLOCK_SIZE
for action in test_config.step_actions
if action.postprocess_copy_idx and action.postprocess_copy_idx[0] != -1
]
mamba_state_ref = torch.load("mamba_kv_cache_dict_ref.pth")
check_mamba_state_equal(mamba_state_ref, mamba_kv_cache_dict, keys_to_check)
mamba_kv_cache_dict.clear()
......@@ -31,6 +31,7 @@ CacheDType = Literal[
"fp8_ds_mla",
]
MambaDType = Literal["auto", "float32", "float16"]
MambaCacheMode = Literal["all", "align", "none"]
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor", "xxhash", "xxhash_cbor"]
KVOffloadingBackend = Literal["native", "lmcache"]
......@@ -123,6 +124,15 @@ class CacheConfig:
"""The data type to use for the Mamba cache (ssm state only, conv state will
still be controlled by mamba_cache_dtype). If set to 'auto', the data type
for the ssm state will be determined by mamba_cache_dtype."""
mamba_cache_mode: MambaCacheMode = "none"
"""The cache strategy for Mamba layers.
- "none": set when prefix caching is disabled.
- "all": cache the mamba state of all tokens at position i * block_size. This is
the default behavior (for models that support it) when prefix caching is
enabled.
- "align": only cache the mamba state of the last token of each scheduler step and
when the token is at position i * block_size.
"""
# Will be set after profiling.
num_gpu_blocks: int | None = field(default=None, init=False)
......
......@@ -999,6 +999,17 @@ class VllmConfig:
# Default to enable HMA if not explicitly disabled by user or logic above.
self.scheduler_config.disable_hybrid_kv_cache_manager = False
if self.cache_config.mamba_cache_mode == "align":
if self.scheduler_config.long_prefill_token_threshold > 0:
assert (
self.scheduler_config.long_prefill_token_threshold
>= self.cache_config.block_size
)
assert not self.scheduler_config.disable_chunked_mm_input, (
"Chunked MM input is required because we need the flexibility to "
"schedule a multiple of block_size tokens even if they are in the "
"middle of a mm input"
)
if self.compilation_config.debug_dump_path:
self.compilation_config.debug_dump_path = (
self.compilation_config.debug_dump_path.absolute().expanduser()
......
......@@ -60,6 +60,7 @@ from vllm.config.cache import (
BlockSize,
CacheDType,
KVOffloadingBackend,
MambaCacheMode,
MambaDType,
PrefixCachingHashAlgo,
)
......@@ -556,6 +557,7 @@ class EngineArgs:
mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
mamba_block_size: int | None = get_field(CacheConfig, "mamba_block_size")
mamba_cache_mode: MambaCacheMode = CacheConfig.mamba_cache_mode
additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config")
......@@ -939,6 +941,9 @@ class EngineArgs:
cache_group.add_argument(
"--mamba-block-size", **cache_kwargs["mamba_block_size"]
)
cache_group.add_argument(
"--mamba-cache-mode", **cache_kwargs["mamba_cache_mode"]
)
cache_group.add_argument(
"--kv-offloading-size", **cache_kwargs["kv_offloading_size"]
)
......@@ -1416,6 +1421,7 @@ class EngineArgs:
mamba_cache_dtype=self.mamba_cache_dtype,
mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
mamba_block_size=self.mamba_block_size,
mamba_cache_mode=self.mamba_cache_mode,
kv_offloading_size=self.kv_offloading_size,
kv_offloading_backend=self.kv_offloading_backend,
)
......
......@@ -56,6 +56,7 @@ class MambaBase(AttentionLayerBase):
block_size=mamba_block_size,
page_size_padded=page_size_padded,
mamba_type=self.mamba_type,
mamba_cache_mode=vllm_config.cache_config.mamba_cache_mode,
num_speculative_blocks=(
vllm_config.speculative_config.num_speculative_tokens
if vllm_config.speculative_config
......
......@@ -255,7 +255,7 @@ class MambaMixer(MambaBase, CustomOp):
assert self.cache_config is not None
mamba_block_size = self.cache_config.mamba_block_size
prefix_caching_enabled = self.cache_config.enable_prefix_caching
is_mamba_cache_all = self.cache_config.mamba_cache_mode == "all"
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
......@@ -304,7 +304,7 @@ class MambaMixer(MambaBase, CustomOp):
state_indices_tensor_p = prefill_decode_split.state_indices_tensor_p
state_indices_tensor_d = prefill_decode_split.state_indices_tensor_d
if prefix_caching_enabled:
if is_mamba_cache_all:
block_idx_last_computed_token_d, block_idx_last_computed_token_p = (
torch.split(
attn_metadata.block_idx_last_computed_token,
......@@ -380,7 +380,7 @@ class MambaMixer(MambaBase, CustomOp):
ssm_outputs.append(scan_out_p)
if has_decode:
if prefix_caching_enabled:
if is_mamba_cache_all:
state_indices_tensor_d_input = state_indices_tensor_d.gather(
1, block_idx_last_computed_token_d.unsqueeze(1)
).squeeze(1)
......
......@@ -570,7 +570,7 @@ class MambaMixer2(MambaBase, CustomOp):
assert self.cache_config is not None
mamba_block_size = self.cache_config.mamba_block_size
prefix_caching_enabled = self.cache_config.enable_prefix_caching
is_mamba_cache_all = self.cache_config.mamba_cache_mode == "all"
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
......@@ -622,7 +622,7 @@ class MambaMixer2(MambaBase, CustomOp):
dim=0,
)
if prefix_caching_enabled:
if is_mamba_cache_all:
# If prefix caching is enabled, retrieve the relevant variables
# for prefill and decode
block_idx_last_computed_token_d, block_idx_last_computed_token_p = (
......@@ -701,7 +701,7 @@ class MambaMixer2(MambaBase, CustomOp):
initial_states = None
if has_initial_states_p is not None and prep_initial_states:
kernel_ssm_indices = state_indices_tensor_p
if prefix_caching_enabled:
if is_mamba_cache_all:
kernel_ssm_indices = state_indices_tensor_p.gather(
1, block_idx_last_computed_token_p.unsqueeze(1)
).squeeze(1)
......@@ -729,14 +729,14 @@ class MambaMixer2(MambaBase, CustomOp):
cu_chunk_seqlens=cu_chunk_seqlen_p,
last_chunk_indices=last_chunk_indices_p,
initial_states=initial_states,
return_intermediate_states=prefix_caching_enabled,
return_intermediate_states=is_mamba_cache_all,
dt_softplus=True,
dt_limit=(0.0, float("inf")),
out=preallocated_ssm_out_p.view(num_prefill_tokens, -1, self.head_dim),
state_dtype=ssm_state.dtype,
)
if prefix_caching_enabled:
if is_mamba_cache_all:
# The chunk_stride is the number of chunks per mamba block
# e.g., if mamba_block_size = 512 and chunk_size = 256,
# then chunk_stride = 2
......@@ -815,7 +815,7 @@ class MambaMixer2(MambaBase, CustomOp):
# Process decode requests
if has_decode:
if prefix_caching_enabled:
if is_mamba_cache_all:
state_indices_tensor_d_input = state_indices_tensor_d.gather(
1, block_idx_last_computed_token_d.unsqueeze(1)
).squeeze(1)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from dataclasses import dataclass
from typing import TypeAlias
import torch
from vllm.config.cache import MambaDType
......@@ -223,3 +227,94 @@ class MambaStateShapeCalculator:
conv_state_k_shape,
recurrent_state_shape,
)
@dataclass
class MambaCopySpec:
"""
Data class specifying the memory-copy parameters for Mamba states used for
prefix caching in align mode.
Attributes:
start_addr (int): Starting address for the memory copy operation.
num_elements (int): Number of elements to copy from the starting address.
"""
start_addr: int
num_elements: int
MambaStateCopyFunc: TypeAlias = Callable[
[torch.Tensor, list[int], int, int], MambaCopySpec
]
"""
Type alias for a function that computes a MambaCopySpec for copying state slices.
Parameters:
state: torch.Tensor - the Mamba state tensor (e.g., conv or temporal states).
block_ids: list[int] - the list of block indices for the state to copy.
cur_block_idx: int - current block index within `block_ids` to copy from.
num_accepted_tokens: int - number of accepted tokens used to compute the copy offset.
Range: 1 .. 1 + num_speculative_tokens (inclusive).
"""
def get_conv_copy_spec(
state: torch.Tensor,
block_ids: list[int],
cur_block_idx: int,
num_accepted_tokens: int,
) -> MambaCopySpec:
"""Return a MambaCopySpec for copying a convolutional state slice."""
src_block_id = block_ids[cur_block_idx]
src_state = state[src_block_id, num_accepted_tokens - 1 :]
return MambaCopySpec(
start_addr=src_state.data_ptr(), num_elements=src_state.numel()
)
def get_temporal_copy_spec(
state: torch.Tensor,
block_ids: list[int],
cur_block_idx: int,
num_accepted_tokens: int,
) -> MambaCopySpec:
"""Return a MambaCopySpec for copying a temporal state slice."""
src_block_id = block_ids[cur_block_idx + num_accepted_tokens - 1]
src_state = state[src_block_id]
return MambaCopySpec(
start_addr=src_state.data_ptr(), num_elements=src_state.numel()
)
get_full_copy_spec = get_temporal_copy_spec
class MambaStateCopyFuncCalculator:
@classmethod
def linear_attention_state_copy_func(cls):
return (get_temporal_copy_spec,)
@classmethod
def mamba1_state_copy_func(cls):
return (get_conv_copy_spec, get_temporal_copy_spec)
@classmethod
def mamba2_state_copy_func(cls):
return get_conv_copy_spec, get_temporal_copy_spec
@classmethod
def short_conv_state_copy_func(cls):
return (get_conv_copy_spec,)
@classmethod
def gated_delta_net_state_copy_func(cls):
return (get_conv_copy_spec, get_temporal_copy_spec)
@classmethod
def kda_state_copy_func(cls):
return (
get_conv_copy_spec,
get_conv_copy_spec,
get_conv_copy_spec,
get_temporal_copy_spec,
)
......@@ -24,6 +24,8 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
......@@ -455,6 +457,10 @@ class BambaForCausalLM(
conv_kernel=hf_config.mamba_d_conv,
)
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.mamba2_state_copy_func()
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config
......
......@@ -330,26 +330,54 @@ class MambaModelConfig(VerifyAndUpdateConfig):
cache_config = vllm_config.cache_config
if cache_config.enable_prefix_caching:
if model_config.supports_mamba_prefix_caching:
logger.info(
"Warning: Prefix caching is currently enabled. "
"Its support for Mamba layers is experimental. "
"Please report any issues you may observe."
if cache_config.mamba_cache_mode == "none":
cache_config.mamba_cache_mode = (
"all" if model_config.supports_mamba_prefix_caching else "align"
)
# By default, mamba block size will be set to max_model_len (see
# below). When enabling prefix caching, we align mamba block size
# to the block size as the basic granularity for prefix caching.
if cache_config.mamba_block_size is None:
cache_config.mamba_block_size = cache_config.block_size
else:
logger.info(
"Hybrid or mamba-based model detected without "
"support for prefix caching: disabling."
logger.warning(
"Mamba cache mode is set to '%s' for %s by default "
"when prefix caching is enabled",
cache_config.mamba_cache_mode,
model_config.architecture,
)
cache_config.enable_prefix_caching = False
if cache_config.mamba_block_size is None:
cache_config.mamba_block_size = model_config.max_model_len
if (
cache_config.mamba_cache_mode == "all"
and not model_config.supports_mamba_prefix_caching
):
cache_config.mamba_cache_mode = "align"
logger.warning(
"Hybrid or mamba-based model detected without support "
"for prefix caching with Mamba cache 'all' mode: "
"falling back to 'align' mode."
)
if cache_config.mamba_cache_mode == "align":
assert vllm_config.scheduler_config.enable_chunked_prefill, (
"Chunked prefill is required for mamba cache mode 'align'."
)
assert not vllm_config.speculative_config, (
"Mamba cache mode 'align' is currently not compatible "
"with speculative decoding."
)
logger.info(
"Warning: Prefix caching in Mamba cache '%s' "
"mode is currently enabled. "
"Its support for Mamba layers is experimental. "
"Please report any issues you may observe.",
cache_config.mamba_cache_mode,
)
# By default, mamba block size will be set to max_model_len (see
# below). When enabling prefix caching, we align mamba block size
# to the block size as the basic granularity for prefix caching.
if cache_config.mamba_block_size is None:
cache_config.mamba_block_size = cache_config.block_size
else:
if cache_config.mamba_cache_mode != "none":
cache_config.mamba_cache_mode = "none"
logger.warning(
"Mamba cache mode is set to 'none' when prefix caching is disabled"
)
if cache_config.mamba_block_size is None:
cache_config.mamba_block_size = model_config.max_model_len
class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
......@@ -426,7 +454,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
mamba_page_size = MambaSpec(
shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
block_size=model_config.max_model_len,
block_size=-1, # block_size doesn't matter for mamba page size
).page_size_bytes
# Model may be marked as is_hybrid
......@@ -435,7 +463,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
if mamba_page_size == 0:
return
if cache_config.enable_prefix_caching:
if cache_config.mamba_cache_mode == "all":
# With prefix caching, select attention block size to
# optimize for mamba kernel performance
......@@ -479,6 +507,13 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
attn_block_size,
)
# By default, mamba block size will be set to max_model_len.
# When enabling prefix caching and using align mamba cache
# mode, we align mamba block size to the block size as the
# basic granularity for prefix caching.
if cache_config.mamba_cache_mode == "align":
cache_config.mamba_block_size = cache_config.block_size
# compute new attention page size
attn_page_size = cache_config.block_size * attn_page_size_1_token
......
......@@ -24,6 +24,8 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
......@@ -551,6 +553,10 @@ class FalconH1ForCausalLM(
conv_kernel=hf_config.mamba_d_conv,
)
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.mamba2_state_copy_func()
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config
......
......@@ -19,6 +19,8 @@ from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLine
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
......@@ -641,6 +643,10 @@ class GraniteMoeHybridForCausalLM(
conv_kernel=hf_config.mamba_d_conv,
)
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.mamba2_state_copy_func()
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
......
......@@ -24,6 +24,7 @@ from vllm.config import ModelConfig, SpeechToTextConfig
from vllm.inputs import TokensPrompt
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.model_executor.layers.mamba.mamba_utils import MambaStateCopyFunc
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.utils.collection_utils import common_prefix
from vllm.utils.func_utils import supports_kw
......@@ -776,6 +777,19 @@ class IsHybrid(Protocol):
"""
...
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, ...]:
"""Calculate copy-function callables for each Mamba state.
Returns:
A tuple of MambaStateCopyFunc callables that correspond, in order,
to the Mamba states produced by the model. Each callable accepts
(state, block_ids, cur_block_idx, num_accepted_tokens) and returns
a MambaCopySpec describing the memory-copy parameters for prefix
caching in align mode.
"""
...
@overload
def is_hybrid(model: object) -> TypeIs[IsHybrid]: ...
......
......@@ -24,6 +24,8 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
......@@ -558,6 +560,10 @@ class JambaForCausalLM(
conv_kernel=hf_config.mamba_d_conv,
)
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.mamba1_state_copy_func()
def compute_logits(
self,
hidden_states: torch.Tensor,
......
......@@ -26,6 +26,8 @@ from vllm.model_executor.layers.linear import (
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
......@@ -544,6 +546,14 @@ class KimiLinearForCausalLM(
num_spec=num_spec,
)
@classmethod
def get_mamba_state_copy_func(
cls,
) -> tuple[
MambaStateCopyFunc, MambaStateCopyFunc, MambaStateCopyFunc, MambaStateCopyFunc
]:
return MambaStateCopyFuncCalculator.kda_state_copy_func()
def compute_logits(
self,
hidden_states: torch.Tensor,
......
......@@ -20,6 +20,8 @@ from vllm.model_executor.layers.linear import (
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
......@@ -459,14 +461,19 @@ class Lfm2ForCausalLM(
conv_kernel=hf_config.conv_L_cache,
)
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.short_conv_state_copy_func()
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
cache_config = vllm_config.cache_config
assert not cache_config.enable_prefix_caching, (
"Lfm2 currently does not support prefix caching"
)
if cache_config.mamba_cache_mode == "all":
raise NotImplementedError(
"Lfm2 currently does not support 'all' prefix caching, "
"please use '--mamba-cache-mode=align' instead"
)
super().__init__()
self.config = config
......
......@@ -25,6 +25,8 @@ from vllm.model_executor.layers.linear import (
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
......@@ -640,6 +642,10 @@ class Lfm2MoeForCausalLM(
conv_kernel=hf_config.conv_L_cache,
)
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.short_conv_state_copy_func()
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
......
......@@ -16,6 +16,8 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
......@@ -261,6 +263,10 @@ class MambaForCausalLM(
conv_kernel=hf_config.conv_kernel,
)
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.mamba1_state_copy_func()
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(input_buffers, **kwargs)
......
......@@ -15,6 +15,8 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
......@@ -228,6 +230,10 @@ class Mamba2ForCausalLM(
conv_kernel=hf_config.conv_kernel,
)
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.mamba2_state_copy_func()
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment