Unverified Commit a903669e authored by Thomas Parnell's avatar Thomas Parnell Committed by GitHub
Browse files

[V1] Remove V0 code paths for Hybrid models (#25400)


Signed-off-by: default avatarThomas Parnell <tpa@zurich.ibm.com>
parent 2c58742d
...@@ -14,7 +14,6 @@ import torch.distributed ...@@ -14,7 +14,6 @@ import torch.distributed
from torch import nn from torch import nn
from transformers import MiniMaxConfig from transformers import MiniMaxConfig
from vllm import envs
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, VllmConfig
...@@ -44,7 +43,6 @@ from vllm.model_executor.models.utils import maybe_prefix ...@@ -44,7 +43,6 @@ from vllm.model_executor.models.utils import maybe_prefix
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import HasInnerState, IsHybrid from .interfaces import HasInnerState, IsHybrid
from .minimax_cache import MinimaxCacheManager, MinimaxCacheParams
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
...@@ -404,7 +402,6 @@ class MiniMaxText01DecoderLayer(nn.Module): ...@@ -404,7 +402,6 @@ class MiniMaxText01DecoderLayer(nn.Module):
def forward(self, def forward(self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: Union[list[dict], Optional[torch.Tensor]],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
is_warmup: bool = False, is_warmup: bool = False,
...@@ -418,7 +415,6 @@ class MiniMaxText01DecoderLayer(nn.Module): ...@@ -418,7 +415,6 @@ class MiniMaxText01DecoderLayer(nn.Module):
hidden_states=layernorm_output, hidden_states=layernorm_output,
output=self_attention_output, output=self_attention_output,
positions=positions, positions=positions,
kv_caches=kv_caches,
) )
residual = residual * self.layernorm_attention_alpha residual = residual * self.layernorm_attention_alpha
...@@ -563,10 +559,6 @@ class MiniMaxText01Model(nn.Module): ...@@ -563,10 +559,6 @@ class MiniMaxText01Model(nn.Module):
self._dtype = _dummy.dtype self._dtype = _dummy.dtype
del _dummy del _dummy
if not envs.VLLM_USE_V1:
self.minimax_cache = MinimaxCacheManager(
dtype=torch.float32, cache_shape=self.cache_shape)
norm_kwargs = {} norm_kwargs = {}
if hasattr(config, "rms_norm_eps"): if hasattr(config, "rms_norm_eps"):
norm_kwargs["eps"] = config.rms_norm_eps norm_kwargs["eps"] = config.rms_norm_eps
...@@ -614,25 +606,6 @@ class MiniMaxText01Model(nn.Module): ...@@ -614,25 +606,6 @@ class MiniMaxText01Model(nn.Module):
**kwargs) -> Union[torch.Tensor, IntermediateTensors]: **kwargs) -> Union[torch.Tensor, IntermediateTensors]:
forward_context = get_forward_context() forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata attn_metadata = forward_context.attn_metadata
if not envs.VLLM_USE_V1 and attn_metadata is None:
return None
if not envs.VLLM_USE_V1:
if "request_ids_to_seq_ids" not in kwargs:
kwargs["request_ids_to_seq_ids"] = {}
if "finished_requests_ids" not in kwargs:
kwargs["finished_requests_ids"] = []
(
minimax_cache_tensors,
state_indices_tensor,
) = self.minimax_cache.current_run_tensors(**kwargs)
if getattr(attn_metadata, "num_prefills", 0) > 0:
self._clear_prefill_cache(attn_metadata, minimax_cache_tensors,
**kwargs)
minimax_cache_params = MinimaxCacheParams(minimax_cache_tensors,
state_indices_tensor)
else:
minimax_cache_params = None
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
if inputs_embeds is None: if inputs_embeds is None:
...@@ -645,20 +618,10 @@ class MiniMaxText01Model(nn.Module): ...@@ -645,20 +618,10 @@ class MiniMaxText01Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
minimax_cache_index = 0
for layer in islice(self.layers, self.start_layer, self.end_layer): for layer in islice(self.layers, self.start_layer, self.end_layer):
_caches = None
if not envs.VLLM_USE_V1 and isinstance(
layer.self_attn, MiniMaxText01LinearAttention):
current_state_layer = minimax_cache_index
_caches = minimax_cache_params.at_layer_idx(
current_state_layer)
minimax_cache_index += 1
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states=hidden_states, hidden_states=hidden_states,
positions=positions, positions=positions,
kv_caches=_caches,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
residual=residual, residual=residual,
) )
...@@ -1003,13 +966,11 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): ...@@ -1003,13 +966,11 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
def get_mamba_state_shape_from_config( def get_mamba_state_shape_from_config(
cls, cls,
vllm_config: "VllmConfig", vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, ...], ...]: ) -> tuple[tuple[int, ...], ...]:
"""Calculate shape for MiniMaxText01LinearAttention cache. """Calculate shape for MiniMaxText01LinearAttention cache.
Args: Args:
vllm_config: vLLM config vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns: Returns:
Tuple containing: Tuple containing:
......
...@@ -23,21 +23,17 @@ from typing import Optional ...@@ -23,21 +23,17 @@ from typing import Optional
import torch import torch
from torch import nn from torch import nn
from vllm import envs
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import ReLUSquaredActivation from vllm.model_executor.layers.activation import ReLUSquaredActivation
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator) MambaStateDtypeCalculator, MambaStateShapeCalculator)
...@@ -49,14 +45,11 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -49,14 +45,11 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
SupportsLoRA, SupportsPP, SupportsLoRA, SupportsPP,
SupportsQuant) SupportsQuant)
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.model_executor.models.utils import ( from vllm.model_executor.models.utils import (
AutoWeightsLoader, WeightsMapper, make_empty_intermediate_tensors_factory, AutoWeightsLoader, WeightsMapper, make_empty_intermediate_tensors_factory,
make_layers, maybe_prefix) make_layers, maybe_prefix)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import NemotronHConfig from vllm.transformers_utils.configs import NemotronHConfig
from vllm.utils import LayerBlockType
class NemotronHMLP(nn.Module): class NemotronHMLP(nn.Module):
...@@ -181,8 +174,6 @@ class NemotronHMambaDecoderLayer(nn.Module): ...@@ -181,8 +174,6 @@ class NemotronHMambaDecoderLayer(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs, **kwargs,
): ):
if residual is None: if residual is None:
...@@ -192,7 +183,7 @@ class NemotronHMambaDecoderLayer(nn.Module): ...@@ -192,7 +183,7 @@ class NemotronHMambaDecoderLayer(nn.Module):
hidden_states, residual = self.norm(hidden_states, residual) hidden_states, residual = self.norm(hidden_states, residual)
output = torch.empty_like(hidden_states) output = torch.empty_like(hidden_states)
self.mixer(hidden_states, output, mamba_cache_params, mamba2_metadata) self.mixer(hidden_states, output)
return output, residual return output, residual
...@@ -370,22 +361,10 @@ class NemotronHModel(nn.Module): ...@@ -370,22 +361,10 @@ class NemotronHModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
attn_metadata = get_forward_context().attn_metadata
if not envs.VLLM_USE_V1:
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.chunk_size,
attn_metadata=attn_metadata,
)
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -398,22 +377,11 @@ class NemotronHModel(nn.Module): ...@@ -398,22 +377,11 @@ class NemotronHModel(nn.Module):
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
residual = None residual = None
num_non_mamba_layers = 0
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
layer_mamba_cache_params = None
if isinstance(layer,
NemotronHMambaDecoderLayer) and mamba_cache_params:
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
i - num_non_mamba_layers)
else:
num_non_mamba_layers += 1
hidden_states, residual = layer( hidden_states, residual = layer(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
residual=residual, residual=residual,
mamba_cache_params=layer_mamba_cache_params,
mamba2_metadata=mamba2_metadata,
) )
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
...@@ -508,13 +476,11 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -508,13 +476,11 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def get_mamba_state_shape_from_config( def get_mamba_state_shape_from_config(
cls, cls,
vllm_config: "VllmConfig", vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]: ) -> tuple[tuple[int, int], tuple[int, int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches. """Calculate shapes for Mamba's convolutional and state caches.
Args: Args:
vllm_config: vLLM config vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns: Returns:
Tuple containing: Tuple containing:
...@@ -533,7 +499,6 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -533,7 +499,6 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
head_dim=hf_config.mamba_head_dim, head_dim=hf_config.mamba_head_dim,
state_size=hf_config.ssm_state_size, state_size=hf_config.ssm_state_size,
conv_kernel=hf_config.conv_kernel, conv_kernel=hf_config.conv_kernel,
use_v1=use_v1,
) )
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
...@@ -566,8 +531,6 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -566,8 +531,6 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
if not lora_config else lora_config.lora_vocab_padding_size, if not lora_config else lora_config.lora_vocab_padding_size,
prefix=maybe_prefix(prefix, "lm_head"), prefix=maybe_prefix(prefix, "lm_head"),
) )
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
...@@ -584,40 +547,11 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -584,40 +547,11 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs): **kwargs):
mamba_cache_params = None hidden_states = self.model(input_ids, positions, intermediate_tensors,
if not envs.VLLM_USE_V1: inputs_embeds)
if self.mamba_cache is None:
num_mamba_layers = \
self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config,
LayerBlockType.mamba
)
mamba_state_shape = \
self.get_mamba_state_shape_from_config(
self.vllm_config, use_v1=False)
mamba_state_dtype = \
self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
num_mamba_layers,
*mamba_state_shape,
*mamba_state_dtype)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.model(input_ids, positions, mamba_cache_params,
intermediate_tensors, inputs_embeds)
return hidden_states return hidden_states
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
......
This diff is collapsed.
...@@ -12,7 +12,6 @@ import torch ...@@ -12,7 +12,6 @@ import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
...@@ -29,8 +28,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -29,8 +28,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata, update_metadata)
from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator) MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
...@@ -47,15 +44,13 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -47,15 +44,13 @@ from vllm.model_executor.model_loader.weight_utils import (
composed_weight_loader, default_weight_loader, sharded_weight_loader) composed_weight_loader, default_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
SupportsPP) SupportsPP)
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.model_executor.models.utils import ( from vllm.model_executor.models.utils import (
is_pp_missing_parameter, make_empty_intermediate_tensors_factory, is_pp_missing_parameter, make_empty_intermediate_tensors_factory,
make_layers, maybe_prefix) make_layers, maybe_prefix)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType, direct_register_custom_op from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
...@@ -194,16 +189,12 @@ class Plamo2MambaMixer(MambaBase, CustomOp): ...@@ -194,16 +189,12 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
self.chunk_size = self.config.mamba_chunk_size self.chunk_size = self.config.mamba_chunk_size
if envs.VLLM_USE_V1:
compilation_config = get_current_vllm_config().compilation_config compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context: if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}") raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self compilation_config.static_forward_context[prefix] = self
# The outer list is for v0 PP virtual engine. Though this code path # The tuple is (conv_state, ssm_state)
# only runs for v1, we have to do this to unify with the interface self.kv_cache = (torch.tensor([]), torch.tensor([]))
# of Attention + v0 PP.
# The inner tuple is (conv_state, ssm_state)
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
assert self.chunk_size != -1, "chunk_size must be set for v1" assert self.chunk_size != -1, "chunk_size must be set for v1"
self.prefix = prefix self.prefix = prefix
...@@ -227,8 +218,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp): ...@@ -227,8 +218,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs, **kwargs,
): ):
pass pass
...@@ -237,14 +226,8 @@ class Plamo2MambaMixer(MambaBase, CustomOp): ...@@ -237,14 +226,8 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs, **kwargs,
): ):
if not envs.VLLM_USE_V1:
CustomOp.forward(self, hidden_states, output, mamba_cache_params,
mamba2_metadata)
else:
torch.ops.vllm.plamo2_mamba_mixer( torch.ops.vllm.plamo2_mamba_mixer(
hidden_states, hidden_states,
output, output,
...@@ -255,41 +238,31 @@ class Plamo2MambaMixer(MambaBase, CustomOp): ...@@ -255,41 +238,31 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs, **kwargs,
): ):
forward_context = get_forward_context() forward_context = get_forward_context()
# mamba2_metadata contains metadata necessary for the mamba2 triton # attn_metadata contains metadata necessary for the mamba2 triton
# kernels to operate in continuous batching and in chunked prefill # kernels to operate in continuous batching and in chunked prefill
# modes; they are computed at top-level model forward since they # modes; they are computed at top-level model forward since they
# stay the same and reused for all mamba layers in the same iteration # stay the same and reused for all mamba layers in the same iteration
attn_metadata: AttentionMetadata = forward_context.attn_metadata attn_metadata: AttentionMetadata = forward_context.attn_metadata
if envs.VLLM_USE_V1:
if attn_metadata is not None: if attn_metadata is not None:
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix] attn_metadata = attn_metadata[self.prefix]
mamba2_metadata = attn_metadata
assert isinstance(attn_metadata, Mamba2AttentionMetadata) assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine] self_kv_cache = self.kv_cache[forward_context.virtual_engine]
# conv_state = (..., dim, width-1) yet contiguous along 'dim' # conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2) conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1] ssm_state = self_kv_cache[1]
state_indices_tensor = attn_metadata.state_indices_tensor state_indices_tensor = attn_metadata.state_indices_tensor
else: has_initial_states_p = attn_metadata.has_initial_states_p
conv_state = mamba_cache_params.conv_state prep_initial_states = attn_metadata.prep_initial_states
ssm_state = mamba_cache_params.ssm_state chunk_size = attn_metadata.chunk_size
state_indices_tensor = mamba_cache_params.state_indices_tensor seq_idx_p = attn_metadata.seq_idx_p
chunk_indices_p = attn_metadata.chunk_indices_p
# Common members between V1 metadata and V0 metadata chunk_offsets_p = attn_metadata.chunk_offsets_p
if mamba2_metadata is not None:
has_initial_states_p = mamba2_metadata.has_initial_states_p
prep_initial_states = mamba2_metadata.prep_initial_states
chunk_size = mamba2_metadata.chunk_size
seq_idx_p = mamba2_metadata.seq_idx_p
chunk_indices_p = mamba2_metadata.chunk_indices_p
chunk_offsets_p = mamba2_metadata.chunk_offsets_p
# 1. Gated MLP's linear projection # 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states) projected_states = self.in_proj(hidden_states)
...@@ -299,8 +272,8 @@ class Plamo2MambaMixer(MambaBase, CustomOp): ...@@ -299,8 +272,8 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2)) self.conv1d.weight.size(2))
if envs.VLLM_USE_V1 and attn_metadata is None: if attn_metadata is None:
# V1 profile run # profile run
hidden_states = (hidden_states.transpose(0, 1).clone().transpose( hidden_states = (hidden_states.transpose(0, 1).clone().transpose(
0, 1)).contiguous() 0, 1)).contiguous()
output[:] = self.out_proj(hidden_states) output[:] = self.out_proj(hidden_states)
...@@ -316,7 +289,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp): ...@@ -316,7 +289,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
# NOTE: V0 put prefill before decode, v1 puts decode before prefill # NOTE: V0 put prefill before decode, v1 puts decode before prefill
# Separate prefill and decode by splitting varlen input # Separate prefill and decode by splitting varlen input
# Split along token dimension # Split along token dimension
if envs.VLLM_USE_V1:
hidden_states_d, hidden_states_p = torch.split( hidden_states_d, hidden_states_p = torch.split(
hidden_states[:num_actual_tokens], hidden_states[:num_actual_tokens],
[num_decodes, num_prefill_tokens], [num_decodes, num_prefill_tokens],
...@@ -334,24 +306,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp): ...@@ -334,24 +306,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
query_start_loc_p = ( query_start_loc_p = (
attn_metadata.query_start_loc[-num_prefills - 1:] - attn_metadata.query_start_loc[-num_prefills - 1:] -
num_decodes if has_prefill else None) num_decodes if has_prefill else None)
else:
hidden_states_p, hidden_states_d = torch.split(
hidden_states,
[num_prefill_tokens, num_decodes],
dim=0,
)
gate_p, gate_d = torch.split(gate,
[num_prefill_tokens, num_decodes],
dim=0)
# Split along batch dimension
state_indices_tensor_p, state_indices_tensor_d = torch.split(
state_indices_tensor,
[num_prefills, num_decodes],
dim=0,
)
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills +
1]
if has_prefill else None)
# Preallocate output tensor to avoid memcpy cost for merging prefill # Preallocate output tensor to avoid memcpy cost for merging prefill
# and decode outputs # and decode outputs
...@@ -363,18 +317,11 @@ class Plamo2MambaMixer(MambaBase, CustomOp): ...@@ -363,18 +317,11 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
device=hidden_states.device, device=hidden_states.device,
) )
if envs.VLLM_USE_V1:
preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split( preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
preallocated_ssm_out, preallocated_ssm_out,
[num_decodes, num_prefill_tokens], [num_decodes, num_prefill_tokens],
dim=0, dim=0,
) )
else:
preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
preallocated_ssm_out,
[num_prefill_tokens, num_decodes],
dim=0,
)
# Process prefill requests # Process prefill requests
if has_prefill: if has_prefill:
...@@ -383,9 +330,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp): ...@@ -383,9 +330,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
# pointed to by "state_indices_tensor" # pointed to by "state_indices_tensor"
x = hidden_states_p.transpose( x = hidden_states_p.transpose(
0, 1) # this is the form that causal-conv see 0, 1) # this is the form that causal-conv see
if mamba2_metadata.cu_seqlen is None:
mamba2_metadata = update_metadata(x, query_start_loc_p,
mamba2_metadata)
hidden_states_p = causal_conv1d_fn( hidden_states_p = causal_conv1d_fn(
x, x,
conv_weights, conv_weights,
...@@ -394,7 +338,7 @@ class Plamo2MambaMixer(MambaBase, CustomOp): ...@@ -394,7 +338,7 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
conv_states=conv_state, conv_states=conv_state,
has_initial_state=has_initial_states_p, has_initial_state=has_initial_states_p,
cache_indices=state_indices_tensor_p, cache_indices=state_indices_tensor_p,
metadata=mamba2_metadata, metadata=attn_metadata,
query_start_loc=query_start_loc_p) query_start_loc=query_start_loc_p)
hidden_states_p = hidden_states_p.transpose(0, 1) hidden_states_p = hidden_states_p.transpose(0, 1)
hidden_states_p = hidden_states_p[:num_prefill_tokens] hidden_states_p = hidden_states_p[:num_prefill_tokens]
...@@ -470,7 +414,7 @@ class Plamo2MambaMixer(MambaBase, CustomOp): ...@@ -470,7 +414,7 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
-1, self.num_heads // self.tp_size, self.head_dim) -1, self.num_heads // self.tp_size, self.head_dim)
# - the hidden is reshaped into (bs, num_heads, head_dim) # - the hidden is reshaped into (bs, num_heads, head_dim)
# - mamba_cache_params.ssm_state's slots will be selected # - ssm_state's slots will be selected
# using state_indices_tensor_d # using state_indices_tensor_d
# NOTE: final output is an in-place update of out tensor # NOTE: final output is an in-place update of out tensor
...@@ -530,10 +474,7 @@ def plamo2_mamba_mixer( ...@@ -530,10 +474,7 @@ def plamo2_mamba_mixer(
) -> None: ) -> None:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
self.forward_cuda(hidden_states=hidden_states, self.forward_cuda(hidden_states=hidden_states, output=output)
output=output,
mamba_cache_params=None,
mamba2_metadata=None)
def plamo2_mamba_mixer_fake( def plamo2_mamba_mixer_fake(
...@@ -731,8 +672,6 @@ class Plamo2DecoderLayer(nn.Module): ...@@ -731,8 +672,6 @@ class Plamo2DecoderLayer(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs, **kwargs,
): ):
if residual is None: if residual is None:
...@@ -747,8 +686,6 @@ class Plamo2DecoderLayer(nn.Module): ...@@ -747,8 +686,6 @@ class Plamo2DecoderLayer(nn.Module):
output = torch.empty_like(hidden_states) output = torch.empty_like(hidden_states)
mixer_kwargs = { mixer_kwargs = {
"output": output, "output": output,
"mamba_cache_params": mamba_cache_params,
"mamba2_metadata": mamba2_metadata,
} }
else: else:
mixer_kwargs = { mixer_kwargs = {
...@@ -790,23 +727,12 @@ class Plamo2Decoder(torch.nn.Module): ...@@ -790,23 +727,12 @@ class Plamo2Decoder(torch.nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
) -> torch.Tensor: ) -> torch.Tensor:
mamba_cache_index = 0
for layer in islice(self.layers, self.start_layer, self.end_layer): for layer in islice(self.layers, self.start_layer, self.end_layer):
layer_mamba_cache_params = None
if layer.is_mamba and mamba_cache_params is not None:
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
mamba_cache_index)
mamba_cache_index += 1
hidden_states, residual = layer( hidden_states, residual = layer(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
residual=residual, residual=residual,
mamba_cache_params=layer_mamba_cache_params,
mamba2_metadata=mamba2_metadata,
) )
return hidden_states, residual return hidden_states, residual
...@@ -844,7 +770,6 @@ class Plamo2Model(torch.nn.Module): ...@@ -844,7 +770,6 @@ class Plamo2Model(torch.nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -859,23 +784,10 @@ class Plamo2Model(torch.nn.Module): ...@@ -859,23 +784,10 @@ class Plamo2Model(torch.nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
if not envs.VLLM_USE_V1:
attn_metadata: AttentionMetadata = get_forward_context(
).attn_metadata
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size,
attn_metadata=attn_metadata,
)
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None
hidden_states, residual = self.layers( hidden_states, residual = self.layers(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
residual=residual, residual=residual,
mamba_cache_params=mamba_cache_params,
mamba2_metadata=mamba2_metadata,
) )
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
...@@ -925,9 +837,6 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid): ...@@ -925,9 +837,6 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid):
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
self.config.vocab_size) self.config.vocab_size)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
...@@ -942,39 +851,11 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid): ...@@ -942,39 +851,11 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid):
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs): **kwargs):
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
num_mamba_layers = (
self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config,
LayerBlockType.mamba))
mamba_state_shape = self.get_mamba_state_shape_from_config(
self.vllm_config, use_v1=False)
mamba_state_dtype = \
self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
num_mamba_layers,
*mamba_state_shape,
*mamba_state_dtype)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
else:
# NOTE: mamba_cache_params is not needed for v1
mamba_cache_params = None
hidden_states = self.model(input_ids, positions, mamba_cache_params, hidden_states = self.model(input_ids, positions, intermediate_tensors,
intermediate_tensors, inputs_embeds) inputs_embeds)
return hidden_states return hidden_states
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
@classmethod @classmethod
def get_mamba_state_dtype_from_config( def get_mamba_state_dtype_from_config(
cls, cls,
...@@ -991,12 +872,10 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid): ...@@ -991,12 +872,10 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid):
def get_mamba_state_shape_from_config( def get_mamba_state_shape_from_config(
cls, cls,
vllm_config: "VllmConfig", vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]: ) -> tuple[tuple[int, int], tuple[int, int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches. """Calculate shapes for Mamba's convolutional and state caches.
Args: Args:
vllm_config: vLLM config vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns: Returns:
Tuple containing: Tuple containing:
- conv_state_shape: Shape for convolutional state cache - conv_state_shape: Shape for convolutional state cache
...@@ -1015,7 +894,6 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid): ...@@ -1015,7 +894,6 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid):
head_dim=hf_config.hidden_size_per_head, head_dim=hf_config.hidden_size_per_head,
state_size=hf_config.mamba_d_state, state_size=hf_config.mamba_d_state,
conv_kernel=hf_config.mamba_d_conv, conv_kernel=hf_config.mamba_d_conv,
use_v1=use_v1,
) )
def compute_logits( def compute_logits(
......
...@@ -11,7 +11,6 @@ from einops import rearrange ...@@ -11,7 +11,6 @@ from einops import rearrange
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from vllm import envs
from vllm.attention import Attention, AttentionBackend, AttentionMetadata from vllm.attention import Attention, AttentionBackend, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig, from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig,
...@@ -35,7 +34,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -35,7 +34,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba2_metadata import update_metadata
from vllm.model_executor.layers.mamba.mamba_mixer2 import ( from vllm.model_executor.layers.mamba.mamba_mixer2 import (
mamba_v2_sharded_weight_loader) mamba_v2_sharded_weight_loader)
from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.mamba.mamba_utils import (
...@@ -51,7 +49,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -51,7 +49,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, sharded_weight_loader) default_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.mamba_cache import MambaCacheParams
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -198,14 +195,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): ...@@ -198,14 +195,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
return MambaStateShapeCalculator.gated_delta_net_state_shape( return MambaStateShapeCalculator.gated_delta_net_state_shape(
self.tp_size, self.tp_size, self.num_k_heads, self.num_v_heads, self.head_k_dim,
self.num_k_heads, self.head_v_dim, self.conv_kernel_size, self.num_spec)
self.num_v_heads,
self.head_k_dim,
self.head_v_dim,
self.conv_kernel_size,
self.num_spec,
use_v1=True)
def __init__( def __init__(
self, self,
...@@ -394,7 +385,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): ...@@ -394,7 +385,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
cache_params: Optional[MambaCacheParams] = None,
): ):
return torch.ops.vllm.gdn_attention( return torch.ops.vllm.gdn_attention(
hidden_states, hidden_states,
...@@ -416,7 +406,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): ...@@ -416,7 +406,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix] attn_metadata = attn_metadata[self.prefix]
conv_metadata = attn_metadata
assert isinstance(attn_metadata, GDNAttentionMetadata) assert isinstance(attn_metadata, GDNAttentionMetadata)
has_initial_state = attn_metadata.has_initial_state has_initial_state = attn_metadata.has_initial_state
spec_query_start_loc = attn_metadata.spec_query_start_loc spec_query_start_loc = attn_metadata.spec_query_start_loc
...@@ -479,12 +468,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): ...@@ -479,12 +468,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
# 2.2: process the remaining part # 2.2: process the remaining part
if attn_metadata.num_prefills > 0: if attn_metadata.num_prefills > 0:
mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1) mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1)
if conv_metadata.cu_seqlen is None:
conv_metadata = update_metadata(mixed_qkv_non_spec_T,
non_spec_query_start_loc,
conv_metadata)
# - "cache_indices" updates the conv_state cache in positions # - "cache_indices" updates the conv_state cache in positions
# pointed to by "mamba_cache_params.state_indices_tensor" # pointed to by "state_indices_tensor"
mixed_qkv_non_spec = causal_conv1d_fn( mixed_qkv_non_spec = causal_conv1d_fn(
mixed_qkv_non_spec_T, mixed_qkv_non_spec_T,
conv_weights, conv_weights,
...@@ -494,7 +479,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): ...@@ -494,7 +479,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
has_initial_state=has_initial_state, has_initial_state=has_initial_state,
cache_indices=non_spec_state_indices_tensor, cache_indices=non_spec_state_indices_tensor,
query_start_loc=non_spec_query_start_loc, query_start_loc=non_spec_query_start_loc,
metadata=conv_metadata, metadata=attn_metadata,
).transpose(0, 1) ).transpose(0, 1)
elif attn_metadata.num_decodes > 0: elif attn_metadata.num_decodes > 0:
mixed_qkv_non_spec = causal_conv1d_update( mixed_qkv_non_spec = causal_conv1d_update(
...@@ -1075,7 +1060,6 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -1075,7 +1060,6 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
scheduler_config = vllm_config.scheduler_config scheduler_config = vllm_config.scheduler_config
assert not cache_config.enable_prefix_caching, \ assert not cache_config.enable_prefix_caching, \
"Qwen3Next currently does not support prefix caching" "Qwen3Next currently does not support prefix caching"
assert envs.VLLM_USE_V1, "Qwen3Next requires VLLM_USE_V1"
self.quant_config = vllm_config.quant_config self.quant_config = vllm_config.quant_config
super().__init__() super().__init__()
...@@ -1195,14 +1179,10 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -1195,14 +1179,10 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
num_spec = (vllm_config.speculative_config.num_speculative_tokens num_spec = (vllm_config.speculative_config.num_speculative_tokens
if vllm_config.speculative_config else 0) if vllm_config.speculative_config else 0)
return MambaStateShapeCalculator.gated_delta_net_state_shape( return MambaStateShapeCalculator.gated_delta_net_state_shape(
tp_size, tp_size, hf_config.linear_num_key_heads,
hf_config.linear_num_key_heads, hf_config.linear_num_value_heads, hf_config.linear_key_head_dim,
hf_config.linear_num_value_heads, hf_config.linear_value_head_dim, hf_config.linear_conv_kernel_dim,
hf_config.linear_key_head_dim, num_spec)
hf_config.linear_value_head_dim,
hf_config.linear_conv_kernel_dim,
num_spec,
use_v1=True)
def compute_logits( def compute_logits(
self, self,
......
...@@ -134,7 +134,6 @@ _TEXT_GENERATION_MODELS = { ...@@ -134,7 +134,6 @@ _TEXT_GENERATION_MODELS = {
"PhiForCausalLM": ("phi", "PhiForCausalLM"), "PhiForCausalLM": ("phi", "PhiForCausalLM"),
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
"Phi4FlashForCausalLM": ("phi4flash", "Phi4FlashForCausalLM"),
"Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"), "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
......
...@@ -15,12 +15,10 @@ import torch ...@@ -15,12 +15,10 @@ import torch
from torch import nn from torch import nn
from transformers import Zamba2Config from transformers import Zamba2Config
from vllm import envs
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -29,8 +27,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -29,8 +27,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator) MambaStateDtypeCalculator, MambaStateShapeCalculator)
...@@ -39,8 +35,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope ...@@ -39,8 +35,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import HasInnerState, IsHybrid from .interfaces import HasInnerState, IsHybrid
...@@ -515,8 +509,6 @@ class Zamba2MambaDecoderLayer(nn.Module): ...@@ -515,8 +509,6 @@ class Zamba2MambaDecoderLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
transformer_hidden_states: Optional[torch.Tensor] = None, transformer_hidden_states: Optional[torch.Tensor] = None,
positions: Optional[torch.Tensor] = None, positions: Optional[torch.Tensor] = None,
original_hidden_states: Optional[torch.Tensor] = None, original_hidden_states: Optional[torch.Tensor] = None,
...@@ -525,8 +517,6 @@ class Zamba2MambaDecoderLayer(nn.Module): ...@@ -525,8 +517,6 @@ class Zamba2MambaDecoderLayer(nn.Module):
Args: Args:
hidden_states: Input tensor [batch_size, seq_len, hidden_size] hidden_states: Input tensor [batch_size, seq_len, hidden_size]
mamba_cache_params: Parameters for Mamba's state caches
(one for conv, one for ssm)
transformer_hidden_states: Optional output from transformer path transformer_hidden_states: Optional output from transformer path
Added to input if provided (used in hybrid architecture) Added to input if provided (used in hybrid architecture)
positions: Optional position IDs (unused in Mamba) positions: Optional position IDs (unused in Mamba)
...@@ -555,8 +545,6 @@ class Zamba2MambaDecoderLayer(nn.Module): ...@@ -555,8 +545,6 @@ class Zamba2MambaDecoderLayer(nn.Module):
self.mamba( self.mamba(
hidden_states, hidden_states,
output, output,
mamba_cache_params=mamba_cache_params,
mamba2_metadata=mamba2_metadata,
) )
# residual connection after mamba # residual connection after mamba
...@@ -607,8 +595,6 @@ class Zamba2HybridLayer(nn.Module): ...@@ -607,8 +595,6 @@ class Zamba2HybridLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
original_hidden_states: torch.Tensor, original_hidden_states: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass through the hybrid layer. """Forward pass through the hybrid layer.
...@@ -623,8 +609,6 @@ class Zamba2HybridLayer(nn.Module): ...@@ -623,8 +609,6 @@ class Zamba2HybridLayer(nn.Module):
original_hidden_states: Original input for transformer residual original_hidden_states: Original input for transformer residual
connection connection
positions: Position IDs for positional embeddings positions: Position IDs for positional embeddings
mamba_cache_params: Parameters for Mamba's state caches
(one for conv, one for ssm)
Returns: Returns:
Output tensor combining transformer and Mamba representations Output tensor combining transformer and Mamba representations
...@@ -644,8 +628,6 @@ class Zamba2HybridLayer(nn.Module): ...@@ -644,8 +628,6 @@ class Zamba2HybridLayer(nn.Module):
layer_outputs = self.mamba_decoder( layer_outputs = self.mamba_decoder(
hidden_states, hidden_states,
transformer_hidden_states=transformer_hidden_states, transformer_hidden_states=transformer_hidden_states,
mamba_cache_params=mamba_cache_params,
mamba2_metadata=mamba2_metadata,
) )
return layer_outputs return layer_outputs
...@@ -752,7 +734,6 @@ class Zamba2Model(nn.Module): ...@@ -752,7 +734,6 @@ class Zamba2Model(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
mamba_cache_params: MambaCacheParams,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
"""Forward pass through the model. """Forward pass through the model.
...@@ -760,8 +741,6 @@ class Zamba2Model(nn.Module): ...@@ -760,8 +741,6 @@ class Zamba2Model(nn.Module):
Args: Args:
input_ids: Input token IDs input_ids: Input token IDs
positions: Position IDs for embeddings positions: Position IDs for embeddings
mamba_cache_params: Parameters for Mamba's state caches
(one for conv, one for ssm)
inputs_embeds: Optional pre-computed input embeddings inputs_embeds: Optional pre-computed input embeddings
Returns: Returns:
...@@ -773,33 +752,13 @@ class Zamba2Model(nn.Module): ...@@ -773,33 +752,13 @@ class Zamba2Model(nn.Module):
inputs_embeds = self.get_input_embeddings(input_ids) inputs_embeds = self.get_input_embeddings(input_ids)
hidden_states = inputs_embeds hidden_states = inputs_embeds
attn_metadata = get_forward_context().attn_metadata
if not envs.VLLM_USE_V1:
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.chunk_size,
attn_metadata=attn_metadata,
)
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None
# Process through layers # Process through layers
original_hidden_states = torch.clone(hidden_states) original_hidden_states = torch.clone(hidden_states)
for layer_idx, layer in enumerate(self.layers): for layer_idx, layer in enumerate(self.layers):
layer_mamba_cache_params = None
if (isinstance(layer, (Zamba2HybridLayer, Zamba2MambaDecoderLayer))
and mamba_cache_params):
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
layer_idx)
layer_outputs = layer( layer_outputs = layer(
hidden_states, hidden_states,
original_hidden_states=original_hidden_states, original_hidden_states=original_hidden_states,
positions=positions, positions=positions,
mamba_cache_params=layer_mamba_cache_params,
mamba2_metadata=mamba2_metadata,
) )
hidden_states = layer_outputs hidden_states = layer_outputs
...@@ -870,13 +829,11 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): ...@@ -870,13 +829,11 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
def get_mamba_state_shape_from_config( def get_mamba_state_shape_from_config(
cls, cls,
vllm_config: "VllmConfig", vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]: ) -> tuple[tuple[int, int], tuple[int, int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches. """Calculate shapes for Mamba's convolutional and state caches.
Args: Args:
vllm_config: vLLM config vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns: Returns:
Tuple containing: Tuple containing:
...@@ -896,7 +853,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): ...@@ -896,7 +853,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
head_dim=hf_config.mamba_headdim, head_dim=hf_config.mamba_headdim,
state_size=hf_config.mamba_d_state, state_size=hf_config.mamba_d_state,
conv_kernel=hf_config.mamba_d_conv, conv_kernel=hf_config.mamba_d_conv,
use_v1=use_v1,
) )
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
...@@ -945,9 +901,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): ...@@ -945,9 +901,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
# Tie weights with input embeddings if using same dimensions # Tie weights with input embeddings if using same dimensions
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
# Initialize logits processing and sampling # Initialize logits processing and sampling
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
...@@ -977,61 +930,15 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): ...@@ -977,61 +930,15 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
Returns: Returns:
Output hidden states Output hidden states
""" """
# Initialize Mamba cache if needed
mamba_cache_params = None
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
num_mamba_layers = self.config.num_hidden_layers
mamba_state_shape = \
self.get_mamba_state_shape_from_config(
self.vllm_config, use_v1=False)
mamba_state_dtype = \
self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
num_mamba_layers,
*mamba_state_shape,
*mamba_state_dtype)
# Get cache parameters for current run
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
# Forward pass through model # Forward pass through model
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,
positions, positions,
mamba_cache_params,
inputs_embeds, inputs_embeds,
) )
return hidden_states return hidden_states
def copy_inputs_before_cuda_graphs(
self, input_buffers: dict[str, torch.Tensor],
**kwargs: Any) -> dict[str, torch.Tensor]:
"""Copy inputs before CUDA graph capture.
Args:
input_buffers: Dictionary of input tensors
**kwargs: Additional arguments passed to cache manager
Returns:
Updated input buffers
"""
return self.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(
self, batch_size: int) -> dict[str, torch.Tensor]:
"""Get inputs for sequence-length-agnostic graph capture.
Args:
batch_size: Size of batch to capture
Returns:
Dictionary of capture inputs
"""
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
......
...@@ -12,6 +12,7 @@ from vllm.config import VllmConfig ...@@ -12,6 +12,7 @@ from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (AttentionCGSupport, from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder, AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
compute_causal_conv1d_metadata,
split_decodes_and_prefills) split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
...@@ -52,7 +53,6 @@ class GDNAttentionMetadata: ...@@ -52,7 +53,6 @@ class GDNAttentionMetadata:
# The following attributes are for triton implementation of causal_conv1d # The following attributes are for triton implementation of causal_conv1d
nums_dict: Optional[dict] = None nums_dict: Optional[dict] = None
cu_seqlen: Optional[int] = None
batch_ptr: Optional[torch.Tensor] = None batch_ptr: Optional[torch.Tensor] = None
token_chunk_offset_ptr: Optional[torch.Tensor] = None token_chunk_offset_ptr: Optional[torch.Tensor] = None
...@@ -134,6 +134,7 @@ class GDNAttentionMetadataBuilder( ...@@ -134,6 +134,7 @@ class GDNAttentionMetadataBuilder(
context_lens = m.num_computed_tokens_cpu context_lens = m.num_computed_tokens_cpu
context_lens_tensor = context_lens.to(query_start_loc.device) context_lens_tensor = context_lens.to(query_start_loc.device)
seq_lens_tensor = m.seq_lens seq_lens_tensor = m.seq_lens
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
if (not self.use_spec_decode or num_draft_tokens is None if (not self.use_spec_decode or num_draft_tokens is None
or num_draft_tokens.sum().item() == 0): or num_draft_tokens.sum().item() == 0):
...@@ -210,6 +211,8 @@ class GDNAttentionMetadataBuilder( ...@@ -210,6 +211,8 @@ class GDNAttentionMetadataBuilder(
has_initial_state = context_lens_tensor > 0 has_initial_state = context_lens_tensor > 0
if spec_sequence_masks is not None: if spec_sequence_masks is not None:
has_initial_state = has_initial_state[~spec_sequence_masks] has_initial_state = has_initial_state[~spec_sequence_masks]
nums_dict, batch_ptr, token_chunk_offset_ptr = \
compute_causal_conv1d_metadata(non_spec_query_start_loc)
else: else:
has_initial_state = None has_initial_state = None
num_actual_tokens = num_prefill_tokens + num_decode_tokens + \ num_actual_tokens = num_prefill_tokens + num_decode_tokens + \
...@@ -297,6 +300,9 @@ class GDNAttentionMetadataBuilder( ...@@ -297,6 +300,9 @@ class GDNAttentionMetadataBuilder(
spec_sequence_masks=spec_sequence_masks, spec_sequence_masks=spec_sequence_masks,
spec_token_masks=spec_token_masks, spec_token_masks=spec_token_masks,
num_accepted_tokens=num_accepted_tokens, num_accepted_tokens=num_accepted_tokens,
nums_dict=nums_dict,
batch_ptr=batch_ptr,
token_chunk_offset_ptr=token_chunk_offset_ptr,
) )
return attn_metadata return attn_metadata
......
...@@ -7,11 +7,12 @@ from typing import Optional ...@@ -7,11 +7,12 @@ from typing import Optional
import torch import torch
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.v1.attention.backends.mamba_attn import ( from vllm.v1.attention.backends.mamba_attn import (
BaseMambaAttentionMetadataBuilder) BaseMambaAttentionMetadataBuilder)
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, from vllm.v1.attention.backends.utils import (PAD_SLOT_ID,
CommonAttentionMetadata,
compute_causal_conv1d_metadata,
split_decodes_and_prefills) split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
...@@ -131,7 +132,6 @@ class Mamba2AttentionMetadata: ...@@ -131,7 +132,6 @@ class Mamba2AttentionMetadata:
# The following attributes are for triton implementation of causal_conv1d # The following attributes are for triton implementation of causal_conv1d
nums_dict: Optional[dict] = None nums_dict: Optional[dict] = None
cu_seqlen: Optional[int] = None
batch_ptr: Optional[torch.Tensor] = None batch_ptr: Optional[torch.Tensor] = None
token_chunk_offset_ptr: Optional[torch.Tensor] = None token_chunk_offset_ptr: Optional[torch.Tensor] = None
...@@ -161,6 +161,9 @@ class Mamba2AttentionMetadataBuilder( ...@@ -161,6 +161,9 @@ class Mamba2AttentionMetadataBuilder(
has_initial_states_p = None has_initial_states_p = None
prep_initial_states = False prep_initial_states = False
# for causal_conv1d
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
...@@ -198,6 +201,9 @@ class Mamba2AttentionMetadataBuilder( ...@@ -198,6 +201,9 @@ class Mamba2AttentionMetadataBuilder(
query_start_loc_p, self.chunk_size, query_start_loc_p, self.chunk_size,
num_prefill_tokens)) num_prefill_tokens))
nums_dict, batch_ptr, token_chunk_offset_ptr = \
compute_causal_conv1d_metadata(query_start_loc_p)
elif num_decodes <= self.decode_cudagraph_max_bs: elif num_decodes <= self.decode_cudagraph_max_bs:
# Pad state tensor for CUDA graph # Pad state tensor for CUDA graph
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes) num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes)
...@@ -220,5 +226,8 @@ class Mamba2AttentionMetadataBuilder( ...@@ -220,5 +226,8 @@ class Mamba2AttentionMetadataBuilder(
chunk_indices_p=chunk_indices_p, chunk_indices_p=chunk_indices_p,
chunk_offsets_p=chunk_offsets_p, chunk_offsets_p=chunk_offsets_p,
state_indices_tensor=state_indices_tensor, state_indices_tensor=state_indices_tensor,
nums_dict=nums_dict,
batch_ptr=batch_ptr,
token_chunk_offset_ptr=token_chunk_offset_ptr,
) )
return attn_metadata return attn_metadata
...@@ -9,6 +9,7 @@ from vllm.attention.backends.abstract import AttentionBackend ...@@ -9,6 +9,7 @@ from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
compute_causal_conv1d_metadata,
split_decodes_and_prefills) split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
...@@ -33,7 +34,6 @@ class ShortConvAttentionMetadata: ...@@ -33,7 +34,6 @@ class ShortConvAttentionMetadata:
# For causal_conv1d # For causal_conv1d
nums_dict: Optional[dict] = None nums_dict: Optional[dict] = None
cu_seqlen: Optional[int] = None
batch_ptr: Optional[torch.Tensor] = None batch_ptr: Optional[torch.Tensor] = None
token_chunk_offset_ptr: Optional[torch.Tensor] = None token_chunk_offset_ptr: Optional[torch.Tensor] = None
...@@ -57,6 +57,9 @@ class ShortConvAttentionMetadataBuilder( ...@@ -57,6 +57,9 @@ class ShortConvAttentionMetadataBuilder(
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
# for causal_conv1d
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills( split_decodes_and_prefills(
common_attn_metadata, common_attn_metadata,
...@@ -70,6 +73,12 @@ class ShortConvAttentionMetadataBuilder( ...@@ -70,6 +73,12 @@ class ShortConvAttentionMetadataBuilder(
has_initial_states = has_initial_states_cpu.to( has_initial_states = has_initial_states_cpu.to(
query_start_loc.device) query_start_loc.device)
query_start_loc_p = common_attn_metadata.query_start_loc[
-num_prefills - 1:] - num_decode_tokens
nums_dict, batch_ptr, token_chunk_offset_ptr = \
compute_causal_conv1d_metadata(query_start_loc_p)
attn_metadata = ShortConvAttentionMetadata( attn_metadata = ShortConvAttentionMetadata(
num_prefills=num_prefills, num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens, num_prefill_tokens=num_prefill_tokens,
...@@ -78,5 +87,8 @@ class ShortConvAttentionMetadataBuilder( ...@@ -78,5 +87,8 @@ class ShortConvAttentionMetadataBuilder(
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
has_initial_states=has_initial_states, has_initial_states=has_initial_states,
state_indices_tensor=state_indices_tensor, state_indices_tensor=state_indices_tensor,
nums_dict=nums_dict,
batch_ptr=batch_ptr,
token_chunk_offset_ptr=token_chunk_offset_ptr,
) )
return attn_metadata return attn_metadata
...@@ -34,6 +34,8 @@ logger = init_logger(__name__) ...@@ -34,6 +34,8 @@ logger = init_logger(__name__)
KVCacheLayoutType = Literal["NHD", "HND"] KVCacheLayoutType = Literal["NHD", "HND"]
_KV_CACHE_LAYOUT_OVERRIDE: Union[KVCacheLayoutType, None] = None _KV_CACHE_LAYOUT_OVERRIDE: Union[KVCacheLayoutType, None] = None
PAD_SLOT_ID = -1
def is_valid_kv_cache_layout(value: str) -> bool: def is_valid_kv_cache_layout(value: str) -> bool:
return value in get_args(KVCacheLayoutType) return value in get_args(KVCacheLayoutType)
...@@ -838,3 +840,52 @@ def create_fast_prefill_custom_backend( ...@@ -838,3 +840,52 @@ def create_fast_prefill_custom_backend(
builder_cls=FastPrefillAttentionBuilder) builder_cls=FastPrefillAttentionBuilder)
return attn_backend return attn_backend
def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor):
# Needed for causal_conv1d
seqlens = query_start_loc_p.diff().to('cpu')
nums_dict = {} # type: ignore
batch_ptr = None
token_chunk_offset_ptr = None
for BLOCK_M in [8]: # cover all BLOCK_M values
nums = -(-seqlens // BLOCK_M)
nums_dict[BLOCK_M] = {}
nums_dict[BLOCK_M]['nums'] = nums
nums_dict[BLOCK_M]['tot'] = nums.sum().item()
mlist = torch.from_numpy(np.repeat(np.arange(len(nums)), nums))
nums_dict[BLOCK_M]['mlist'] = mlist
mlist_len = len(nums_dict[BLOCK_M]['mlist'])
nums_dict[BLOCK_M]['mlist_len'] = mlist_len
MAX_NUM_PROGRAMS = max(1024, mlist_len) * 2
offsetlist = [] # type: ignore
for idx, num in enumerate(nums):
offsetlist.extend(range(num))
offsetlist = torch.tensor(offsetlist, dtype=torch.int32)
nums_dict[BLOCK_M]['offsetlist'] = offsetlist
if batch_ptr is None:
# Update default value after class definition
batch_ptr = torch.full((MAX_NUM_PROGRAMS, ),
PAD_SLOT_ID,
dtype=torch.int32,
device='cuda')
token_chunk_offset_ptr = torch.full((MAX_NUM_PROGRAMS, ),
PAD_SLOT_ID,
dtype=torch.int32,
device='cuda')
else:
if batch_ptr.nelement() < MAX_NUM_PROGRAMS:
batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID)
token_chunk_offset_ptr.resize_( # type: ignore
MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID)
batch_ptr[0:mlist_len].copy_(mlist)
token_chunk_offset_ptr[ # type: ignore
0:mlist_len].copy_(offsetlist)
nums_dict[BLOCK_M]['batch_ptr'] = batch_ptr
nums_dict[BLOCK_M]['token_chunk_offset_ptr'] = (token_chunk_offset_ptr
) # type: ignore
return nums_dict, batch_ptr, token_chunk_offset_ptr
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment