Unverified Commit 92737542 authored by Asaf Joseph Gardin's avatar Asaf Joseph Gardin Committed by GitHub
Browse files

[Hybrid] Added supports_mamba_prefix_caching Protocol (#27339)


Signed-off-by: default avatarasafg <39553475+Josephasafg@users.noreply.github.com>
parent f4e81540
...@@ -1656,6 +1656,10 @@ class ModelConfig: ...@@ -1656,6 +1656,10 @@ class ModelConfig:
def has_inner_state(self): def has_inner_state(self):
return self._model_info.has_inner_state return self._model_info.has_inner_state
@property
def supports_mamba_prefix_caching(self) -> bool:
return self._model_info.supports_mamba_prefix_caching
@property @property
def use_mla(self) -> bool: def use_mla(self) -> bool:
return self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE return self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE
......
...@@ -37,7 +37,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -37,7 +37,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, SupportsQuant from .interfaces import (
HasInnerState,
IsHybrid,
SupportsLoRA,
SupportsMambaPrefixCaching,
SupportsPP,
SupportsQuant,
)
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
is_pp_missing_parameter, is_pp_missing_parameter,
...@@ -394,7 +401,13 @@ class BambaModel(nn.Module): ...@@ -394,7 +401,13 @@ class BambaModel(nn.Module):
class BambaForCausalLM( class BambaForCausalLM(
nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsQuant nn.Module,
HasInnerState,
SupportsLoRA,
SupportsPP,
IsHybrid,
SupportsQuant,
SupportsMambaPrefixCaching,
): ):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
......
...@@ -295,17 +295,8 @@ class MambaModelConfig(VerifyAndUpdateConfig): ...@@ -295,17 +295,8 @@ class MambaModelConfig(VerifyAndUpdateConfig):
# override by prefix caching logic later) # override by prefix caching logic later)
cache_config.mamba_block_size = model_config.max_model_len cache_config.mamba_block_size = model_config.max_model_len
# TODO(@tdoublep) find a better way to do this than whitelist
MAMBA2_MODELS = [
"BambaForCausalLM",
"FalconH1ForCausalLM",
"GraniteMoeHybridForCausalLM",
"Mamba2ForCausalLM",
"NemotronHForCausalLM",
"Zamba2ForCausalLM",
]
if cache_config.enable_prefix_caching: if cache_config.enable_prefix_caching:
if model_config.architecture in MAMBA2_MODELS: if model_config.supports_mamba_prefix_caching:
logger.info( logger.info(
"Warning: Prefix caching is currently enabled. " "Warning: Prefix caching is currently enabled. "
"Its support for Mamba2 layers is experimental. " "Its support for Mamba2 layers is experimental. "
......
...@@ -37,7 +37,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -37,7 +37,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP from .interfaces import (
HasInnerState,
IsHybrid,
SupportsLoRA,
SupportsMambaPrefixCaching,
SupportsPP,
)
from .utils import ( from .utils import (
PPMissingLayer, PPMissingLayer,
is_pp_missing_parameter, is_pp_missing_parameter,
...@@ -495,7 +501,14 @@ class FalconH1Model(nn.Module): ...@@ -495,7 +501,14 @@ class FalconH1Model(nn.Module):
return hidden_states return hidden_states
class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid): class FalconH1ForCausalLM(
nn.Module,
HasInnerState,
SupportsLoRA,
SupportsPP,
IsHybrid,
SupportsMambaPrefixCaching,
):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"], "qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"], "gate_up_proj": ["gate_proj", "up_proj"],
......
...@@ -34,7 +34,14 @@ from vllm.sequence import IntermediateTensors ...@@ -34,7 +34,14 @@ from vllm.sequence import IntermediateTensors
from .granitemoe import GraniteMoeMoE from .granitemoe import GraniteMoeMoE
from .granitemoeshared import GraniteMoeSharedMLP from .granitemoeshared import GraniteMoeSharedMLP
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, SupportsQuant from .interfaces import (
HasInnerState,
IsHybrid,
SupportsLoRA,
SupportsMambaPrefixCaching,
SupportsPP,
SupportsQuant,
)
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
is_pp_missing_parameter, is_pp_missing_parameter,
...@@ -584,7 +591,13 @@ class GraniteMoeHybridModel(nn.Module): ...@@ -584,7 +591,13 @@ class GraniteMoeHybridModel(nn.Module):
class GraniteMoeHybridForCausalLM( class GraniteMoeHybridForCausalLM(
nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsQuant nn.Module,
HasInnerState,
SupportsLoRA,
SupportsPP,
IsHybrid,
SupportsQuant,
SupportsMambaPrefixCaching,
): ):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
......
...@@ -697,6 +697,34 @@ def has_noops( ...@@ -697,6 +697,34 @@ def has_noops(
return getattr(model, "has_noops", False) return getattr(model, "has_noops", False)
@runtime_checkable
class SupportsMambaPrefixCaching(Protocol):
"""The interface for models whose mamba layers support prefix caching.
This is currently experimental.
"""
supports_mamba_prefix_caching: ClassVar[Literal[True]] = True
@overload
def supports_mamba_prefix_caching(
model: object,
) -> TypeIs[SupportsMambaPrefixCaching]: ...
@overload
def supports_mamba_prefix_caching(
model: type[object],
) -> TypeIs[type[SupportsMambaPrefixCaching]]: ...
def supports_mamba_prefix_caching(
model: type[object] | object,
) -> TypeIs[type[SupportsMambaPrefixCaching]] | TypeIs[SupportsMambaPrefixCaching]:
return getattr(model, "supports_mamba_prefix_caching", False)
@runtime_checkable @runtime_checkable
class SupportsCrossEncoding(Protocol): class SupportsCrossEncoding(Protocol):
"""The interface required for all models that support cross encoding.""" """The interface required for all models that support cross encoding."""
......
...@@ -25,7 +25,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -25,7 +25,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import HasInnerState, IsAttentionFree from vllm.model_executor.models.interfaces import (
HasInnerState,
IsAttentionFree,
SupportsMambaPrefixCaching,
)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .utils import ( from .utils import (
...@@ -189,7 +193,9 @@ class Mamba2Model(nn.Module): ...@@ -189,7 +193,9 @@ class Mamba2Model(nn.Module):
return loaded_params return loaded_params
class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): class Mamba2ForCausalLM(
nn.Module, HasInnerState, IsAttentionFree, SupportsMambaPrefixCaching
):
@classmethod @classmethod
def get_mamba_state_dtype_from_config( def get_mamba_state_dtype_from_config(
cls, cls,
......
...@@ -62,6 +62,7 @@ from vllm.model_executor.models.interfaces import ( ...@@ -62,6 +62,7 @@ from vllm.model_executor.models.interfaces import (
IsHybrid, IsHybrid,
MixtureOfExperts, MixtureOfExperts,
SupportsLoRA, SupportsLoRA,
SupportsMambaPrefixCaching,
SupportsPP, SupportsPP,
SupportsQuant, SupportsQuant,
) )
...@@ -695,6 +696,7 @@ class NemotronHForCausalLM( ...@@ -695,6 +696,7 @@ class NemotronHForCausalLM(
IsHybrid, IsHybrid,
SupportsQuant, SupportsQuant,
MixtureOfExperts, MixtureOfExperts,
SupportsMambaPrefixCaching,
): ):
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={"backbone": "model"}, orig_to_new_prefix={"backbone": "model"},
......
...@@ -39,6 +39,7 @@ from .interfaces import ( ...@@ -39,6 +39,7 @@ from .interfaces import (
is_attention_free, is_attention_free,
is_hybrid, is_hybrid,
supports_cross_encoding, supports_cross_encoding,
supports_mamba_prefix_caching,
supports_multimodal, supports_multimodal,
supports_multimodal_encoder_tp_data, supports_multimodal_encoder_tp_data,
supports_multimodal_raw_input_only, supports_multimodal_raw_input_only,
...@@ -496,6 +497,7 @@ class _ModelInfo: ...@@ -496,6 +497,7 @@ class _ModelInfo:
is_attention_free: bool is_attention_free: bool
is_hybrid: bool is_hybrid: bool
has_noops: bool has_noops: bool
supports_mamba_prefix_caching: bool
supports_transcription: bool supports_transcription: bool
supports_transcription_only: bool supports_transcription_only: bool
...@@ -518,6 +520,7 @@ class _ModelInfo: ...@@ -518,6 +520,7 @@ class _ModelInfo:
has_inner_state=has_inner_state(model), has_inner_state=has_inner_state(model),
is_attention_free=is_attention_free(model), is_attention_free=is_attention_free(model),
is_hybrid=is_hybrid(model), is_hybrid=is_hybrid(model),
supports_mamba_prefix_caching=supports_mamba_prefix_caching(model),
supports_transcription=supports_transcription(model), supports_transcription=supports_transcription(model),
supports_transcription_only=( supports_transcription_only=(
supports_transcription(model) and model.supports_transcription_only supports_transcription(model) and model.supports_transcription_only
......
...@@ -45,7 +45,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -45,7 +45,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import HasInnerState, IsHybrid from .interfaces import HasInnerState, IsHybrid, SupportsMambaPrefixCaching
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
...@@ -824,7 +824,7 @@ class Zamba2Model(nn.Module): ...@@ -824,7 +824,7 @@ class Zamba2Model(nn.Module):
return loaded_params return loaded_params
class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsMambaPrefixCaching):
"""Zamba2 model with causal language modeling head. """Zamba2 model with causal language modeling head.
This class wraps the core Zamba2 model and adds: This class wraps the core Zamba2 model and adds:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment