Unverified Commit 5206e5e2 authored by Harry Huang's avatar Harry Huang Committed by GitHub
Browse files

[V1][Hybrid] Mamba Prefix Caching with align mode (#30877)


Signed-off-by: default avatarhuanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: default avatarChen Zhang <zhangch99@outlook.com>
Co-authored-by: default avatarChen Zhang <zhangch99@outlook.com>
parent fec9da0a
...@@ -24,7 +24,7 @@ pytestmark = pytest.mark.cpu_test ...@@ -24,7 +24,7 @@ pytestmark = pytest.mark.cpu_test
def get_sliding_window_manager(sliding_window_spec, block_pool, enable_caching=True): def get_sliding_window_manager(sliding_window_spec, block_pool, enable_caching=True):
return SlidingWindowManager( return SlidingWindowManager(
sliding_window_spec, sliding_window_spec,
block_pool, block_pool=block_pool,
enable_caching=enable_caching, enable_caching=enable_caching,
kv_cache_group_id=0, kv_cache_group_id=0,
) )
...@@ -35,7 +35,7 @@ def get_chunked_local_attention_manager( ...@@ -35,7 +35,7 @@ def get_chunked_local_attention_manager(
): ):
return ChunkedLocalAttentionManager( return ChunkedLocalAttentionManager(
chunked_local_attention_spec, chunked_local_attention_spec,
block_pool, block_pool=block_pool,
enable_caching=enable_caching, enable_caching=enable_caching,
kv_cache_group_id=0, kv_cache_group_id=0,
) )
...@@ -342,11 +342,15 @@ def test_get_num_blocks_to_allocate(): ...@@ -342,11 +342,15 @@ def test_get_num_blocks_to_allocate():
] ]
assert ( assert (
manager.get_num_blocks_to_allocate("1", 20 * block_size, cached_blocks_1, 0) manager.get_num_blocks_to_allocate(
"1", 20 * block_size, cached_blocks_1, 0, 20 * block_size
)
== 20 == 20
) )
assert ( assert (
manager.get_num_blocks_to_allocate("2", 20 * block_size, cached_blocks_2, 0) manager.get_num_blocks_to_allocate(
"2", 20 * block_size, cached_blocks_2, 0, 20 * block_size
)
== 15 == 15
) )
...@@ -375,6 +379,7 @@ def test_evictable_cached_blocks_not_double_allocated(): ...@@ -375,6 +379,7 @@ def test_evictable_cached_blocks_not_double_allocated():
num_tokens=2 * block_size, num_tokens=2 * block_size,
new_computed_blocks=[evictable_block], new_computed_blocks=[evictable_block],
total_computed_tokens=block_size, total_computed_tokens=block_size,
num_tokens_main_model=2 * block_size,
) )
# Free capacity check should count evictable cached blocks, but allocation # Free capacity check should count evictable cached blocks, but allocation
# should only allocate the truly new block. # should only allocate the truly new block.
...@@ -386,7 +391,9 @@ def test_evictable_cached_blocks_not_double_allocated(): ...@@ -386,7 +391,9 @@ def test_evictable_cached_blocks_not_double_allocated():
num_local_computed_tokens=block_size, num_local_computed_tokens=block_size,
num_external_computed_tokens=0, num_external_computed_tokens=0,
) )
new_blocks = manager.allocate_new_blocks(request_id, num_tokens=4) new_blocks = manager.allocate_new_blocks(
request_id, num_tokens=4, num_tokens_main_model=4
)
assert len(new_blocks) == 1 assert len(new_blocks) == 1
assert len(manager.req_to_blocks[request_id]) == 2 assert len(manager.req_to_blocks[request_id]) == 2
...@@ -411,10 +418,14 @@ def test_chunked_local_attention_get_num_blocks_to_allocate(): ...@@ -411,10 +418,14 @@ def test_chunked_local_attention_get_num_blocks_to_allocate():
] ]
assert ( assert (
manager.get_num_blocks_to_allocate("1", 20 * block_size, cached_blocks_1, 0) manager.get_num_blocks_to_allocate(
"1", 20 * block_size, cached_blocks_1, 0, 20 * block_size
)
== 20 == 20
) )
assert ( assert (
manager.get_num_blocks_to_allocate("2", 20 * block_size, cached_blocks_2, 0) manager.get_num_blocks_to_allocate(
"2", 20 * block_size, cached_blocks_2, 0, 20 * block_size
)
== 15 == 15
) )
This diff is collapsed.
...@@ -31,6 +31,7 @@ CacheDType = Literal[ ...@@ -31,6 +31,7 @@ CacheDType = Literal[
"fp8_ds_mla", "fp8_ds_mla",
] ]
MambaDType = Literal["auto", "float32", "float16"] MambaDType = Literal["auto", "float32", "float16"]
MambaCacheMode = Literal["all", "align", "none"]
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor", "xxhash", "xxhash_cbor"] PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor", "xxhash", "xxhash_cbor"]
KVOffloadingBackend = Literal["native", "lmcache"] KVOffloadingBackend = Literal["native", "lmcache"]
...@@ -123,6 +124,15 @@ class CacheConfig: ...@@ -123,6 +124,15 @@ class CacheConfig:
"""The data type to use for the Mamba cache (ssm state only, conv state will """The data type to use for the Mamba cache (ssm state only, conv state will
still be controlled by mamba_cache_dtype). If set to 'auto', the data type still be controlled by mamba_cache_dtype). If set to 'auto', the data type
for the ssm state will be determined by mamba_cache_dtype.""" for the ssm state will be determined by mamba_cache_dtype."""
mamba_cache_mode: MambaCacheMode = "none"
"""The cache strategy for Mamba layers.
- "none": set when prefix caching is disabled.
- "all": cache the mamba state of all tokens at position i * block_size. This is
the default behavior (for models that support it) when prefix caching is
enabled.
- "align": only cache the mamba state of the last token of each scheduler step and
when the token is at position i * block_size.
"""
# Will be set after profiling. # Will be set after profiling.
num_gpu_blocks: int | None = field(default=None, init=False) num_gpu_blocks: int | None = field(default=None, init=False)
......
...@@ -999,6 +999,17 @@ class VllmConfig: ...@@ -999,6 +999,17 @@ class VllmConfig:
# Default to enable HMA if not explicitly disabled by user or logic above. # Default to enable HMA if not explicitly disabled by user or logic above.
self.scheduler_config.disable_hybrid_kv_cache_manager = False self.scheduler_config.disable_hybrid_kv_cache_manager = False
if self.cache_config.mamba_cache_mode == "align":
if self.scheduler_config.long_prefill_token_threshold > 0:
assert (
self.scheduler_config.long_prefill_token_threshold
>= self.cache_config.block_size
)
assert not self.scheduler_config.disable_chunked_mm_input, (
"Chunked MM input is required because we need the flexibility to "
"schedule a multiple of block_size tokens even if they are in the "
"middle of a mm input"
)
if self.compilation_config.debug_dump_path: if self.compilation_config.debug_dump_path:
self.compilation_config.debug_dump_path = ( self.compilation_config.debug_dump_path = (
self.compilation_config.debug_dump_path.absolute().expanduser() self.compilation_config.debug_dump_path.absolute().expanduser()
......
...@@ -60,6 +60,7 @@ from vllm.config.cache import ( ...@@ -60,6 +60,7 @@ from vllm.config.cache import (
BlockSize, BlockSize,
CacheDType, CacheDType,
KVOffloadingBackend, KVOffloadingBackend,
MambaCacheMode,
MambaDType, MambaDType,
PrefixCachingHashAlgo, PrefixCachingHashAlgo,
) )
...@@ -556,6 +557,7 @@ class EngineArgs: ...@@ -556,6 +557,7 @@ class EngineArgs:
mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
mamba_block_size: int | None = get_field(CacheConfig, "mamba_block_size") mamba_block_size: int | None = get_field(CacheConfig, "mamba_block_size")
mamba_cache_mode: MambaCacheMode = CacheConfig.mamba_cache_mode
additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config") additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config")
...@@ -939,6 +941,9 @@ class EngineArgs: ...@@ -939,6 +941,9 @@ class EngineArgs:
cache_group.add_argument( cache_group.add_argument(
"--mamba-block-size", **cache_kwargs["mamba_block_size"] "--mamba-block-size", **cache_kwargs["mamba_block_size"]
) )
cache_group.add_argument(
"--mamba-cache-mode", **cache_kwargs["mamba_cache_mode"]
)
cache_group.add_argument( cache_group.add_argument(
"--kv-offloading-size", **cache_kwargs["kv_offloading_size"] "--kv-offloading-size", **cache_kwargs["kv_offloading_size"]
) )
...@@ -1416,6 +1421,7 @@ class EngineArgs: ...@@ -1416,6 +1421,7 @@ class EngineArgs:
mamba_cache_dtype=self.mamba_cache_dtype, mamba_cache_dtype=self.mamba_cache_dtype,
mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype, mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
mamba_block_size=self.mamba_block_size, mamba_block_size=self.mamba_block_size,
mamba_cache_mode=self.mamba_cache_mode,
kv_offloading_size=self.kv_offloading_size, kv_offloading_size=self.kv_offloading_size,
kv_offloading_backend=self.kv_offloading_backend, kv_offloading_backend=self.kv_offloading_backend,
) )
......
...@@ -56,6 +56,7 @@ class MambaBase(AttentionLayerBase): ...@@ -56,6 +56,7 @@ class MambaBase(AttentionLayerBase):
block_size=mamba_block_size, block_size=mamba_block_size,
page_size_padded=page_size_padded, page_size_padded=page_size_padded,
mamba_type=self.mamba_type, mamba_type=self.mamba_type,
mamba_cache_mode=vllm_config.cache_config.mamba_cache_mode,
num_speculative_blocks=( num_speculative_blocks=(
vllm_config.speculative_config.num_speculative_tokens vllm_config.speculative_config.num_speculative_tokens
if vllm_config.speculative_config if vllm_config.speculative_config
......
...@@ -255,7 +255,7 @@ class MambaMixer(MambaBase, CustomOp): ...@@ -255,7 +255,7 @@ class MambaMixer(MambaBase, CustomOp):
assert self.cache_config is not None assert self.cache_config is not None
mamba_block_size = self.cache_config.mamba_block_size mamba_block_size = self.cache_config.mamba_block_size
prefix_caching_enabled = self.cache_config.enable_prefix_caching is_mamba_cache_all = self.cache_config.mamba_cache_mode == "all"
if attn_metadata is not None: if attn_metadata is not None:
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata, dict)
...@@ -304,7 +304,7 @@ class MambaMixer(MambaBase, CustomOp): ...@@ -304,7 +304,7 @@ class MambaMixer(MambaBase, CustomOp):
state_indices_tensor_p = prefill_decode_split.state_indices_tensor_p state_indices_tensor_p = prefill_decode_split.state_indices_tensor_p
state_indices_tensor_d = prefill_decode_split.state_indices_tensor_d state_indices_tensor_d = prefill_decode_split.state_indices_tensor_d
if prefix_caching_enabled: if is_mamba_cache_all:
block_idx_last_computed_token_d, block_idx_last_computed_token_p = ( block_idx_last_computed_token_d, block_idx_last_computed_token_p = (
torch.split( torch.split(
attn_metadata.block_idx_last_computed_token, attn_metadata.block_idx_last_computed_token,
...@@ -380,7 +380,7 @@ class MambaMixer(MambaBase, CustomOp): ...@@ -380,7 +380,7 @@ class MambaMixer(MambaBase, CustomOp):
ssm_outputs.append(scan_out_p) ssm_outputs.append(scan_out_p)
if has_decode: if has_decode:
if prefix_caching_enabled: if is_mamba_cache_all:
state_indices_tensor_d_input = state_indices_tensor_d.gather( state_indices_tensor_d_input = state_indices_tensor_d.gather(
1, block_idx_last_computed_token_d.unsqueeze(1) 1, block_idx_last_computed_token_d.unsqueeze(1)
).squeeze(1) ).squeeze(1)
......
...@@ -570,7 +570,7 @@ class MambaMixer2(MambaBase, CustomOp): ...@@ -570,7 +570,7 @@ class MambaMixer2(MambaBase, CustomOp):
assert self.cache_config is not None assert self.cache_config is not None
mamba_block_size = self.cache_config.mamba_block_size mamba_block_size = self.cache_config.mamba_block_size
prefix_caching_enabled = self.cache_config.enable_prefix_caching is_mamba_cache_all = self.cache_config.mamba_cache_mode == "all"
if attn_metadata is not None: if attn_metadata is not None:
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix] attn_metadata = attn_metadata[self.prefix]
...@@ -622,7 +622,7 @@ class MambaMixer2(MambaBase, CustomOp): ...@@ -622,7 +622,7 @@ class MambaMixer2(MambaBase, CustomOp):
dim=0, dim=0,
) )
if prefix_caching_enabled: if is_mamba_cache_all:
# If prefix caching is enabled, retrieve the relevant variables # If prefix caching is enabled, retrieve the relevant variables
# for prefill and decode # for prefill and decode
block_idx_last_computed_token_d, block_idx_last_computed_token_p = ( block_idx_last_computed_token_d, block_idx_last_computed_token_p = (
...@@ -701,7 +701,7 @@ class MambaMixer2(MambaBase, CustomOp): ...@@ -701,7 +701,7 @@ class MambaMixer2(MambaBase, CustomOp):
initial_states = None initial_states = None
if has_initial_states_p is not None and prep_initial_states: if has_initial_states_p is not None and prep_initial_states:
kernel_ssm_indices = state_indices_tensor_p kernel_ssm_indices = state_indices_tensor_p
if prefix_caching_enabled: if is_mamba_cache_all:
kernel_ssm_indices = state_indices_tensor_p.gather( kernel_ssm_indices = state_indices_tensor_p.gather(
1, block_idx_last_computed_token_p.unsqueeze(1) 1, block_idx_last_computed_token_p.unsqueeze(1)
).squeeze(1) ).squeeze(1)
...@@ -729,14 +729,14 @@ class MambaMixer2(MambaBase, CustomOp): ...@@ -729,14 +729,14 @@ class MambaMixer2(MambaBase, CustomOp):
cu_chunk_seqlens=cu_chunk_seqlen_p, cu_chunk_seqlens=cu_chunk_seqlen_p,
last_chunk_indices=last_chunk_indices_p, last_chunk_indices=last_chunk_indices_p,
initial_states=initial_states, initial_states=initial_states,
return_intermediate_states=prefix_caching_enabled, return_intermediate_states=is_mamba_cache_all,
dt_softplus=True, dt_softplus=True,
dt_limit=(0.0, float("inf")), dt_limit=(0.0, float("inf")),
out=preallocated_ssm_out_p.view(num_prefill_tokens, -1, self.head_dim), out=preallocated_ssm_out_p.view(num_prefill_tokens, -1, self.head_dim),
state_dtype=ssm_state.dtype, state_dtype=ssm_state.dtype,
) )
if prefix_caching_enabled: if is_mamba_cache_all:
# The chunk_stride is the number of chunks per mamba block # The chunk_stride is the number of chunks per mamba block
# e.g., if mamba_block_size = 512 and chunk_size = 256, # e.g., if mamba_block_size = 512 and chunk_size = 256,
# then chunk_stride = 2 # then chunk_stride = 2
...@@ -815,7 +815,7 @@ class MambaMixer2(MambaBase, CustomOp): ...@@ -815,7 +815,7 @@ class MambaMixer2(MambaBase, CustomOp):
# Process decode requests # Process decode requests
if has_decode: if has_decode:
if prefix_caching_enabled: if is_mamba_cache_all:
state_indices_tensor_d_input = state_indices_tensor_d.gather( state_indices_tensor_d_input = state_indices_tensor_d.gather(
1, block_idx_last_computed_token_d.unsqueeze(1) 1, block_idx_last_computed_token_d.unsqueeze(1)
).squeeze(1) ).squeeze(1)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from dataclasses import dataclass
from typing import TypeAlias
import torch import torch
from vllm.config.cache import MambaDType from vllm.config.cache import MambaDType
...@@ -223,3 +227,94 @@ class MambaStateShapeCalculator: ...@@ -223,3 +227,94 @@ class MambaStateShapeCalculator:
conv_state_k_shape, conv_state_k_shape,
recurrent_state_shape, recurrent_state_shape,
) )
@dataclass
class MambaCopySpec:
"""
Data class specifying the memory-copy parameters for Mamba states used for
prefix caching in align mode.
Attributes:
start_addr (int): Starting address for the memory copy operation.
num_elements (int): Number of elements to copy from the starting address.
"""
start_addr: int
num_elements: int
MambaStateCopyFunc: TypeAlias = Callable[
[torch.Tensor, list[int], int, int], MambaCopySpec
]
"""
Type alias for a function that computes a MambaCopySpec for copying state slices.
Parameters:
state: torch.Tensor - the Mamba state tensor (e.g., conv or temporal states).
block_ids: list[int] - the list of block indices for the state to copy.
cur_block_idx: int - current block index within `block_ids` to copy from.
num_accepted_tokens: int - number of accepted tokens used to compute the copy offset.
Range: 1 .. 1 + num_speculative_tokens (inclusive).
"""
def get_conv_copy_spec(
state: torch.Tensor,
block_ids: list[int],
cur_block_idx: int,
num_accepted_tokens: int,
) -> MambaCopySpec:
"""Return a MambaCopySpec for copying a convolutional state slice."""
src_block_id = block_ids[cur_block_idx]
src_state = state[src_block_id, num_accepted_tokens - 1 :]
return MambaCopySpec(
start_addr=src_state.data_ptr(), num_elements=src_state.numel()
)
def get_temporal_copy_spec(
state: torch.Tensor,
block_ids: list[int],
cur_block_idx: int,
num_accepted_tokens: int,
) -> MambaCopySpec:
"""Return a MambaCopySpec for copying a temporal state slice."""
src_block_id = block_ids[cur_block_idx + num_accepted_tokens - 1]
src_state = state[src_block_id]
return MambaCopySpec(
start_addr=src_state.data_ptr(), num_elements=src_state.numel()
)
get_full_copy_spec = get_temporal_copy_spec
class MambaStateCopyFuncCalculator:
@classmethod
def linear_attention_state_copy_func(cls):
return (get_temporal_copy_spec,)
@classmethod
def mamba1_state_copy_func(cls):
return (get_conv_copy_spec, get_temporal_copy_spec)
@classmethod
def mamba2_state_copy_func(cls):
return get_conv_copy_spec, get_temporal_copy_spec
@classmethod
def short_conv_state_copy_func(cls):
return (get_conv_copy_spec,)
@classmethod
def gated_delta_net_state_copy_func(cls):
return (get_conv_copy_spec, get_temporal_copy_spec)
@classmethod
def kda_state_copy_func(cls):
return (
get_conv_copy_spec,
get_conv_copy_spec,
get_conv_copy_spec,
get_temporal_copy_spec,
)
...@@ -24,6 +24,8 @@ from vllm.model_executor.layers.linear import ( ...@@ -24,6 +24,8 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator, MambaStateDtypeCalculator,
MambaStateShapeCalculator, MambaStateShapeCalculator,
) )
...@@ -455,6 +457,10 @@ class BambaForCausalLM( ...@@ -455,6 +457,10 @@ class BambaForCausalLM(
conv_kernel=hf_config.mamba_d_conv, conv_kernel=hf_config.mamba_d_conv,
) )
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.mamba2_state_copy_func()
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config self.vllm_config = vllm_config
......
...@@ -330,11 +330,40 @@ class MambaModelConfig(VerifyAndUpdateConfig): ...@@ -330,11 +330,40 @@ class MambaModelConfig(VerifyAndUpdateConfig):
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
if cache_config.enable_prefix_caching: if cache_config.enable_prefix_caching:
if model_config.supports_mamba_prefix_caching: if cache_config.mamba_cache_mode == "none":
cache_config.mamba_cache_mode = (
"all" if model_config.supports_mamba_prefix_caching else "align"
)
logger.warning(
"Mamba cache mode is set to '%s' for %s by default "
"when prefix caching is enabled",
cache_config.mamba_cache_mode,
model_config.architecture,
)
if (
cache_config.mamba_cache_mode == "all"
and not model_config.supports_mamba_prefix_caching
):
cache_config.mamba_cache_mode = "align"
logger.warning(
"Hybrid or mamba-based model detected without support "
"for prefix caching with Mamba cache 'all' mode: "
"falling back to 'align' mode."
)
if cache_config.mamba_cache_mode == "align":
assert vllm_config.scheduler_config.enable_chunked_prefill, (
"Chunked prefill is required for mamba cache mode 'align'."
)
assert not vllm_config.speculative_config, (
"Mamba cache mode 'align' is currently not compatible "
"with speculative decoding."
)
logger.info( logger.info(
"Warning: Prefix caching is currently enabled. " "Warning: Prefix caching in Mamba cache '%s' "
"mode is currently enabled. "
"Its support for Mamba layers is experimental. " "Its support for Mamba layers is experimental. "
"Please report any issues you may observe." "Please report any issues you may observe.",
cache_config.mamba_cache_mode,
) )
# By default, mamba block size will be set to max_model_len (see # By default, mamba block size will be set to max_model_len (see
# below). When enabling prefix caching, we align mamba block size # below). When enabling prefix caching, we align mamba block size
...@@ -342,12 +371,11 @@ class MambaModelConfig(VerifyAndUpdateConfig): ...@@ -342,12 +371,11 @@ class MambaModelConfig(VerifyAndUpdateConfig):
if cache_config.mamba_block_size is None: if cache_config.mamba_block_size is None:
cache_config.mamba_block_size = cache_config.block_size cache_config.mamba_block_size = cache_config.block_size
else: else:
logger.info( if cache_config.mamba_cache_mode != "none":
"Hybrid or mamba-based model detected without " cache_config.mamba_cache_mode = "none"
"support for prefix caching: disabling." logger.warning(
"Mamba cache mode is set to 'none' when prefix caching is disabled"
) )
cache_config.enable_prefix_caching = False
if cache_config.mamba_block_size is None: if cache_config.mamba_block_size is None:
cache_config.mamba_block_size = model_config.max_model_len cache_config.mamba_block_size = model_config.max_model_len
...@@ -426,7 +454,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): ...@@ -426,7 +454,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
mamba_page_size = MambaSpec( mamba_page_size = MambaSpec(
shapes=model_cls.get_mamba_state_shape_from_config(vllm_config), shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config), dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
block_size=model_config.max_model_len, block_size=-1, # block_size doesn't matter for mamba page size
).page_size_bytes ).page_size_bytes
# Model may be marked as is_hybrid # Model may be marked as is_hybrid
...@@ -435,7 +463,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): ...@@ -435,7 +463,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
if mamba_page_size == 0: if mamba_page_size == 0:
return return
if cache_config.enable_prefix_caching: if cache_config.mamba_cache_mode == "all":
# With prefix caching, select attention block size to # With prefix caching, select attention block size to
# optimize for mamba kernel performance # optimize for mamba kernel performance
...@@ -479,6 +507,13 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): ...@@ -479,6 +507,13 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
attn_block_size, attn_block_size,
) )
# By default, mamba block size will be set to max_model_len.
# When enabling prefix caching and using align mamba cache
# mode, we align mamba block size to the block size as the
# basic granularity for prefix caching.
if cache_config.mamba_cache_mode == "align":
cache_config.mamba_block_size = cache_config.block_size
# compute new attention page size # compute new attention page size
attn_page_size = cache_config.block_size * attn_page_size_1_token attn_page_size = cache_config.block_size * attn_page_size_1_token
......
...@@ -24,6 +24,8 @@ from vllm.model_executor.layers.linear import ( ...@@ -24,6 +24,8 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator, MambaStateDtypeCalculator,
MambaStateShapeCalculator, MambaStateShapeCalculator,
) )
...@@ -551,6 +553,10 @@ class FalconH1ForCausalLM( ...@@ -551,6 +553,10 @@ class FalconH1ForCausalLM(
conv_kernel=hf_config.mamba_d_conv, conv_kernel=hf_config.mamba_d_conv,
) )
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.mamba2_state_copy_func()
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config self.vllm_config = vllm_config
......
...@@ -19,6 +19,8 @@ from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLine ...@@ -19,6 +19,8 @@ from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLine
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator, MambaStateDtypeCalculator,
MambaStateShapeCalculator, MambaStateShapeCalculator,
) )
...@@ -641,6 +643,10 @@ class GraniteMoeHybridForCausalLM( ...@@ -641,6 +643,10 @@ class GraniteMoeHybridForCausalLM(
conv_kernel=hf_config.mamba_d_conv, conv_kernel=hf_config.mamba_d_conv,
) )
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.mamba2_state_copy_func()
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
......
...@@ -24,6 +24,7 @@ from vllm.config import ModelConfig, SpeechToTextConfig ...@@ -24,6 +24,7 @@ from vllm.config import ModelConfig, SpeechToTextConfig
from vllm.inputs import TokensPrompt from vllm.inputs import TokensPrompt
from vllm.inputs.data import PromptType from vllm.inputs.data import PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.mamba.mamba_utils import MambaStateCopyFunc
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.utils.collection_utils import common_prefix from vllm.utils.collection_utils import common_prefix
from vllm.utils.func_utils import supports_kw from vllm.utils.func_utils import supports_kw
...@@ -776,6 +777,19 @@ class IsHybrid(Protocol): ...@@ -776,6 +777,19 @@ class IsHybrid(Protocol):
""" """
... ...
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, ...]:
"""Calculate copy-function callables for each Mamba state.
Returns:
A tuple of MambaStateCopyFunc callables that correspond, in order,
to the Mamba states produced by the model. Each callable accepts
(state, block_ids, cur_block_idx, num_accepted_tokens) and returns
a MambaCopySpec describing the memory-copy parameters for prefix
caching in align mode.
"""
...
@overload @overload
def is_hybrid(model: object) -> TypeIs[IsHybrid]: ... def is_hybrid(model: object) -> TypeIs[IsHybrid]: ...
......
...@@ -24,6 +24,8 @@ from vllm.model_executor.layers.linear import ( ...@@ -24,6 +24,8 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator, MambaStateDtypeCalculator,
MambaStateShapeCalculator, MambaStateShapeCalculator,
) )
...@@ -558,6 +560,10 @@ class JambaForCausalLM( ...@@ -558,6 +560,10 @@ class JambaForCausalLM(
conv_kernel=hf_config.mamba_d_conv, conv_kernel=hf_config.mamba_d_conv,
) )
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.mamba1_state_copy_func()
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
......
...@@ -26,6 +26,8 @@ from vllm.model_executor.layers.linear import ( ...@@ -26,6 +26,8 @@ from vllm.model_executor.layers.linear import (
) )
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator, MambaStateDtypeCalculator,
MambaStateShapeCalculator, MambaStateShapeCalculator,
) )
...@@ -544,6 +546,14 @@ class KimiLinearForCausalLM( ...@@ -544,6 +546,14 @@ class KimiLinearForCausalLM(
num_spec=num_spec, num_spec=num_spec,
) )
@classmethod
def get_mamba_state_copy_func(
cls,
) -> tuple[
MambaStateCopyFunc, MambaStateCopyFunc, MambaStateCopyFunc, MambaStateCopyFunc
]:
return MambaStateCopyFuncCalculator.kda_state_copy_func()
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
......
...@@ -20,6 +20,8 @@ from vllm.model_executor.layers.linear import ( ...@@ -20,6 +20,8 @@ from vllm.model_executor.layers.linear import (
) )
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator, MambaStateDtypeCalculator,
MambaStateShapeCalculator, MambaStateShapeCalculator,
) )
...@@ -459,13 +461,18 @@ class Lfm2ForCausalLM( ...@@ -459,13 +461,18 @@ class Lfm2ForCausalLM(
conv_kernel=hf_config.conv_L_cache, conv_kernel=hf_config.conv_L_cache,
) )
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.short_conv_state_copy_func()
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
if cache_config.mamba_cache_mode == "all":
assert not cache_config.enable_prefix_caching, ( raise NotImplementedError(
"Lfm2 currently does not support prefix caching" "Lfm2 currently does not support 'all' prefix caching, "
"please use '--mamba-cache-mode=align' instead"
) )
super().__init__() super().__init__()
......
...@@ -25,6 +25,8 @@ from vllm.model_executor.layers.linear import ( ...@@ -25,6 +25,8 @@ from vllm.model_executor.layers.linear import (
) )
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator, MambaStateDtypeCalculator,
MambaStateShapeCalculator, MambaStateShapeCalculator,
) )
...@@ -640,6 +642,10 @@ class Lfm2MoeForCausalLM( ...@@ -640,6 +642,10 @@ class Lfm2MoeForCausalLM(
conv_kernel=hf_config.conv_L_cache, conv_kernel=hf_config.conv_L_cache,
) )
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.short_conv_state_copy_func()
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
......
...@@ -16,6 +16,8 @@ from vllm.model_executor.layers.layernorm import RMSNorm ...@@ -16,6 +16,8 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator, MambaStateDtypeCalculator,
MambaStateShapeCalculator, MambaStateShapeCalculator,
) )
...@@ -261,6 +263,10 @@ class MambaForCausalLM( ...@@ -261,6 +263,10 @@ class MambaForCausalLM(
conv_kernel=hf_config.conv_kernel, conv_kernel=hf_config.conv_kernel,
) )
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.mamba1_state_copy_func()
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(input_buffers, **kwargs) return self.mamba_cache.copy_inputs_before_cuda_graphs(input_buffers, **kwargs)
......
...@@ -15,6 +15,8 @@ from vllm.model_executor.layers.layernorm import RMSNorm ...@@ -15,6 +15,8 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator, MambaStateDtypeCalculator,
MambaStateShapeCalculator, MambaStateShapeCalculator,
) )
...@@ -228,6 +230,10 @@ class Mamba2ForCausalLM( ...@@ -228,6 +230,10 @@ class Mamba2ForCausalLM(
conv_kernel=hf_config.conv_kernel, conv_kernel=hf_config.conv_kernel,
) )
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.mamba2_state_copy_func()
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment