Unverified Commit a903669e authored by Thomas Parnell's avatar Thomas Parnell Committed by GitHub
Browse files

[V1] Remove V0 code paths for Hybrid models (#25400)


Signed-off-by: default avatarThomas Parnell <tpa@zurich.ibm.com>
parent 2c58742d
......@@ -14,7 +14,6 @@ import torch.distributed
from torch import nn
from transformers import MiniMaxConfig
from vllm import envs
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
......@@ -44,7 +43,6 @@ from vllm.model_executor.models.utils import maybe_prefix
from vllm.sequence import IntermediateTensors
from .interfaces import HasInnerState, IsHybrid
from .minimax_cache import MinimaxCacheManager, MinimaxCacheParams
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
......@@ -404,7 +402,6 @@ class MiniMaxText01DecoderLayer(nn.Module):
def forward(self,
hidden_states: torch.Tensor,
positions: torch.Tensor,
kv_caches: Union[list[dict], Optional[torch.Tensor]],
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
is_warmup: bool = False,
......@@ -418,7 +415,6 @@ class MiniMaxText01DecoderLayer(nn.Module):
hidden_states=layernorm_output,
output=self_attention_output,
positions=positions,
kv_caches=kv_caches,
)
residual = residual * self.layernorm_attention_alpha
......@@ -563,10 +559,6 @@ class MiniMaxText01Model(nn.Module):
self._dtype = _dummy.dtype
del _dummy
if not envs.VLLM_USE_V1:
self.minimax_cache = MinimaxCacheManager(
dtype=torch.float32, cache_shape=self.cache_shape)
norm_kwargs = {}
if hasattr(config, "rms_norm_eps"):
norm_kwargs["eps"] = config.rms_norm_eps
......@@ -614,25 +606,6 @@ class MiniMaxText01Model(nn.Module):
**kwargs) -> Union[torch.Tensor, IntermediateTensors]:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
if not envs.VLLM_USE_V1 and attn_metadata is None:
return None
if not envs.VLLM_USE_V1:
if "request_ids_to_seq_ids" not in kwargs:
kwargs["request_ids_to_seq_ids"] = {}
if "finished_requests_ids" not in kwargs:
kwargs["finished_requests_ids"] = []
(
minimax_cache_tensors,
state_indices_tensor,
) = self.minimax_cache.current_run_tensors(**kwargs)
if getattr(attn_metadata, "num_prefills", 0) > 0:
self._clear_prefill_cache(attn_metadata, minimax_cache_tensors,
**kwargs)
minimax_cache_params = MinimaxCacheParams(minimax_cache_tensors,
state_indices_tensor)
else:
minimax_cache_params = None
if get_pp_group().is_first_rank:
if inputs_embeds is None:
......@@ -645,20 +618,10 @@ class MiniMaxText01Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
minimax_cache_index = 0
for layer in islice(self.layers, self.start_layer, self.end_layer):
_caches = None
if not envs.VLLM_USE_V1 and isinstance(
layer.self_attn, MiniMaxText01LinearAttention):
current_state_layer = minimax_cache_index
_caches = minimax_cache_params.at_layer_idx(
current_state_layer)
minimax_cache_index += 1
hidden_states, residual = layer(
hidden_states=hidden_states,
positions=positions,
kv_caches=_caches,
attn_metadata=attn_metadata,
residual=residual,
)
......@@ -1003,13 +966,11 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
def get_mamba_state_shape_from_config(
cls,
vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, ...], ...]:
"""Calculate shape for MiniMaxText01LinearAttention cache.
Args:
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Tuple containing:
......
......@@ -23,21 +23,17 @@ from typing import Optional
import torch
from torch import nn
from vllm import envs
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import ReLUSquaredActivation
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator)
......@@ -49,14 +45,11 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
SupportsLoRA, SupportsPP,
SupportsQuant)
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.model_executor.models.utils import (
AutoWeightsLoader, WeightsMapper, make_empty_intermediate_tensors_factory,
make_layers, maybe_prefix)
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import NemotronHConfig
from vllm.utils import LayerBlockType
class NemotronHMLP(nn.Module):
......@@ -181,8 +174,6 @@ class NemotronHMambaDecoderLayer(nn.Module):
self,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs,
):
if residual is None:
......@@ -192,7 +183,7 @@ class NemotronHMambaDecoderLayer(nn.Module):
hidden_states, residual = self.norm(hidden_states, residual)
output = torch.empty_like(hidden_states)
self.mixer(hidden_states, output, mamba_cache_params, mamba2_metadata)
self.mixer(hidden_states, output)
return output, residual
......@@ -370,22 +361,10 @@ class NemotronHModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
attn_metadata = get_forward_context().attn_metadata
if not envs.VLLM_USE_V1:
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.chunk_size,
attn_metadata=attn_metadata,
)
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
......@@ -398,22 +377,11 @@ class NemotronHModel(nn.Module):
residual = intermediate_tensors["residual"]
residual = None
num_non_mamba_layers = 0
for i, layer in enumerate(self.layers):
layer_mamba_cache_params = None
if isinstance(layer,
NemotronHMambaDecoderLayer) and mamba_cache_params:
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
i - num_non_mamba_layers)
else:
num_non_mamba_layers += 1
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states,
residual=residual,
mamba_cache_params=layer_mamba_cache_params,
mamba2_metadata=mamba2_metadata,
)
if not get_pp_group().is_last_rank:
......@@ -508,13 +476,11 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def get_mamba_state_shape_from_config(
cls,
vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches.
Args:
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Tuple containing:
......@@ -533,7 +499,6 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
head_dim=hf_config.mamba_head_dim,
state_size=hf_config.ssm_state_size,
conv_kernel=hf_config.conv_kernel,
use_v1=use_v1,
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
......@@ -566,8 +531,6 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
if not lora_config else lora_config.lora_vocab_padding_size,
prefix=maybe_prefix(prefix, "lm_head"),
)
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
......@@ -584,40 +547,11 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
mamba_cache_params = None
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
num_mamba_layers = \
self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config,
LayerBlockType.mamba
)
mamba_state_shape = \
self.get_mamba_state_shape_from_config(
self.vllm_config, use_v1=False)
mamba_state_dtype = \
self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
num_mamba_layers,
*mamba_state_shape,
*mamba_state_dtype)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.model(input_ids, positions, mamba_cache_params,
intermediate_tensors, inputs_embeds)
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def compute_logits(
self,
hidden_states: torch.Tensor,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Iterable
from typing import Optional, Union
import torch
import torch.nn as nn
from transformers.activations import ACT2FN
import vllm.envs as envs
from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.selector import _Backend
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_state_update)
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
SupportsV0Only)
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.sequence import IntermediateTensors
from .utils import make_layers, maybe_prefix
logger = init_logger(__name__)
class SwiGLUActivation(nn.Module):
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
return x1 * nn.functional.silu(x2)
class SambaYMLP(nn.Module):
"""Gated Linear Unit.
Reference:
Language Modeling with Gated Convolutional Networks.
https://arxiv.org/pdf/1612.08083v3.pdf.
"""
def __init__(self, config):
super().__init__()
self.config = config
self.fc1 = nn.Linear(config.hidden_size,
2 * config.intermediate_size,
bias=False)
self.fc2 = nn.Linear(config.intermediate_size,
config.hidden_size,
bias=False)
self.activation_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_states):
y = self.fc1(hidden_states)
gate, y = y.chunk(2, dim=-1)
y = y * self.activation_fn(gate)
return self.fc2(y)
def get_virtual_engine():
forward_context: ForwardContext = get_forward_context()
return forward_context.virtual_engine
class SambaYAttention(nn.Module):
def __init__(self,
config,
layer_idx: Optional[int] = None,
yoco_cross: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = ""):
super().__init__()
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing "
"a `layer_idx` is not recommended and will lead to errors "
"during the forward call if caching is used. Please make "
"sure to provide a `layer_idx` when creating this class.")
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.yoco_cross = yoco_cross
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError("hidden_size must be divisible by num_heads "
f"(got `hidden_size`: {self.hidden_size} and "
f"`num_heads`: {self.num_heads}).")
op_size = self.num_heads * self.head_dim + 2 * (
self.num_key_value_heads * self.head_dim)
self.out_proj = nn.Linear(self.num_heads * self.head_dim,
self.hidden_size,
bias=True)
if yoco_cross:
self.Wqkv = nn.Linear(self.hidden_size,
self.num_heads * self.head_dim,
bias=True)
else:
self.Wqkv = nn.Linear(self.hidden_size, op_size, bias=True)
# disable sliding window for the second half of the model
is_sliding = config.layer_types[layer_idx] == "sliding_attention"
sliding_window = config.sliding_window if is_sliding else None
assert self.num_heads % 2 == 0, 'num_heads should be even'
assert self.num_key_value_heads % 2 == 0, 'num_heads should be even'
self.lambda_init = self.lambda_init_fn(layer_idx)
self.lambda_q1 = nn.Parameter(
torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,
std=0.1))
self.lambda_k1 = nn.Parameter(
torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,
std=0.1))
self.lambda_q2 = nn.Parameter(
torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,
std=0.1))
self.lambda_k2 = nn.Parameter(
torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,
std=0.1))
self.subln = nn.RMSNorm(2 * self.head_dim,
eps=1e-5,
elementwise_affine=True)
params = {
'differential_flash_attention_config': {
'lambda_init': self.lambda_init,
'lambda_q1': self.lambda_q1,
'lambda_k1': self.lambda_k1,
'lambda_q2': self.lambda_q2,
'lambda_k2': self.lambda_k2,
"subln": self.subln,
}
}
if yoco_cross:
kv_shared_layer_index = config.num_hidden_layers // 2 + 1
kv_sharing_target_layer_name = \
f"model.layers.{kv_shared_layer_index}.self_attn.attn"
else:
kv_sharing_target_layer_name = None
self.attn = Attention(
self.num_heads,
self.head_dim,
self.head_dim**-0.5,
num_kv_heads=self.num_key_value_heads,
cache_config=cache_config,
per_layer_sliding_window=sliding_window,
prefix=f"{prefix}.attn",
attn_type=AttentionType.DECODER,
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
**params)
assert self.attn.backend == _Backend.DIFFERENTIAL_FLASH_ATTN,\
"DIFFERENTIAL_FLASH_ATTN required"
def lambda_init_fn(self, depth):
return 0.8 - 0.6 * math.exp(-0.3 * depth)
def forward(
self,
hidden_states: torch.Tensor,
):
if not self.yoco_cross: # need to generate kv-cache
qkv = self.Wqkv(hidden_states)
q, k, v = qkv.split([
self.hidden_size, self.num_key_value_heads * self.head_dim,
self.num_key_value_heads * self.head_dim
],
dim=-1)
attn_output = self.attn(q, k, v)
else: # reuse the kv cache, full attention
q = self.Wqkv(hidden_states)
attn_output = self.attn(q, None, None)
attn_output = attn_output.view(-1, self.num_heads * self.head_dim)
return self.out_proj(attn_output)
class Phi4Mamba(nn.Module):
def __init__(
self,
d_model,
d_state=16,
d_conv=4,
expand=2,
dt_rank="auto",
dt_min=0.001,
dt_max=0.1,
dt_init="random", # difference
dt_scale=1.0, # difference
dt_init_floor=1e-4,
conv_bias=True,
bias=False,
use_fast_path=True, # Fused kernel options
layer_idx=None,
device=None,
dtype=None,
yoco_cross=False,
yoco_kv=False,
):
factory_kwargs = {"params_dtype": dtype} # difference
super().__init__()
self.yoco_cross = yoco_cross
self.yoco_kv = yoco_kv
self.d_model = d_model
self.d_state = d_state
self.d_conv = d_conv
self.expand = expand
self.d_inner = int(self.expand * self.d_model)
self.dt_rank = math.ceil(self.d_model /
16) if dt_rank == "auto" else dt_rank
self.use_fast_path = use_fast_path
self.layer_idx = layer_idx
self.swiGluActivation = SwiGLUActivation()
if self.yoco_cross:
self.in_proj = MergedColumnParallelLinear(self.d_model,
[self.d_inner],
bias=bias,
**factory_kwargs)
self.out_proj = RowParallelLinear(self.d_inner,
self.d_model,
bias=bias,
**factory_kwargs)
return
self.conv1d = ColumnParallelLinear(
input_size=d_conv,
output_size=self.d_inner,
bias=conv_bias,
params_dtype=dtype,
)
# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
# `ColumnParallelLinear` and `set_weight_attrs`
# doesn't allow to override it
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
self.in_proj = MergedColumnParallelLinear(
self.d_model,
[self.d_inner] * 2,
bias=bias,
params_dtype=dtype,
)
# selective projection used to make dt, B and C input dependent
self.x_proj = RowParallelLinear(
self.d_inner,
self.dt_rank + self.d_state * 2,
bias=False,
params_dtype=dtype,
)
# time step projection (discretization) -
# In the forward we need to apply dt_proj without the bias,
# as the bias is added in the selective scan kernel.
self.dt_proj = ColumnParallelLinear(
self.dt_rank,
self.d_inner,
bias=True,
skip_bias_add=True,
params_dtype=dtype,
)
# # D "skip" parameter
# self.D = nn.Parameter(torch.ones(self.d_inner)) # Keep in fp32
self.A = nn.Parameter(
torch.empty(
self.d_inner,
self.d_state,
dtype=torch.float32,
))
self.D = nn.Parameter(torch.ones(self.d_inner, dtype=torch.float32))
self.out_proj = RowParallelLinear(
self.d_inner,
self.d_model,
bias=bias,
input_is_parallel=True,
params_dtype=dtype,
)
self.activation = "silu"
def forward(self,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams,
yoco_key_values=None) -> torch.Tensor:
if self.yoco_cross:
out = self.in_proj(hidden_states)[0]
out = self.swiGluActivation(yoco_key_values, out)
out = self.out_proj(out)
return out[0], yoco_key_values
# 1. Gated MLP's linear projection
# projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
projected_states = self.in_proj(
hidden_states.to(self.in_proj.weight.dtype))[0].transpose(-2, -1)
hidden_states, gate = projected_states.chunk(2, dim=-2)
# 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))
if attn_metadata.query_start_loc is not None \
and attn_metadata.context_lens_tensor is not None:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
hidden_states = causal_conv1d_fn(
hidden_states,
conv_weights,
self.conv1d.bias,
activation=self.activation,
conv_states=mamba_cache_params.conv_state,
has_initial_state=attn_metadata.context_lens_tensor > 0,
cache_indices=mamba_cache_params.state_indices_tensor,
query_start_loc=attn_metadata.query_start_loc)
else:
hidden_states = causal_conv1d_update(
hidden_states.transpose(0, 1),
mamba_cache_params.conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=mamba_cache_params.state_indices_tensor)
hidden_states = hidden_states.transpose(0, 1)
# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]
time_step, B, C = torch.split(
ssm_parameters,
[self.dt_rank, self.d_state, self.d_state],
dim=-1,
)
# Note that Jamba normalizes B, C, and time_step here but Mamba doesn't.
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
time_proj_bias = (self.dt_proj.bias.float() if hasattr(
self.dt_proj, "bias") else None)
if attn_metadata.query_start_loc is not None \
and attn_metadata.context_lens_tensor is not None:
scan_outputs = selective_scan_fn(
hidden_states,
mamba_cache_params.ssm_state,
discrete_time_step,
self.A,
B.transpose(-2, -1),
C.transpose(-2, -1),
self.D.float(),
# z,
None if self.yoco_kv else gate,
time_proj_bias,
delta_softplus=True,
cache_indices=mamba_cache_params.state_indices_tensor,
has_initial_state=attn_metadata.context_lens_tensor > 0,
query_start_loc=attn_metadata.query_start_loc)
else:
scan_outputs = torch.empty_like(hidden_states.transpose(0, 1))
selective_state_update(
mamba_cache_params.ssm_state,
hidden_states.transpose(0, 1),
discrete_time_step.transpose(0, 1),
self.A,
B,
C,
self.D,
# z
# gate.transpose(0, 1),
None if self.yoco_kv else gate.transpose(0, 1),
time_proj_bias,
dt_softplus=True,
state_batch_indices=mamba_cache_params.state_indices_tensor,
out=scan_outputs)
scan_outputs = scan_outputs.transpose(0, 1)
# 4. Final linear projection
if self.yoco_kv:
# gate = gate.transpose(-1,-2).contiguous()
yoco_key_values = scan_outputs.transpose(-2, -1)
scan_outputs = self.swiGluActivation(scan_outputs, gate)
contextualized_states = self.out_proj(scan_outputs.transpose(-2,
-1))[0]
return contextualized_states, yoco_key_values
class SambaYDecoderLayer(nn.Module):
def __init__(
self,
config,
layer_idx,
cache_config,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.mlp = SambaYMLP(config)
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.yoco_mb = False
self.yoco_cross = False
if layer_idx >= config.num_hidden_layers // 2:
self.yoco_mb = True
self.yoco_cross = (layer_idx
>= (config.num_hidden_layers // 2 + 2))
self.use_mamba = config.mb_per_layer > 0 and \
layer_idx % config.mb_per_layer == 0
if self.use_mamba:
factory_kwargs = {"dtype": None}
self.attn = Phi4Mamba(config.hidden_size,
layer_idx=layer_idx,
yoco_cross=self.yoco_cross,
yoco_kv=self.yoco_mb,
**factory_kwargs)
else:
self.attn = SambaYAttention(config,
layer_idx=layer_idx,
yoco_cross=self.yoco_cross,
cache_config=cache_config,
prefix=f"{prefix}.self_attn")
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
positions: torch.Tensor,
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams,
ssm_output: Optional[torch.LongTensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if self.use_mamba:
assert mamba_cache_params is not None
else:
assert mamba_cache_params is None
residual = hidden_states
hidden_states = self.input_layernorm(
hidden_states.to(dtype=self.input_layernorm.weight.dtype))
if self.use_mamba:
attn_outputs, ssm_output = self.attn(hidden_states,
attn_metadata,
mamba_cache_params,
yoco_key_values=ssm_output)
residual = residual.to(torch.float32)
else:
attn_outputs = self.attn(hidden_states, )
hidden_states = residual + attn_outputs
residual = hidden_states
hidden_states = self.post_attention_layernorm(
hidden_states.to(dtype=self.post_attention_layernorm.weight.dtype))
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states, ssm_output
class SambaYModel(nn.Module):
def __init__(self,
config,
cache_config=None,
quant_config=None,
lora_config=None,
prefix: str = "") -> None:
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
# Pipeline parallel is not supported since the second half of
# the layers share the kv cache.
if get_pp_group().world_size != 1:
raise ValueError("Pipeline Parallel not supported")
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: SambaYDecoderLayer(config,
int(prefix.split('.')[-1]),
cache_config,
prefix=prefix),
prefix=f"{prefix}.layers")
self.final_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
mamba_state_idx = 0
ssm_output = None
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
if i == self.config.num_hidden_layers // 2 + 2:
# profile run
kv_cache_idx = self.config.num_hidden_layers // 2 + 1
cache_layer = self.layers[kv_cache_idx]
kv_cache = cache_layer.attn.attn.kv_cache
if kv_cache[0].numel() == 0:
break
# Starting from this layer, we do not need to calculate
# the kv cache since we reuse the kv cache from last layer.
# If in prefill phase, we can <s>prune></s> truncate
# the hidden state to save computation cost.
if attn_metadata.prefill_metadata and not envs.VLLM_USE_V1:
selected_token_indices = torch.cumsum(
attn_metadata.seq_lens_tensor, dim=0) - 1
hidden_states = hidden_states.index_select(
0, selected_token_indices)
ssm_output = ssm_output.index_select(
0, selected_token_indices)
if layer.use_mamba:
if i < self.config.num_hidden_layers // 2 or \
not layer.yoco_cross:
mamba_cache = mamba_cache_params.at_layer_idx(
mamba_state_idx)
mamba_state_idx += 1
else:
mamba_cache = mamba_cache_params.at_layer_idx(
mamba_state_idx - 1)
hidden_states, ssm_output = layer(hidden_states,
positions,
attn_metadata,
mamba_cache,
ssm_output=ssm_output)
else:
hidden_states, ssm_output = layer(
hidden_states,
positions,
attn_metadata,
None, # mamba_cache_params
ssm_output=ssm_output)
hidden_states = self.final_layernorm(
hidden_states.to(dtype=self.final_layernorm.weight.dtype))
return hidden_states
class Phi4FlashForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
lora_config = vllm_config.lora_config
quant_config = vllm_config.quant_config
scheduler_config = vllm_config.scheduler_config
self.compilation_config = vllm_config.compilation_config
self.vllm_config = vllm_config
# Prefix caching and chunked prefill is not supported for this model.
assert not cache_config.enable_prefix_caching, \
"Phi4flash currently does not support prefix caching"
assert not scheduler_config.chunked_prefill_enabled, \
"Phi4Flash currently does not support prefix caching"
super().__init__()
self.config = config
self.model_config = vllm_config.model_config
self.scheduler_config = scheduler_config
self.model = SambaYModel(config,
cache_config=cache_config,
prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=(
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
self.embedding_bias = None
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logits_as_input=False)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, IntermediateTensors]:
if self.mamba_cache is None:
num_mamba_layers = self.config.num_hidden_layers \
// 2 // self.config.mb_per_layer + 1
self.mamba_cache = MambaCacheManager(
self.vllm_config,
num_mamba_layers,
*self._get_mamba_cache_shape(),
self.lm_head.weight.dtype,
self.lm_head.weight.dtype,
)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
attn_metadata = get_forward_context().attn_metadata
# input_ids and hidden_states isn't a one-to-one mapping in prefill
# stage due to YOCO optimization.
hidden_states = self.model(input_ids, positions, attn_metadata,
mamba_cache_params, intermediate_tensors,
inputs_embeds)
return hidden_states
def _get_mamba_cache_shape(
self
) -> tuple[Optional[tuple[int, int]], Optional[tuple[int, int]]]:
world_size = get_tensor_model_parallel_world_size()
hidden_size = self.config.hidden_size
mamba_expand = self.config.mamba_expand # 2
mamba_d_conv = self.config.mamba_d_conv # 4
mamba_d_state = self.config.mamba_d_state # 16
conv_state_shape = (
mamba_expand * hidden_size // world_size,
mamba_d_conv - 1,
)
temporal_state_shape = (
mamba_expand * hidden_size // world_size,
mamba_d_state,
)
return conv_state_shape, temporal_state_shape
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> Optional[torch.Tensor]:
processed_logits = self.logits_processor(
self.lm_head,
hidden_states,
self.embedding_bias,
)
return processed_logits
def load_weights(
self,
weights: Iterable[tuple[str, torch.Tensor]],
):
weights = {name: weight for name, weight in weights}
adjusted_weights = {}
for name, weight in weights.items():
if "A_log" in name:
name = name.replace("A_log", "A")
weight = -torch.exp(weight.float())
if "inner_cross_attn." in name:
name = name.replace("inner_cross_attn.", "")
adjusted_weights[name] = weight
adjusted_weights["lm_head.weight"] = weights[
"model.embed_tokens.weight"]
loaded_params: set[str] = set()
for name, param in self.named_parameters():
weight = adjusted_weights.get(name)
if weight is not None and weight.shape != param.shape:
logger.warning("Shape mismatch: %s %s %s", name, weight.shape,
param.shape)
loaded_params.add(name)
missing_keys, unexpected_keys = self.load_state_dict(adjusted_weights,
strict=False)
assert len(unexpected_keys) == 0, f"Unexpected keys: {unexpected_keys}"
assert len(missing_keys) == 0, f"Missing keys: {missing_keys}"
return loaded_params
......@@ -12,7 +12,6 @@ import torch
from torch import nn
from transformers import PretrainedConfig
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
......@@ -29,8 +28,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata, update_metadata)
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
......@@ -47,15 +44,13 @@ from vllm.model_executor.model_loader.weight_utils import (
composed_weight_loader, default_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
SupportsPP)
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.model_executor.models.utils import (
is_pp_missing_parameter, make_empty_intermediate_tensors_factory,
make_layers, maybe_prefix)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType, direct_register_custom_op
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
......@@ -194,17 +189,13 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
self.chunk_size = self.config.mamba_chunk_size
if envs.VLLM_USE_V1:
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
# The outer list is for v0 PP virtual engine. Though this code path
# only runs for v1, we have to do this to unify with the interface
# of Attention + v0 PP.
# The inner tuple is (conv_state, ssm_state)
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
assert self.chunk_size != -1, "chunk_size must be set for v1"
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
# The tuple is (conv_state, ssm_state)
self.kv_cache = (torch.tensor([]), torch.tensor([]))
assert self.chunk_size != -1, "chunk_size must be set for v1"
self.prefix = prefix
......@@ -227,8 +218,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs,
):
pass
......@@ -237,59 +226,43 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs,
):
if not envs.VLLM_USE_V1:
CustomOp.forward(self, hidden_states, output, mamba_cache_params,
mamba2_metadata)
else:
torch.ops.vllm.plamo2_mamba_mixer(
hidden_states,
output,
self.prefix,
)
torch.ops.vllm.plamo2_mamba_mixer(
hidden_states,
output,
self.prefix,
)
def forward_cuda(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs,
):
forward_context = get_forward_context()
# mamba2_metadata contains metadata necessary for the mamba2 triton
# attn_metadata contains metadata necessary for the mamba2 triton
# kernels to operate in continuous batching and in chunked prefill
# modes; they are computed at top-level model forward since they
# stay the same and reused for all mamba layers in the same iteration
attn_metadata: AttentionMetadata = forward_context.attn_metadata
if envs.VLLM_USE_V1:
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
mamba2_metadata = attn_metadata
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
state_indices_tensor = attn_metadata.state_indices_tensor
else:
conv_state = mamba_cache_params.conv_state
ssm_state = mamba_cache_params.ssm_state
state_indices_tensor = mamba_cache_params.state_indices_tensor
# Common members between V1 metadata and V0 metadata
if mamba2_metadata is not None:
has_initial_states_p = mamba2_metadata.has_initial_states_p
prep_initial_states = mamba2_metadata.prep_initial_states
chunk_size = mamba2_metadata.chunk_size
seq_idx_p = mamba2_metadata.seq_idx_p
chunk_indices_p = mamba2_metadata.chunk_indices_p
chunk_offsets_p = mamba2_metadata.chunk_offsets_p
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
state_indices_tensor = attn_metadata.state_indices_tensor
has_initial_states_p = attn_metadata.has_initial_states_p
prep_initial_states = attn_metadata.prep_initial_states
chunk_size = attn_metadata.chunk_size
seq_idx_p = attn_metadata.seq_idx_p
chunk_indices_p = attn_metadata.chunk_indices_p
chunk_offsets_p = attn_metadata.chunk_offsets_p
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)
......@@ -299,8 +272,8 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))
if envs.VLLM_USE_V1 and attn_metadata is None:
# V1 profile run
if attn_metadata is None:
# profile run
hidden_states = (hidden_states.transpose(0, 1).clone().transpose(
0, 1)).contiguous()
output[:] = self.out_proj(hidden_states)
......@@ -316,42 +289,23 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
# Separate prefill and decode by splitting varlen input
# Split along token dimension
if envs.VLLM_USE_V1:
hidden_states_d, hidden_states_p = torch.split(
hidden_states[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
dim=0,
)
gate_d, gate_p = torch.split(gate[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
dim=0)
# Split along batch dimension
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor,
[num_decodes, num_prefills],
dim=0,
)
query_start_loc_p = (
attn_metadata.query_start_loc[-num_prefills - 1:] -
num_decodes if has_prefill else None)
else:
hidden_states_p, hidden_states_d = torch.split(
hidden_states,
[num_prefill_tokens, num_decodes],
dim=0,
)
gate_p, gate_d = torch.split(gate,
[num_prefill_tokens, num_decodes],
dim=0)
# Split along batch dimension
state_indices_tensor_p, state_indices_tensor_d = torch.split(
state_indices_tensor,
[num_prefills, num_decodes],
dim=0,
)
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills +
1]
if has_prefill else None)
hidden_states_d, hidden_states_p = torch.split(
hidden_states[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
dim=0,
)
gate_d, gate_p = torch.split(gate[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
dim=0)
# Split along batch dimension
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor,
[num_decodes, num_prefills],
dim=0,
)
query_start_loc_p = (
attn_metadata.query_start_loc[-num_prefills - 1:] -
num_decodes if has_prefill else None)
# Preallocate output tensor to avoid memcpy cost for merging prefill
# and decode outputs
......@@ -363,18 +317,11 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
dtype=hidden_states.dtype,
device=hidden_states.device,
)
if envs.VLLM_USE_V1:
preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
preallocated_ssm_out,
[num_decodes, num_prefill_tokens],
dim=0,
)
else:
preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
preallocated_ssm_out,
[num_prefill_tokens, num_decodes],
dim=0,
)
preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
preallocated_ssm_out,
[num_decodes, num_prefill_tokens],
dim=0,
)
# Process prefill requests
if has_prefill:
......@@ -383,9 +330,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
# pointed to by "state_indices_tensor"
x = hidden_states_p.transpose(
0, 1) # this is the form that causal-conv see
if mamba2_metadata.cu_seqlen is None:
mamba2_metadata = update_metadata(x, query_start_loc_p,
mamba2_metadata)
hidden_states_p = causal_conv1d_fn(
x,
conv_weights,
......@@ -394,7 +338,7 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
conv_states=conv_state,
has_initial_state=has_initial_states_p,
cache_indices=state_indices_tensor_p,
metadata=mamba2_metadata,
metadata=attn_metadata,
query_start_loc=query_start_loc_p)
hidden_states_p = hidden_states_p.transpose(0, 1)
hidden_states_p = hidden_states_p[:num_prefill_tokens]
......@@ -470,7 +414,7 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
-1, self.num_heads // self.tp_size, self.head_dim)
# - the hidden is reshaped into (bs, num_heads, head_dim)
# - mamba_cache_params.ssm_state's slots will be selected
# - ssm_state's slots will be selected
# using state_indices_tensor_d
# NOTE: final output is an in-place update of out tensor
......@@ -530,10 +474,7 @@ def plamo2_mamba_mixer(
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self.forward_cuda(hidden_states=hidden_states,
output=output,
mamba_cache_params=None,
mamba2_metadata=None)
self.forward_cuda(hidden_states=hidden_states, output=output)
def plamo2_mamba_mixer_fake(
......@@ -731,8 +672,6 @@ class Plamo2DecoderLayer(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs,
):
if residual is None:
......@@ -747,8 +686,6 @@ class Plamo2DecoderLayer(nn.Module):
output = torch.empty_like(hidden_states)
mixer_kwargs = {
"output": output,
"mamba_cache_params": mamba_cache_params,
"mamba2_metadata": mamba2_metadata,
}
else:
mixer_kwargs = {
......@@ -790,23 +727,12 @@ class Plamo2Decoder(torch.nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
) -> torch.Tensor:
mamba_cache_index = 0
for layer in islice(self.layers, self.start_layer, self.end_layer):
layer_mamba_cache_params = None
if layer.is_mamba and mamba_cache_params is not None:
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
mamba_cache_index)
mamba_cache_index += 1
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states,
residual=residual,
mamba_cache_params=layer_mamba_cache_params,
mamba2_metadata=mamba2_metadata,
)
return hidden_states, residual
......@@ -844,7 +770,6 @@ class Plamo2Model(torch.nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
......@@ -859,23 +784,10 @@ class Plamo2Model(torch.nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
if not envs.VLLM_USE_V1:
attn_metadata: AttentionMetadata = get_forward_context(
).attn_metadata
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size,
attn_metadata=attn_metadata,
)
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None
hidden_states, residual = self.layers(
positions=positions,
hidden_states=hidden_states,
residual=residual,
mamba_cache_params=mamba_cache_params,
mamba2_metadata=mamba2_metadata,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
......@@ -925,9 +837,6 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid):
if self.config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
self.config.vocab_size)
self.make_empty_intermediate_tensors = (
......@@ -942,39 +851,11 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid):
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
num_mamba_layers = (
self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config,
LayerBlockType.mamba))
mamba_state_shape = self.get_mamba_state_shape_from_config(
self.vllm_config, use_v1=False)
mamba_state_dtype = \
self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
num_mamba_layers,
*mamba_state_shape,
*mamba_state_dtype)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
else:
# NOTE: mamba_cache_params is not needed for v1
mamba_cache_params = None
hidden_states = self.model(input_ids, positions, mamba_cache_params,
intermediate_tensors, inputs_embeds)
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
@classmethod
def get_mamba_state_dtype_from_config(
cls,
......@@ -991,12 +872,10 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid):
def get_mamba_state_shape_from_config(
cls,
vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches.
Args:
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Tuple containing:
- conv_state_shape: Shape for convolutional state cache
......@@ -1015,7 +894,6 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid):
head_dim=hf_config.hidden_size_per_head,
state_size=hf_config.mamba_d_state,
conv_kernel=hf_config.mamba_d_conv,
use_v1=use_v1,
)
def compute_logits(
......
......@@ -11,7 +11,6 @@ from einops import rearrange
from torch import nn
from transformers.activations import ACT2FN
from vllm import envs
from vllm.attention import Attention, AttentionBackend, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig,
......@@ -35,7 +34,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba2_metadata import update_metadata
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
mamba_v2_sharded_weight_loader)
from vllm.model_executor.layers.mamba.mamba_utils import (
......@@ -51,7 +49,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.mamba_cache import MambaCacheParams
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
......@@ -198,14 +195,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
return MambaStateShapeCalculator.gated_delta_net_state_shape(
self.tp_size,
self.num_k_heads,
self.num_v_heads,
self.head_k_dim,
self.head_v_dim,
self.conv_kernel_size,
self.num_spec,
use_v1=True)
self.tp_size, self.num_k_heads, self.num_v_heads, self.head_k_dim,
self.head_v_dim, self.conv_kernel_size, self.num_spec)
def __init__(
self,
......@@ -394,7 +385,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
cache_params: Optional[MambaCacheParams] = None,
):
return torch.ops.vllm.gdn_attention(
hidden_states,
......@@ -416,7 +406,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
conv_metadata = attn_metadata
assert isinstance(attn_metadata, GDNAttentionMetadata)
has_initial_state = attn_metadata.has_initial_state
spec_query_start_loc = attn_metadata.spec_query_start_loc
......@@ -479,12 +468,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
# 2.2: process the remaining part
if attn_metadata.num_prefills > 0:
mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1)
if conv_metadata.cu_seqlen is None:
conv_metadata = update_metadata(mixed_qkv_non_spec_T,
non_spec_query_start_loc,
conv_metadata)
# - "cache_indices" updates the conv_state cache in positions
# pointed to by "mamba_cache_params.state_indices_tensor"
# pointed to by "state_indices_tensor"
mixed_qkv_non_spec = causal_conv1d_fn(
mixed_qkv_non_spec_T,
conv_weights,
......@@ -494,7 +479,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
has_initial_state=has_initial_state,
cache_indices=non_spec_state_indices_tensor,
query_start_loc=non_spec_query_start_loc,
metadata=conv_metadata,
metadata=attn_metadata,
).transpose(0, 1)
elif attn_metadata.num_decodes > 0:
mixed_qkv_non_spec = causal_conv1d_update(
......@@ -1075,7 +1060,6 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
scheduler_config = vllm_config.scheduler_config
assert not cache_config.enable_prefix_caching, \
"Qwen3Next currently does not support prefix caching"
assert envs.VLLM_USE_V1, "Qwen3Next requires VLLM_USE_V1"
self.quant_config = vllm_config.quant_config
super().__init__()
......@@ -1195,14 +1179,10 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
num_spec = (vllm_config.speculative_config.num_speculative_tokens
if vllm_config.speculative_config else 0)
return MambaStateShapeCalculator.gated_delta_net_state_shape(
tp_size,
hf_config.linear_num_key_heads,
hf_config.linear_num_value_heads,
hf_config.linear_key_head_dim,
hf_config.linear_value_head_dim,
hf_config.linear_conv_kernel_dim,
num_spec,
use_v1=True)
tp_size, hf_config.linear_num_key_heads,
hf_config.linear_num_value_heads, hf_config.linear_key_head_dim,
hf_config.linear_value_head_dim, hf_config.linear_conv_kernel_dim,
num_spec)
def compute_logits(
self,
......
......@@ -134,7 +134,6 @@ _TEXT_GENERATION_MODELS = {
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
"Phi4FlashForCausalLM": ("phi4flash", "Phi4FlashForCausalLM"),
"Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
......
......@@ -15,12 +15,10 @@ import torch
from torch import nn
from transformers import Zamba2Config
from vllm import envs
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
......@@ -29,8 +27,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator)
......@@ -39,8 +35,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.sequence import IntermediateTensors
from .interfaces import HasInnerState, IsHybrid
......@@ -515,8 +509,6 @@ class Zamba2MambaDecoderLayer(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
transformer_hidden_states: Optional[torch.Tensor] = None,
positions: Optional[torch.Tensor] = None,
original_hidden_states: Optional[torch.Tensor] = None,
......@@ -525,8 +517,6 @@ class Zamba2MambaDecoderLayer(nn.Module):
Args:
hidden_states: Input tensor [batch_size, seq_len, hidden_size]
mamba_cache_params: Parameters for Mamba's state caches
(one for conv, one for ssm)
transformer_hidden_states: Optional output from transformer path
Added to input if provided (used in hybrid architecture)
positions: Optional position IDs (unused in Mamba)
......@@ -555,8 +545,6 @@ class Zamba2MambaDecoderLayer(nn.Module):
self.mamba(
hidden_states,
output,
mamba_cache_params=mamba_cache_params,
mamba2_metadata=mamba2_metadata,
)
# residual connection after mamba
......@@ -607,8 +595,6 @@ class Zamba2HybridLayer(nn.Module):
hidden_states: torch.Tensor,
original_hidden_states: torch.Tensor,
positions: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
) -> torch.Tensor:
"""Forward pass through the hybrid layer.
......@@ -623,8 +609,6 @@ class Zamba2HybridLayer(nn.Module):
original_hidden_states: Original input for transformer residual
connection
positions: Position IDs for positional embeddings
mamba_cache_params: Parameters for Mamba's state caches
(one for conv, one for ssm)
Returns:
Output tensor combining transformer and Mamba representations
......@@ -644,8 +628,6 @@ class Zamba2HybridLayer(nn.Module):
layer_outputs = self.mamba_decoder(
hidden_states,
transformer_hidden_states=transformer_hidden_states,
mamba_cache_params=mamba_cache_params,
mamba2_metadata=mamba2_metadata,
)
return layer_outputs
......@@ -752,7 +734,6 @@ class Zamba2Model(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
mamba_cache_params: MambaCacheParams,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
"""Forward pass through the model.
......@@ -760,8 +741,6 @@ class Zamba2Model(nn.Module):
Args:
input_ids: Input token IDs
positions: Position IDs for embeddings
mamba_cache_params: Parameters for Mamba's state caches
(one for conv, one for ssm)
inputs_embeds: Optional pre-computed input embeddings
Returns:
......@@ -773,33 +752,13 @@ class Zamba2Model(nn.Module):
inputs_embeds = self.get_input_embeddings(input_ids)
hidden_states = inputs_embeds
attn_metadata = get_forward_context().attn_metadata
if not envs.VLLM_USE_V1:
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.chunk_size,
attn_metadata=attn_metadata,
)
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None
# Process through layers
original_hidden_states = torch.clone(hidden_states)
for layer_idx, layer in enumerate(self.layers):
layer_mamba_cache_params = None
if (isinstance(layer, (Zamba2HybridLayer, Zamba2MambaDecoderLayer))
and mamba_cache_params):
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
layer_idx)
layer_outputs = layer(
hidden_states,
original_hidden_states=original_hidden_states,
positions=positions,
mamba_cache_params=layer_mamba_cache_params,
mamba2_metadata=mamba2_metadata,
)
hidden_states = layer_outputs
......@@ -870,13 +829,11 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
def get_mamba_state_shape_from_config(
cls,
vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches.
Args:
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Tuple containing:
......@@ -896,7 +853,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
head_dim=hf_config.mamba_headdim,
state_size=hf_config.mamba_d_state,
conv_kernel=hf_config.mamba_d_conv,
use_v1=use_v1,
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
......@@ -945,9 +901,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
# Tie weights with input embeddings if using same dimensions
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
# Initialize logits processing and sampling
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
......@@ -977,61 +930,15 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
Returns:
Output hidden states
"""
# Initialize Mamba cache if needed
mamba_cache_params = None
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
num_mamba_layers = self.config.num_hidden_layers
mamba_state_shape = \
self.get_mamba_state_shape_from_config(
self.vllm_config, use_v1=False)
mamba_state_dtype = \
self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
num_mamba_layers,
*mamba_state_shape,
*mamba_state_dtype)
# Get cache parameters for current run
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
# Forward pass through model
hidden_states = self.model(
input_ids,
positions,
mamba_cache_params,
inputs_embeds,
)
return hidden_states
def copy_inputs_before_cuda_graphs(
self, input_buffers: dict[str, torch.Tensor],
**kwargs: Any) -> dict[str, torch.Tensor]:
"""Copy inputs before CUDA graph capture.
Args:
input_buffers: Dictionary of input tensors
**kwargs: Additional arguments passed to cache manager
Returns:
Updated input buffers
"""
return self.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(
self, batch_size: int) -> dict[str, torch.Tensor]:
"""Get inputs for sequence-length-agnostic graph capture.
Args:
batch_size: Size of batch to capture
Returns:
Dictionary of capture inputs
"""
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def compute_logits(
self,
hidden_states: torch.Tensor,
......
......@@ -12,6 +12,7 @@ from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
compute_causal_conv1d_metadata,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
......@@ -52,7 +53,6 @@ class GDNAttentionMetadata:
# The following attributes are for triton implementation of causal_conv1d
nums_dict: Optional[dict] = None
cu_seqlen: Optional[int] = None
batch_ptr: Optional[torch.Tensor] = None
token_chunk_offset_ptr: Optional[torch.Tensor] = None
......@@ -134,6 +134,7 @@ class GDNAttentionMetadataBuilder(
context_lens = m.num_computed_tokens_cpu
context_lens_tensor = context_lens.to(query_start_loc.device)
seq_lens_tensor = m.seq_lens
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
if (not self.use_spec_decode or num_draft_tokens is None
or num_draft_tokens.sum().item() == 0):
......@@ -210,6 +211,8 @@ class GDNAttentionMetadataBuilder(
has_initial_state = context_lens_tensor > 0
if spec_sequence_masks is not None:
has_initial_state = has_initial_state[~spec_sequence_masks]
nums_dict, batch_ptr, token_chunk_offset_ptr = \
compute_causal_conv1d_metadata(non_spec_query_start_loc)
else:
has_initial_state = None
num_actual_tokens = num_prefill_tokens + num_decode_tokens + \
......@@ -297,6 +300,9 @@ class GDNAttentionMetadataBuilder(
spec_sequence_masks=spec_sequence_masks,
spec_token_masks=spec_token_masks,
num_accepted_tokens=num_accepted_tokens,
nums_dict=nums_dict,
batch_ptr=batch_ptr,
token_chunk_offset_ptr=token_chunk_offset_ptr,
)
return attn_metadata
......
......@@ -7,11 +7,12 @@ from typing import Optional
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig
from vllm.v1.attention.backends.mamba_attn import (
BaseMambaAttentionMetadataBuilder)
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
from vllm.v1.attention.backends.utils import (PAD_SLOT_ID,
CommonAttentionMetadata,
compute_causal_conv1d_metadata,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec
......@@ -131,7 +132,6 @@ class Mamba2AttentionMetadata:
# The following attributes are for triton implementation of causal_conv1d
nums_dict: Optional[dict] = None
cu_seqlen: Optional[int] = None
batch_ptr: Optional[torch.Tensor] = None
token_chunk_offset_ptr: Optional[torch.Tensor] = None
......@@ -161,6 +161,9 @@ class Mamba2AttentionMetadataBuilder(
has_initial_states_p = None
prep_initial_states = False
# for causal_conv1d
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
......@@ -198,6 +201,9 @@ class Mamba2AttentionMetadataBuilder(
query_start_loc_p, self.chunk_size,
num_prefill_tokens))
nums_dict, batch_ptr, token_chunk_offset_ptr = \
compute_causal_conv1d_metadata(query_start_loc_p)
elif num_decodes <= self.decode_cudagraph_max_bs:
# Pad state tensor for CUDA graph
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes)
......@@ -220,5 +226,8 @@ class Mamba2AttentionMetadataBuilder(
chunk_indices_p=chunk_indices_p,
chunk_offsets_p=chunk_offsets_p,
state_indices_tensor=state_indices_tensor,
nums_dict=nums_dict,
batch_ptr=batch_ptr,
token_chunk_offset_ptr=token_chunk_offset_ptr,
)
return attn_metadata
......@@ -9,6 +9,7 @@ from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata,
compute_causal_conv1d_metadata,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
......@@ -33,7 +34,6 @@ class ShortConvAttentionMetadata:
# For causal_conv1d
nums_dict: Optional[dict] = None
cu_seqlen: Optional[int] = None
batch_ptr: Optional[torch.Tensor] = None
token_chunk_offset_ptr: Optional[torch.Tensor] = None
......@@ -57,6 +57,9 @@ class ShortConvAttentionMetadataBuilder(
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
# for causal_conv1d
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata,
......@@ -70,6 +73,12 @@ class ShortConvAttentionMetadataBuilder(
has_initial_states = has_initial_states_cpu.to(
query_start_loc.device)
query_start_loc_p = common_attn_metadata.query_start_loc[
-num_prefills - 1:] - num_decode_tokens
nums_dict, batch_ptr, token_chunk_offset_ptr = \
compute_causal_conv1d_metadata(query_start_loc_p)
attn_metadata = ShortConvAttentionMetadata(
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
......@@ -78,5 +87,8 @@ class ShortConvAttentionMetadataBuilder(
query_start_loc=query_start_loc,
has_initial_states=has_initial_states,
state_indices_tensor=state_indices_tensor,
nums_dict=nums_dict,
batch_ptr=batch_ptr,
token_chunk_offset_ptr=token_chunk_offset_ptr,
)
return attn_metadata
......@@ -34,6 +34,8 @@ logger = init_logger(__name__)
KVCacheLayoutType = Literal["NHD", "HND"]
_KV_CACHE_LAYOUT_OVERRIDE: Union[KVCacheLayoutType, None] = None
PAD_SLOT_ID = -1
def is_valid_kv_cache_layout(value: str) -> bool:
return value in get_args(KVCacheLayoutType)
......@@ -838,3 +840,52 @@ def create_fast_prefill_custom_backend(
builder_cls=FastPrefillAttentionBuilder)
return attn_backend
def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor):
# Needed for causal_conv1d
seqlens = query_start_loc_p.diff().to('cpu')
nums_dict = {} # type: ignore
batch_ptr = None
token_chunk_offset_ptr = None
for BLOCK_M in [8]: # cover all BLOCK_M values
nums = -(-seqlens // BLOCK_M)
nums_dict[BLOCK_M] = {}
nums_dict[BLOCK_M]['nums'] = nums
nums_dict[BLOCK_M]['tot'] = nums.sum().item()
mlist = torch.from_numpy(np.repeat(np.arange(len(nums)), nums))
nums_dict[BLOCK_M]['mlist'] = mlist
mlist_len = len(nums_dict[BLOCK_M]['mlist'])
nums_dict[BLOCK_M]['mlist_len'] = mlist_len
MAX_NUM_PROGRAMS = max(1024, mlist_len) * 2
offsetlist = [] # type: ignore
for idx, num in enumerate(nums):
offsetlist.extend(range(num))
offsetlist = torch.tensor(offsetlist, dtype=torch.int32)
nums_dict[BLOCK_M]['offsetlist'] = offsetlist
if batch_ptr is None:
# Update default value after class definition
batch_ptr = torch.full((MAX_NUM_PROGRAMS, ),
PAD_SLOT_ID,
dtype=torch.int32,
device='cuda')
token_chunk_offset_ptr = torch.full((MAX_NUM_PROGRAMS, ),
PAD_SLOT_ID,
dtype=torch.int32,
device='cuda')
else:
if batch_ptr.nelement() < MAX_NUM_PROGRAMS:
batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID)
token_chunk_offset_ptr.resize_( # type: ignore
MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID)
batch_ptr[0:mlist_len].copy_(mlist)
token_chunk_offset_ptr[ # type: ignore
0:mlist_len].copy_(offsetlist)
nums_dict[BLOCK_M]['batch_ptr'] = batch_ptr
nums_dict[BLOCK_M]['token_chunk_offset_ptr'] = (token_chunk_offset_ptr
) # type: ignore
return nums_dict, batch_ptr, token_chunk_offset_ptr
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment