Unverified Commit 5206e5e2 authored by Harry Huang's avatar Harry Huang Committed by GitHub
Browse files

[V1][Hybrid] Mamba Prefix Caching with align mode (#30877)


Signed-off-by: default avatarhuanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: default avatarChen Zhang <zhangch99@outlook.com>
Co-authored-by: default avatarChen Zhang <zhangch99@outlook.com>
parent fec9da0a
......@@ -24,7 +24,7 @@ pytestmark = pytest.mark.cpu_test
def get_sliding_window_manager(sliding_window_spec, block_pool, enable_caching=True):
return SlidingWindowManager(
sliding_window_spec,
block_pool,
block_pool=block_pool,
enable_caching=enable_caching,
kv_cache_group_id=0,
)
......@@ -35,7 +35,7 @@ def get_chunked_local_attention_manager(
):
return ChunkedLocalAttentionManager(
chunked_local_attention_spec,
block_pool,
block_pool=block_pool,
enable_caching=enable_caching,
kv_cache_group_id=0,
)
......@@ -342,11 +342,15 @@ def test_get_num_blocks_to_allocate():
]
assert (
manager.get_num_blocks_to_allocate("1", 20 * block_size, cached_blocks_1, 0)
manager.get_num_blocks_to_allocate(
"1", 20 * block_size, cached_blocks_1, 0, 20 * block_size
)
== 20
)
assert (
manager.get_num_blocks_to_allocate("2", 20 * block_size, cached_blocks_2, 0)
manager.get_num_blocks_to_allocate(
"2", 20 * block_size, cached_blocks_2, 0, 20 * block_size
)
== 15
)
......@@ -375,6 +379,7 @@ def test_evictable_cached_blocks_not_double_allocated():
num_tokens=2 * block_size,
new_computed_blocks=[evictable_block],
total_computed_tokens=block_size,
num_tokens_main_model=2 * block_size,
)
# Free capacity check should count evictable cached blocks, but allocation
# should only allocate the truly new block.
......@@ -386,7 +391,9 @@ def test_evictable_cached_blocks_not_double_allocated():
num_local_computed_tokens=block_size,
num_external_computed_tokens=0,
)
new_blocks = manager.allocate_new_blocks(request_id, num_tokens=4)
new_blocks = manager.allocate_new_blocks(
request_id, num_tokens=4, num_tokens_main_model=4
)
assert len(new_blocks) == 1
assert len(manager.req_to_blocks[request_id]) == 2
......@@ -411,10 +418,14 @@ def test_chunked_local_attention_get_num_blocks_to_allocate():
]
assert (
manager.get_num_blocks_to_allocate("1", 20 * block_size, cached_blocks_1, 0)
manager.get_num_blocks_to_allocate(
"1", 20 * block_size, cached_blocks_1, 0, 20 * block_size
)
== 20
)
assert (
manager.get_num_blocks_to_allocate("2", 20 * block_size, cached_blocks_2, 0)
manager.get_num_blocks_to_allocate(
"2", 20 * block_size, cached_blocks_2, 0, 20 * block_size
)
== 15
)
This diff is collapsed.
......@@ -31,6 +31,7 @@ CacheDType = Literal[
"fp8_ds_mla",
]
MambaDType = Literal["auto", "float32", "float16"]
MambaCacheMode = Literal["all", "align", "none"]
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor", "xxhash", "xxhash_cbor"]
KVOffloadingBackend = Literal["native", "lmcache"]
......@@ -123,6 +124,15 @@ class CacheConfig:
"""The data type to use for the Mamba cache (ssm state only, conv state will
still be controlled by mamba_cache_dtype). If set to 'auto', the data type
for the ssm state will be determined by mamba_cache_dtype."""
mamba_cache_mode: MambaCacheMode = "none"
"""The cache strategy for Mamba layers.
- "none": set when prefix caching is disabled.
- "all": cache the mamba state of all tokens at position i * block_size. This is
the default behavior (for models that support it) when prefix caching is
enabled.
- "align": only cache the mamba state of the last token of each scheduler step and
when the token is at position i * block_size.
"""
# Will be set after profiling.
num_gpu_blocks: int | None = field(default=None, init=False)
......
......@@ -999,6 +999,17 @@ class VllmConfig:
# Default to enable HMA if not explicitly disabled by user or logic above.
self.scheduler_config.disable_hybrid_kv_cache_manager = False
if self.cache_config.mamba_cache_mode == "align":
if self.scheduler_config.long_prefill_token_threshold > 0:
assert (
self.scheduler_config.long_prefill_token_threshold
>= self.cache_config.block_size
)
assert not self.scheduler_config.disable_chunked_mm_input, (
"Chunked MM input is required because we need the flexibility to "
"schedule a multiple of block_size tokens even if they are in the "
"middle of a mm input"
)
if self.compilation_config.debug_dump_path:
self.compilation_config.debug_dump_path = (
self.compilation_config.debug_dump_path.absolute().expanduser()
......
......@@ -60,6 +60,7 @@ from vllm.config.cache import (
BlockSize,
CacheDType,
KVOffloadingBackend,
MambaCacheMode,
MambaDType,
PrefixCachingHashAlgo,
)
......@@ -556,6 +557,7 @@ class EngineArgs:
mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
mamba_block_size: int | None = get_field(CacheConfig, "mamba_block_size")
mamba_cache_mode: MambaCacheMode = CacheConfig.mamba_cache_mode
additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config")
......@@ -939,6 +941,9 @@ class EngineArgs:
cache_group.add_argument(
"--mamba-block-size", **cache_kwargs["mamba_block_size"]
)
cache_group.add_argument(
"--mamba-cache-mode", **cache_kwargs["mamba_cache_mode"]
)
cache_group.add_argument(
"--kv-offloading-size", **cache_kwargs["kv_offloading_size"]
)
......@@ -1416,6 +1421,7 @@ class EngineArgs:
mamba_cache_dtype=self.mamba_cache_dtype,
mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
mamba_block_size=self.mamba_block_size,
mamba_cache_mode=self.mamba_cache_mode,
kv_offloading_size=self.kv_offloading_size,
kv_offloading_backend=self.kv_offloading_backend,
)
......
......@@ -56,6 +56,7 @@ class MambaBase(AttentionLayerBase):
block_size=mamba_block_size,
page_size_padded=page_size_padded,
mamba_type=self.mamba_type,
mamba_cache_mode=vllm_config.cache_config.mamba_cache_mode,
num_speculative_blocks=(
vllm_config.speculative_config.num_speculative_tokens
if vllm_config.speculative_config
......
......@@ -255,7 +255,7 @@ class MambaMixer(MambaBase, CustomOp):
assert self.cache_config is not None
mamba_block_size = self.cache_config.mamba_block_size
prefix_caching_enabled = self.cache_config.enable_prefix_caching
is_mamba_cache_all = self.cache_config.mamba_cache_mode == "all"
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
......@@ -304,7 +304,7 @@ class MambaMixer(MambaBase, CustomOp):
state_indices_tensor_p = prefill_decode_split.state_indices_tensor_p
state_indices_tensor_d = prefill_decode_split.state_indices_tensor_d
if prefix_caching_enabled:
if is_mamba_cache_all:
block_idx_last_computed_token_d, block_idx_last_computed_token_p = (
torch.split(
attn_metadata.block_idx_last_computed_token,
......@@ -380,7 +380,7 @@ class MambaMixer(MambaBase, CustomOp):
ssm_outputs.append(scan_out_p)
if has_decode:
if prefix_caching_enabled:
if is_mamba_cache_all:
state_indices_tensor_d_input = state_indices_tensor_d.gather(
1, block_idx_last_computed_token_d.unsqueeze(1)
).squeeze(1)
......
......@@ -570,7 +570,7 @@ class MambaMixer2(MambaBase, CustomOp):
assert self.cache_config is not None
mamba_block_size = self.cache_config.mamba_block_size
prefix_caching_enabled = self.cache_config.enable_prefix_caching
is_mamba_cache_all = self.cache_config.mamba_cache_mode == "all"
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
......@@ -622,7 +622,7 @@ class MambaMixer2(MambaBase, CustomOp):
dim=0,
)
if prefix_caching_enabled:
if is_mamba_cache_all:
# If prefix caching is enabled, retrieve the relevant variables
# for prefill and decode
block_idx_last_computed_token_d, block_idx_last_computed_token_p = (
......@@ -701,7 +701,7 @@ class MambaMixer2(MambaBase, CustomOp):
initial_states = None
if has_initial_states_p is not None and prep_initial_states:
kernel_ssm_indices = state_indices_tensor_p
if prefix_caching_enabled:
if is_mamba_cache_all:
kernel_ssm_indices = state_indices_tensor_p.gather(
1, block_idx_last_computed_token_p.unsqueeze(1)
).squeeze(1)
......@@ -729,14 +729,14 @@ class MambaMixer2(MambaBase, CustomOp):
cu_chunk_seqlens=cu_chunk_seqlen_p,
last_chunk_indices=last_chunk_indices_p,
initial_states=initial_states,
return_intermediate_states=prefix_caching_enabled,
return_intermediate_states=is_mamba_cache_all,
dt_softplus=True,
dt_limit=(0.0, float("inf")),
out=preallocated_ssm_out_p.view(num_prefill_tokens, -1, self.head_dim),
state_dtype=ssm_state.dtype,
)
if prefix_caching_enabled:
if is_mamba_cache_all:
# The chunk_stride is the number of chunks per mamba block
# e.g., if mamba_block_size = 512 and chunk_size = 256,
# then chunk_stride = 2
......@@ -815,7 +815,7 @@ class MambaMixer2(MambaBase, CustomOp):
# Process decode requests
if has_decode:
if prefix_caching_enabled:
if is_mamba_cache_all:
state_indices_tensor_d_input = state_indices_tensor_d.gather(
1, block_idx_last_computed_token_d.unsqueeze(1)
).squeeze(1)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from dataclasses import dataclass
from typing import TypeAlias
import torch
from vllm.config.cache import MambaDType
......@@ -223,3 +227,94 @@ class MambaStateShapeCalculator:
conv_state_k_shape,
recurrent_state_shape,
)
@dataclass
class MambaCopySpec:
"""
Data class specifying the memory-copy parameters for Mamba states used for
prefix caching in align mode.
Attributes:
start_addr (int): Starting address for the memory copy operation.
num_elements (int): Number of elements to copy from the starting address.
"""
start_addr: int
num_elements: int
MambaStateCopyFunc: TypeAlias = Callable[
[torch.Tensor, list[int], int, int], MambaCopySpec
]
"""
Type alias for a function that computes a MambaCopySpec for copying state slices.
Parameters:
state: torch.Tensor - the Mamba state tensor (e.g., conv or temporal states).
block_ids: list[int] - the list of block indices for the state to copy.
cur_block_idx: int - current block index within `block_ids` to copy from.
num_accepted_tokens: int - number of accepted tokens used to compute the copy offset.
Range: 1 .. 1 + num_speculative_tokens (inclusive).
"""
def get_conv_copy_spec(
state: torch.Tensor,
block_ids: list[int],
cur_block_idx: int,
num_accepted_tokens: int,
) -> MambaCopySpec:
"""Return a MambaCopySpec for copying a convolutional state slice."""
src_block_id = block_ids[cur_block_idx]
src_state = state[src_block_id, num_accepted_tokens - 1 :]
return MambaCopySpec(
start_addr=src_state.data_ptr(), num_elements=src_state.numel()
)
def get_temporal_copy_spec(
state: torch.Tensor,
block_ids: list[int],
cur_block_idx: int,
num_accepted_tokens: int,
) -> MambaCopySpec:
"""Return a MambaCopySpec for copying a temporal state slice."""
src_block_id = block_ids[cur_block_idx + num_accepted_tokens - 1]
src_state = state[src_block_id]
return MambaCopySpec(
start_addr=src_state.data_ptr(), num_elements=src_state.numel()
)
get_full_copy_spec = get_temporal_copy_spec
class MambaStateCopyFuncCalculator:
@classmethod
def linear_attention_state_copy_func(cls):
return (get_temporal_copy_spec,)
@classmethod
def mamba1_state_copy_func(cls):
return (get_conv_copy_spec, get_temporal_copy_spec)
@classmethod
def mamba2_state_copy_func(cls):
return get_conv_copy_spec, get_temporal_copy_spec
@classmethod
def short_conv_state_copy_func(cls):
return (get_conv_copy_spec,)
@classmethod
def gated_delta_net_state_copy_func(cls):
return (get_conv_copy_spec, get_temporal_copy_spec)
@classmethod
def kda_state_copy_func(cls):
return (
get_conv_copy_spec,
get_conv_copy_spec,
get_conv_copy_spec,
get_temporal_copy_spec,
)
......@@ -24,6 +24,8 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
......@@ -455,6 +457,10 @@ class BambaForCausalLM(
conv_kernel=hf_config.mamba_d_conv,
)
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.mamba2_state_copy_func()
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config
......
......@@ -330,11 +330,40 @@ class MambaModelConfig(VerifyAndUpdateConfig):
cache_config = vllm_config.cache_config
if cache_config.enable_prefix_caching:
if model_config.supports_mamba_prefix_caching:
if cache_config.mamba_cache_mode == "none":
cache_config.mamba_cache_mode = (
"all" if model_config.supports_mamba_prefix_caching else "align"
)
logger.warning(
"Mamba cache mode is set to '%s' for %s by default "
"when prefix caching is enabled",
cache_config.mamba_cache_mode,
model_config.architecture,
)
if (
cache_config.mamba_cache_mode == "all"
and not model_config.supports_mamba_prefix_caching
):
cache_config.mamba_cache_mode = "align"
logger.warning(
"Hybrid or mamba-based model detected without support "
"for prefix caching with Mamba cache 'all' mode: "
"falling back to 'align' mode."
)
if cache_config.mamba_cache_mode == "align":
assert vllm_config.scheduler_config.enable_chunked_prefill, (
"Chunked prefill is required for mamba cache mode 'align'."
)
assert not vllm_config.speculative_config, (
"Mamba cache mode 'align' is currently not compatible "
"with speculative decoding."
)
logger.info(
"Warning: Prefix caching is currently enabled. "
"Warning: Prefix caching in Mamba cache '%s' "
"mode is currently enabled. "
"Its support for Mamba layers is experimental. "
"Please report any issues you may observe."
"Please report any issues you may observe.",
cache_config.mamba_cache_mode,
)
# By default, mamba block size will be set to max_model_len (see
# below). When enabling prefix caching, we align mamba block size
......@@ -342,12 +371,11 @@ class MambaModelConfig(VerifyAndUpdateConfig):
if cache_config.mamba_block_size is None:
cache_config.mamba_block_size = cache_config.block_size
else:
logger.info(
"Hybrid or mamba-based model detected without "
"support for prefix caching: disabling."
if cache_config.mamba_cache_mode != "none":
cache_config.mamba_cache_mode = "none"
logger.warning(
"Mamba cache mode is set to 'none' when prefix caching is disabled"
)
cache_config.enable_prefix_caching = False
if cache_config.mamba_block_size is None:
cache_config.mamba_block_size = model_config.max_model_len
......@@ -426,7 +454,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
mamba_page_size = MambaSpec(
shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
block_size=model_config.max_model_len,
block_size=-1, # block_size doesn't matter for mamba page size
).page_size_bytes
# Model may be marked as is_hybrid
......@@ -435,7 +463,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
if mamba_page_size == 0:
return
if cache_config.enable_prefix_caching:
if cache_config.mamba_cache_mode == "all":
# With prefix caching, select attention block size to
# optimize for mamba kernel performance
......@@ -479,6 +507,13 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
attn_block_size,
)
# By default, mamba block size will be set to max_model_len.
# When enabling prefix caching and using align mamba cache
# mode, we align mamba block size to the block size as the
# basic granularity for prefix caching.
if cache_config.mamba_cache_mode == "align":
cache_config.mamba_block_size = cache_config.block_size
# compute new attention page size
attn_page_size = cache_config.block_size * attn_page_size_1_token
......
......@@ -24,6 +24,8 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
......@@ -551,6 +553,10 @@ class FalconH1ForCausalLM(
conv_kernel=hf_config.mamba_d_conv,
)
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.mamba2_state_copy_func()
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config
......
......@@ -19,6 +19,8 @@ from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLine
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
......@@ -641,6 +643,10 @@ class GraniteMoeHybridForCausalLM(
conv_kernel=hf_config.mamba_d_conv,
)
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.mamba2_state_copy_func()
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
......
......@@ -24,6 +24,7 @@ from vllm.config import ModelConfig, SpeechToTextConfig
from vllm.inputs import TokensPrompt
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.model_executor.layers.mamba.mamba_utils import MambaStateCopyFunc
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.utils.collection_utils import common_prefix
from vllm.utils.func_utils import supports_kw
......@@ -776,6 +777,19 @@ class IsHybrid(Protocol):
"""
...
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, ...]:
"""Calculate copy-function callables for each Mamba state.
Returns:
A tuple of MambaStateCopyFunc callables that correspond, in order,
to the Mamba states produced by the model. Each callable accepts
(state, block_ids, cur_block_idx, num_accepted_tokens) and returns
a MambaCopySpec describing the memory-copy parameters for prefix
caching in align mode.
"""
...
@overload
def is_hybrid(model: object) -> TypeIs[IsHybrid]: ...
......
......@@ -24,6 +24,8 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
......@@ -558,6 +560,10 @@ class JambaForCausalLM(
conv_kernel=hf_config.mamba_d_conv,
)
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.mamba1_state_copy_func()
def compute_logits(
self,
hidden_states: torch.Tensor,
......
......@@ -26,6 +26,8 @@ from vllm.model_executor.layers.linear import (
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
......@@ -544,6 +546,14 @@ class KimiLinearForCausalLM(
num_spec=num_spec,
)
@classmethod
def get_mamba_state_copy_func(
cls,
) -> tuple[
MambaStateCopyFunc, MambaStateCopyFunc, MambaStateCopyFunc, MambaStateCopyFunc
]:
return MambaStateCopyFuncCalculator.kda_state_copy_func()
def compute_logits(
self,
hidden_states: torch.Tensor,
......
......@@ -20,6 +20,8 @@ from vllm.model_executor.layers.linear import (
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
......@@ -459,13 +461,18 @@ class Lfm2ForCausalLM(
conv_kernel=hf_config.conv_L_cache,
)
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.short_conv_state_copy_func()
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
cache_config = vllm_config.cache_config
assert not cache_config.enable_prefix_caching, (
"Lfm2 currently does not support prefix caching"
if cache_config.mamba_cache_mode == "all":
raise NotImplementedError(
"Lfm2 currently does not support 'all' prefix caching, "
"please use '--mamba-cache-mode=align' instead"
)
super().__init__()
......
......@@ -25,6 +25,8 @@ from vllm.model_executor.layers.linear import (
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
......@@ -640,6 +642,10 @@ class Lfm2MoeForCausalLM(
conv_kernel=hf_config.conv_L_cache,
)
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.short_conv_state_copy_func()
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
......
......@@ -16,6 +16,8 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
......@@ -261,6 +263,10 @@ class MambaForCausalLM(
conv_kernel=hf_config.conv_kernel,
)
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.mamba1_state_copy_func()
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(input_buffers, **kwargs)
......
......@@ -15,6 +15,8 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
......@@ -228,6 +230,10 @@ class Mamba2ForCausalLM(
conv_kernel=hf_config.conv_kernel,
)
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.mamba2_state_copy_func()
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment