Unverified Commit 92737542 authored by Asaf Joseph Gardin's avatar Asaf Joseph Gardin Committed by GitHub
Browse files

[Hybrid] Added supports_mamba_prefix_caching Protocol (#27339)


Signed-off-by: default avatarasafg <39553475+Josephasafg@users.noreply.github.com>
parent f4e81540
......@@ -1656,6 +1656,10 @@ class ModelConfig:
def has_inner_state(self):
return self._model_info.has_inner_state
@property
def supports_mamba_prefix_caching(self) -> bool:
return self._model_info.supports_mamba_prefix_caching
@property
def use_mla(self) -> bool:
return self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE
......
......@@ -37,7 +37,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, SupportsQuant
from .interfaces import (
HasInnerState,
IsHybrid,
SupportsLoRA,
SupportsMambaPrefixCaching,
SupportsPP,
SupportsQuant,
)
from .utils import (
AutoWeightsLoader,
is_pp_missing_parameter,
......@@ -394,7 +401,13 @@ class BambaModel(nn.Module):
class BambaForCausalLM(
nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsQuant
nn.Module,
HasInnerState,
SupportsLoRA,
SupportsPP,
IsHybrid,
SupportsQuant,
SupportsMambaPrefixCaching,
):
packed_modules_mapping = {
"qkv_proj": [
......
......@@ -295,17 +295,8 @@ class MambaModelConfig(VerifyAndUpdateConfig):
# override by prefix caching logic later)
cache_config.mamba_block_size = model_config.max_model_len
# TODO(@tdoublep) find a better way to do this than whitelist
MAMBA2_MODELS = [
"BambaForCausalLM",
"FalconH1ForCausalLM",
"GraniteMoeHybridForCausalLM",
"Mamba2ForCausalLM",
"NemotronHForCausalLM",
"Zamba2ForCausalLM",
]
if cache_config.enable_prefix_caching:
if model_config.architecture in MAMBA2_MODELS:
if model_config.supports_mamba_prefix_caching:
logger.info(
"Warning: Prefix caching is currently enabled. "
"Its support for Mamba2 layers is experimental. "
......
......@@ -37,7 +37,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
from .interfaces import (
HasInnerState,
IsHybrid,
SupportsLoRA,
SupportsMambaPrefixCaching,
SupportsPP,
)
from .utils import (
PPMissingLayer,
is_pp_missing_parameter,
......@@ -495,7 +501,14 @@ class FalconH1Model(nn.Module):
return hidden_states
class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid):
class FalconH1ForCausalLM(
nn.Module,
HasInnerState,
SupportsLoRA,
SupportsPP,
IsHybrid,
SupportsMambaPrefixCaching,
):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
......
......@@ -34,7 +34,14 @@ from vllm.sequence import IntermediateTensors
from .granitemoe import GraniteMoeMoE
from .granitemoeshared import GraniteMoeSharedMLP
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, SupportsQuant
from .interfaces import (
HasInnerState,
IsHybrid,
SupportsLoRA,
SupportsMambaPrefixCaching,
SupportsPP,
SupportsQuant,
)
from .utils import (
AutoWeightsLoader,
is_pp_missing_parameter,
......@@ -584,7 +591,13 @@ class GraniteMoeHybridModel(nn.Module):
class GraniteMoeHybridForCausalLM(
nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsQuant
nn.Module,
HasInnerState,
SupportsLoRA,
SupportsPP,
IsHybrid,
SupportsQuant,
SupportsMambaPrefixCaching,
):
packed_modules_mapping = {
"qkv_proj": [
......
......@@ -697,6 +697,34 @@ def has_noops(
return getattr(model, "has_noops", False)
@runtime_checkable
class SupportsMambaPrefixCaching(Protocol):
"""The interface for models whose mamba layers support prefix caching.
This is currently experimental.
"""
supports_mamba_prefix_caching: ClassVar[Literal[True]] = True
@overload
def supports_mamba_prefix_caching(
model: object,
) -> TypeIs[SupportsMambaPrefixCaching]: ...
@overload
def supports_mamba_prefix_caching(
model: type[object],
) -> TypeIs[type[SupportsMambaPrefixCaching]]: ...
def supports_mamba_prefix_caching(
model: type[object] | object,
) -> TypeIs[type[SupportsMambaPrefixCaching]] | TypeIs[SupportsMambaPrefixCaching]:
return getattr(model, "supports_mamba_prefix_caching", False)
@runtime_checkable
class SupportsCrossEncoding(Protocol):
"""The interface required for all models that support cross encoding."""
......
......@@ -25,7 +25,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import HasInnerState, IsAttentionFree
from vllm.model_executor.models.interfaces import (
HasInnerState,
IsAttentionFree,
SupportsMambaPrefixCaching,
)
from vllm.sequence import IntermediateTensors
from .utils import (
......@@ -189,7 +193,9 @@ class Mamba2Model(nn.Module):
return loaded_params
class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
class Mamba2ForCausalLM(
nn.Module, HasInnerState, IsAttentionFree, SupportsMambaPrefixCaching
):
@classmethod
def get_mamba_state_dtype_from_config(
cls,
......
......@@ -62,6 +62,7 @@ from vllm.model_executor.models.interfaces import (
IsHybrid,
MixtureOfExperts,
SupportsLoRA,
SupportsMambaPrefixCaching,
SupportsPP,
SupportsQuant,
)
......@@ -695,6 +696,7 @@ class NemotronHForCausalLM(
IsHybrid,
SupportsQuant,
MixtureOfExperts,
SupportsMambaPrefixCaching,
):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={"backbone": "model"},
......
......@@ -39,6 +39,7 @@ from .interfaces import (
is_attention_free,
is_hybrid,
supports_cross_encoding,
supports_mamba_prefix_caching,
supports_multimodal,
supports_multimodal_encoder_tp_data,
supports_multimodal_raw_input_only,
......@@ -496,6 +497,7 @@ class _ModelInfo:
is_attention_free: bool
is_hybrid: bool
has_noops: bool
supports_mamba_prefix_caching: bool
supports_transcription: bool
supports_transcription_only: bool
......@@ -518,6 +520,7 @@ class _ModelInfo:
has_inner_state=has_inner_state(model),
is_attention_free=is_attention_free(model),
is_hybrid=is_hybrid(model),
supports_mamba_prefix_caching=supports_mamba_prefix_caching(model),
supports_transcription=supports_transcription(model),
supports_transcription_only=(
supports_transcription(model) and model.supports_transcription_only
......
......@@ -45,7 +45,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors
from .interfaces import HasInnerState, IsHybrid
from .interfaces import HasInnerState, IsHybrid, SupportsMambaPrefixCaching
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
......@@ -824,7 +824,7 @@ class Zamba2Model(nn.Module):
return loaded_params
class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsMambaPrefixCaching):
"""Zamba2 model with causal language modeling head.
This class wraps the core Zamba2 model and adds:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment