Unverified Commit a903669e authored by Thomas Parnell's avatar Thomas Parnell Committed by GitHub
Browse files

[V1] Remove V0 code paths for Hybrid models (#25400)


Signed-off-by: default avatarThomas Parnell <tpa@zurich.ibm.com>
parent 2c58742d
......@@ -14,7 +14,6 @@ import torch.distributed
from torch import nn
from transformers import MiniMaxConfig
from vllm import envs
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
......@@ -44,7 +43,6 @@ from vllm.model_executor.models.utils import maybe_prefix
from vllm.sequence import IntermediateTensors
from .interfaces import HasInnerState, IsHybrid
from .minimax_cache import MinimaxCacheManager, MinimaxCacheParams
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
......@@ -404,7 +402,6 @@ class MiniMaxText01DecoderLayer(nn.Module):
def forward(self,
hidden_states: torch.Tensor,
positions: torch.Tensor,
kv_caches: Union[list[dict], Optional[torch.Tensor]],
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
is_warmup: bool = False,
......@@ -418,7 +415,6 @@ class MiniMaxText01DecoderLayer(nn.Module):
hidden_states=layernorm_output,
output=self_attention_output,
positions=positions,
kv_caches=kv_caches,
)
residual = residual * self.layernorm_attention_alpha
......@@ -563,10 +559,6 @@ class MiniMaxText01Model(nn.Module):
self._dtype = _dummy.dtype
del _dummy
if not envs.VLLM_USE_V1:
self.minimax_cache = MinimaxCacheManager(
dtype=torch.float32, cache_shape=self.cache_shape)
norm_kwargs = {}
if hasattr(config, "rms_norm_eps"):
norm_kwargs["eps"] = config.rms_norm_eps
......@@ -614,25 +606,6 @@ class MiniMaxText01Model(nn.Module):
**kwargs) -> Union[torch.Tensor, IntermediateTensors]:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
if not envs.VLLM_USE_V1 and attn_metadata is None:
return None
if not envs.VLLM_USE_V1:
if "request_ids_to_seq_ids" not in kwargs:
kwargs["request_ids_to_seq_ids"] = {}
if "finished_requests_ids" not in kwargs:
kwargs["finished_requests_ids"] = []
(
minimax_cache_tensors,
state_indices_tensor,
) = self.minimax_cache.current_run_tensors(**kwargs)
if getattr(attn_metadata, "num_prefills", 0) > 0:
self._clear_prefill_cache(attn_metadata, minimax_cache_tensors,
**kwargs)
minimax_cache_params = MinimaxCacheParams(minimax_cache_tensors,
state_indices_tensor)
else:
minimax_cache_params = None
if get_pp_group().is_first_rank:
if inputs_embeds is None:
......@@ -645,20 +618,10 @@ class MiniMaxText01Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
minimax_cache_index = 0
for layer in islice(self.layers, self.start_layer, self.end_layer):
_caches = None
if not envs.VLLM_USE_V1 and isinstance(
layer.self_attn, MiniMaxText01LinearAttention):
current_state_layer = minimax_cache_index
_caches = minimax_cache_params.at_layer_idx(
current_state_layer)
minimax_cache_index += 1
hidden_states, residual = layer(
hidden_states=hidden_states,
positions=positions,
kv_caches=_caches,
attn_metadata=attn_metadata,
residual=residual,
)
......@@ -1003,13 +966,11 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
def get_mamba_state_shape_from_config(
cls,
vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, ...], ...]:
"""Calculate shape for MiniMaxText01LinearAttention cache.
Args:
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Tuple containing:
......
......@@ -23,21 +23,17 @@ from typing import Optional
import torch
from torch import nn
from vllm import envs
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import ReLUSquaredActivation
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator)
......@@ -49,14 +45,11 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
SupportsLoRA, SupportsPP,
SupportsQuant)
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.model_executor.models.utils import (
AutoWeightsLoader, WeightsMapper, make_empty_intermediate_tensors_factory,
make_layers, maybe_prefix)
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import NemotronHConfig
from vllm.utils import LayerBlockType
class NemotronHMLP(nn.Module):
......@@ -181,8 +174,6 @@ class NemotronHMambaDecoderLayer(nn.Module):
self,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs,
):
if residual is None:
......@@ -192,7 +183,7 @@ class NemotronHMambaDecoderLayer(nn.Module):
hidden_states, residual = self.norm(hidden_states, residual)
output = torch.empty_like(hidden_states)
self.mixer(hidden_states, output, mamba_cache_params, mamba2_metadata)
self.mixer(hidden_states, output)
return output, residual
......@@ -370,22 +361,10 @@ class NemotronHModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
attn_metadata = get_forward_context().attn_metadata
if not envs.VLLM_USE_V1:
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.chunk_size,
attn_metadata=attn_metadata,
)
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
......@@ -398,22 +377,11 @@ class NemotronHModel(nn.Module):
residual = intermediate_tensors["residual"]
residual = None
num_non_mamba_layers = 0
for i, layer in enumerate(self.layers):
layer_mamba_cache_params = None
if isinstance(layer,
NemotronHMambaDecoderLayer) and mamba_cache_params:
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
i - num_non_mamba_layers)
else:
num_non_mamba_layers += 1
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states,
residual=residual,
mamba_cache_params=layer_mamba_cache_params,
mamba2_metadata=mamba2_metadata,
)
if not get_pp_group().is_last_rank:
......@@ -508,13 +476,11 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def get_mamba_state_shape_from_config(
cls,
vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches.
Args:
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Tuple containing:
......@@ -533,7 +499,6 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
head_dim=hf_config.mamba_head_dim,
state_size=hf_config.ssm_state_size,
conv_kernel=hf_config.conv_kernel,
use_v1=use_v1,
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
......@@ -566,8 +531,6 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
if not lora_config else lora_config.lora_vocab_padding_size,
prefix=maybe_prefix(prefix, "lm_head"),
)
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
......@@ -584,40 +547,11 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
mamba_cache_params = None
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
num_mamba_layers = \
self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config,
LayerBlockType.mamba
)
mamba_state_shape = \
self.get_mamba_state_shape_from_config(
self.vllm_config, use_v1=False)
mamba_state_dtype = \
self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
num_mamba_layers,
*mamba_state_shape,
*mamba_state_dtype)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.model(input_ids, positions, mamba_cache_params,
intermediate_tensors, inputs_embeds)
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def compute_logits(
self,
hidden_states: torch.Tensor,
......
This diff is collapsed.
......@@ -12,7 +12,6 @@ import torch
from torch import nn
from transformers import PretrainedConfig
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
......@@ -29,8 +28,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata, update_metadata)
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
......@@ -47,15 +44,13 @@ from vllm.model_executor.model_loader.weight_utils import (
composed_weight_loader, default_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
SupportsPP)
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.model_executor.models.utils import (
is_pp_missing_parameter, make_empty_intermediate_tensors_factory,
make_layers, maybe_prefix)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType, direct_register_custom_op
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
......@@ -194,17 +189,13 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
self.chunk_size = self.config.mamba_chunk_size
if envs.VLLM_USE_V1:
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
# The outer list is for v0 PP virtual engine. Though this code path
# only runs for v1, we have to do this to unify with the interface
# of Attention + v0 PP.
# The inner tuple is (conv_state, ssm_state)
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
assert self.chunk_size != -1, "chunk_size must be set for v1"
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
# The tuple is (conv_state, ssm_state)
self.kv_cache = (torch.tensor([]), torch.tensor([]))
assert self.chunk_size != -1, "chunk_size must be set for v1"
self.prefix = prefix
......@@ -227,8 +218,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs,
):
pass
......@@ -237,59 +226,43 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs,
):
if not envs.VLLM_USE_V1:
CustomOp.forward(self, hidden_states, output, mamba_cache_params,
mamba2_metadata)
else:
torch.ops.vllm.plamo2_mamba_mixer(
hidden_states,
output,
self.prefix,
)
torch.ops.vllm.plamo2_mamba_mixer(
hidden_states,
output,
self.prefix,
)
def forward_cuda(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs,
):
forward_context = get_forward_context()
# mamba2_metadata contains metadata necessary for the mamba2 triton
# attn_metadata contains metadata necessary for the mamba2 triton
# kernels to operate in continuous batching and in chunked prefill
# modes; they are computed at top-level model forward since they
# stay the same and reused for all mamba layers in the same iteration
attn_metadata: AttentionMetadata = forward_context.attn_metadata
if envs.VLLM_USE_V1:
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
mamba2_metadata = attn_metadata
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
state_indices_tensor = attn_metadata.state_indices_tensor
else:
conv_state = mamba_cache_params.conv_state
ssm_state = mamba_cache_params.ssm_state
state_indices_tensor = mamba_cache_params.state_indices_tensor
# Common members between V1 metadata and V0 metadata
if mamba2_metadata is not None:
has_initial_states_p = mamba2_metadata.has_initial_states_p
prep_initial_states = mamba2_metadata.prep_initial_states
chunk_size = mamba2_metadata.chunk_size
seq_idx_p = mamba2_metadata.seq_idx_p
chunk_indices_p = mamba2_metadata.chunk_indices_p
chunk_offsets_p = mamba2_metadata.chunk_offsets_p
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
state_indices_tensor = attn_metadata.state_indices_tensor
has_initial_states_p = attn_metadata.has_initial_states_p
prep_initial_states = attn_metadata.prep_initial_states
chunk_size = attn_metadata.chunk_size
seq_idx_p = attn_metadata.seq_idx_p
chunk_indices_p = attn_metadata.chunk_indices_p
chunk_offsets_p = attn_metadata.chunk_offsets_p
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)
......@@ -299,8 +272,8 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))
if envs.VLLM_USE_V1 and attn_metadata is None:
# V1 profile run
if attn_metadata is None:
# profile run
hidden_states = (hidden_states.transpose(0, 1).clone().transpose(
0, 1)).contiguous()
output[:] = self.out_proj(hidden_states)
......@@ -316,42 +289,23 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
# Separate prefill and decode by splitting varlen input
# Split along token dimension
if envs.VLLM_USE_V1:
hidden_states_d, hidden_states_p = torch.split(
hidden_states[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
dim=0,
)
gate_d, gate_p = torch.split(gate[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
dim=0)
# Split along batch dimension
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor,
[num_decodes, num_prefills],
dim=0,
)
query_start_loc_p = (
attn_metadata.query_start_loc[-num_prefills - 1:] -
num_decodes if has_prefill else None)
else:
hidden_states_p, hidden_states_d = torch.split(
hidden_states,
[num_prefill_tokens, num_decodes],
dim=0,
)
gate_p, gate_d = torch.split(gate,
[num_prefill_tokens, num_decodes],
dim=0)
# Split along batch dimension
state_indices_tensor_p, state_indices_tensor_d = torch.split(
state_indices_tensor,
[num_prefills, num_decodes],
dim=0,
)
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills +
1]
if has_prefill else None)
hidden_states_d, hidden_states_p = torch.split(
hidden_states[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
dim=0,
)
gate_d, gate_p = torch.split(gate[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
dim=0)
# Split along batch dimension
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor,
[num_decodes, num_prefills],
dim=0,
)
query_start_loc_p = (
attn_metadata.query_start_loc[-num_prefills - 1:] -
num_decodes if has_prefill else None)
# Preallocate output tensor to avoid memcpy cost for merging prefill
# and decode outputs
......@@ -363,18 +317,11 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
dtype=hidden_states.dtype,
device=hidden_states.device,
)
if envs.VLLM_USE_V1:
preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
preallocated_ssm_out,
[num_decodes, num_prefill_tokens],
dim=0,
)
else:
preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
preallocated_ssm_out,
[num_prefill_tokens, num_decodes],
dim=0,
)
preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
preallocated_ssm_out,
[num_decodes, num_prefill_tokens],
dim=0,
)
# Process prefill requests
if has_prefill:
......@@ -383,9 +330,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
# pointed to by "state_indices_tensor"
x = hidden_states_p.transpose(
0, 1) # this is the form that causal-conv see
if mamba2_metadata.cu_seqlen is None:
mamba2_metadata = update_metadata(x, query_start_loc_p,
mamba2_metadata)
hidden_states_p = causal_conv1d_fn(
x,
conv_weights,
......@@ -394,7 +338,7 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
conv_states=conv_state,
has_initial_state=has_initial_states_p,
cache_indices=state_indices_tensor_p,
metadata=mamba2_metadata,
metadata=attn_metadata,
query_start_loc=query_start_loc_p)
hidden_states_p = hidden_states_p.transpose(0, 1)
hidden_states_p = hidden_states_p[:num_prefill_tokens]
......@@ -470,7 +414,7 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
-1, self.num_heads // self.tp_size, self.head_dim)
# - the hidden is reshaped into (bs, num_heads, head_dim)
# - mamba_cache_params.ssm_state's slots will be selected
# - ssm_state's slots will be selected
# using state_indices_tensor_d
# NOTE: final output is an in-place update of out tensor
......@@ -530,10 +474,7 @@ def plamo2_mamba_mixer(
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self.forward_cuda(hidden_states=hidden_states,
output=output,
mamba_cache_params=None,
mamba2_metadata=None)
self.forward_cuda(hidden_states=hidden_states, output=output)
def plamo2_mamba_mixer_fake(
......@@ -731,8 +672,6 @@ class Plamo2DecoderLayer(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs,
):
if residual is None:
......@@ -747,8 +686,6 @@ class Plamo2DecoderLayer(nn.Module):
output = torch.empty_like(hidden_states)
mixer_kwargs = {
"output": output,
"mamba_cache_params": mamba_cache_params,
"mamba2_metadata": mamba2_metadata,
}
else:
mixer_kwargs = {
......@@ -790,23 +727,12 @@ class Plamo2Decoder(torch.nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
) -> torch.Tensor:
mamba_cache_index = 0
for layer in islice(self.layers, self.start_layer, self.end_layer):
layer_mamba_cache_params = None
if layer.is_mamba and mamba_cache_params is not None:
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
mamba_cache_index)
mamba_cache_index += 1
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states,
residual=residual,
mamba_cache_params=layer_mamba_cache_params,
mamba2_metadata=mamba2_metadata,
)
return hidden_states, residual
......@@ -844,7 +770,6 @@ class Plamo2Model(torch.nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
......@@ -859,23 +784,10 @@ class Plamo2Model(torch.nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
if not envs.VLLM_USE_V1:
attn_metadata: AttentionMetadata = get_forward_context(
).attn_metadata
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size,
attn_metadata=attn_metadata,
)
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None
hidden_states, residual = self.layers(
positions=positions,
hidden_states=hidden_states,
residual=residual,
mamba_cache_params=mamba_cache_params,
mamba2_metadata=mamba2_metadata,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
......@@ -925,9 +837,6 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid):
if self.config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
self.config.vocab_size)
self.make_empty_intermediate_tensors = (
......@@ -942,39 +851,11 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid):
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
num_mamba_layers = (
self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config,
LayerBlockType.mamba))
mamba_state_shape = self.get_mamba_state_shape_from_config(
self.vllm_config, use_v1=False)
mamba_state_dtype = \
self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
num_mamba_layers,
*mamba_state_shape,
*mamba_state_dtype)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
else:
# NOTE: mamba_cache_params is not needed for v1
mamba_cache_params = None
hidden_states = self.model(input_ids, positions, mamba_cache_params,
intermediate_tensors, inputs_embeds)
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
@classmethod
def get_mamba_state_dtype_from_config(
cls,
......@@ -991,12 +872,10 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid):
def get_mamba_state_shape_from_config(
cls,
vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches.
Args:
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Tuple containing:
- conv_state_shape: Shape for convolutional state cache
......@@ -1015,7 +894,6 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid):
head_dim=hf_config.hidden_size_per_head,
state_size=hf_config.mamba_d_state,
conv_kernel=hf_config.mamba_d_conv,
use_v1=use_v1,
)
def compute_logits(
......
......@@ -11,7 +11,6 @@ from einops import rearrange
from torch import nn
from transformers.activations import ACT2FN
from vllm import envs
from vllm.attention import Attention, AttentionBackend, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig,
......@@ -35,7 +34,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba2_metadata import update_metadata
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
mamba_v2_sharded_weight_loader)
from vllm.model_executor.layers.mamba.mamba_utils import (
......@@ -51,7 +49,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.mamba_cache import MambaCacheParams
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
......@@ -198,14 +195,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
return MambaStateShapeCalculator.gated_delta_net_state_shape(
self.tp_size,
self.num_k_heads,
self.num_v_heads,
self.head_k_dim,
self.head_v_dim,
self.conv_kernel_size,
self.num_spec,
use_v1=True)
self.tp_size, self.num_k_heads, self.num_v_heads, self.head_k_dim,
self.head_v_dim, self.conv_kernel_size, self.num_spec)
def __init__(
self,
......@@ -394,7 +385,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
cache_params: Optional[MambaCacheParams] = None,
):
return torch.ops.vllm.gdn_attention(
hidden_states,
......@@ -416,7 +406,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
conv_metadata = attn_metadata
assert isinstance(attn_metadata, GDNAttentionMetadata)
has_initial_state = attn_metadata.has_initial_state
spec_query_start_loc = attn_metadata.spec_query_start_loc
......@@ -479,12 +468,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
# 2.2: process the remaining part
if attn_metadata.num_prefills > 0:
mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1)
if conv_metadata.cu_seqlen is None:
conv_metadata = update_metadata(mixed_qkv_non_spec_T,
non_spec_query_start_loc,
conv_metadata)
# - "cache_indices" updates the conv_state cache in positions
# pointed to by "mamba_cache_params.state_indices_tensor"
# pointed to by "state_indices_tensor"
mixed_qkv_non_spec = causal_conv1d_fn(
mixed_qkv_non_spec_T,
conv_weights,
......@@ -494,7 +479,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
has_initial_state=has_initial_state,
cache_indices=non_spec_state_indices_tensor,
query_start_loc=non_spec_query_start_loc,
metadata=conv_metadata,
metadata=attn_metadata,
).transpose(0, 1)
elif attn_metadata.num_decodes > 0:
mixed_qkv_non_spec = causal_conv1d_update(
......@@ -1075,7 +1060,6 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
scheduler_config = vllm_config.scheduler_config
assert not cache_config.enable_prefix_caching, \
"Qwen3Next currently does not support prefix caching"
assert envs.VLLM_USE_V1, "Qwen3Next requires VLLM_USE_V1"
self.quant_config = vllm_config.quant_config
super().__init__()
......@@ -1195,14 +1179,10 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
num_spec = (vllm_config.speculative_config.num_speculative_tokens
if vllm_config.speculative_config else 0)
return MambaStateShapeCalculator.gated_delta_net_state_shape(
tp_size,
hf_config.linear_num_key_heads,
hf_config.linear_num_value_heads,
hf_config.linear_key_head_dim,
hf_config.linear_value_head_dim,
hf_config.linear_conv_kernel_dim,
num_spec,
use_v1=True)
tp_size, hf_config.linear_num_key_heads,
hf_config.linear_num_value_heads, hf_config.linear_key_head_dim,
hf_config.linear_value_head_dim, hf_config.linear_conv_kernel_dim,
num_spec)
def compute_logits(
self,
......
......@@ -134,7 +134,6 @@ _TEXT_GENERATION_MODELS = {
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
"Phi4FlashForCausalLM": ("phi4flash", "Phi4FlashForCausalLM"),
"Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
......
......@@ -15,12 +15,10 @@ import torch
from torch import nn
from transformers import Zamba2Config
from vllm import envs
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
......@@ -29,8 +27,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator)
......@@ -39,8 +35,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.sequence import IntermediateTensors
from .interfaces import HasInnerState, IsHybrid
......@@ -515,8 +509,6 @@ class Zamba2MambaDecoderLayer(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
transformer_hidden_states: Optional[torch.Tensor] = None,
positions: Optional[torch.Tensor] = None,
original_hidden_states: Optional[torch.Tensor] = None,
......@@ -525,8 +517,6 @@ class Zamba2MambaDecoderLayer(nn.Module):
Args:
hidden_states: Input tensor [batch_size, seq_len, hidden_size]
mamba_cache_params: Parameters for Mamba's state caches
(one for conv, one for ssm)
transformer_hidden_states: Optional output from transformer path
Added to input if provided (used in hybrid architecture)
positions: Optional position IDs (unused in Mamba)
......@@ -555,8 +545,6 @@ class Zamba2MambaDecoderLayer(nn.Module):
self.mamba(
hidden_states,
output,
mamba_cache_params=mamba_cache_params,
mamba2_metadata=mamba2_metadata,
)
# residual connection after mamba
......@@ -607,8 +595,6 @@ class Zamba2HybridLayer(nn.Module):
hidden_states: torch.Tensor,
original_hidden_states: torch.Tensor,
positions: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
) -> torch.Tensor:
"""Forward pass through the hybrid layer.
......@@ -623,8 +609,6 @@ class Zamba2HybridLayer(nn.Module):
original_hidden_states: Original input for transformer residual
connection
positions: Position IDs for positional embeddings
mamba_cache_params: Parameters for Mamba's state caches
(one for conv, one for ssm)
Returns:
Output tensor combining transformer and Mamba representations
......@@ -644,8 +628,6 @@ class Zamba2HybridLayer(nn.Module):
layer_outputs = self.mamba_decoder(
hidden_states,
transformer_hidden_states=transformer_hidden_states,
mamba_cache_params=mamba_cache_params,
mamba2_metadata=mamba2_metadata,
)
return layer_outputs
......@@ -752,7 +734,6 @@ class Zamba2Model(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
mamba_cache_params: MambaCacheParams,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
"""Forward pass through the model.
......@@ -760,8 +741,6 @@ class Zamba2Model(nn.Module):
Args:
input_ids: Input token IDs
positions: Position IDs for embeddings
mamba_cache_params: Parameters for Mamba's state caches
(one for conv, one for ssm)
inputs_embeds: Optional pre-computed input embeddings
Returns:
......@@ -773,33 +752,13 @@ class Zamba2Model(nn.Module):
inputs_embeds = self.get_input_embeddings(input_ids)
hidden_states = inputs_embeds
attn_metadata = get_forward_context().attn_metadata
if not envs.VLLM_USE_V1:
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.chunk_size,
attn_metadata=attn_metadata,
)
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None
# Process through layers
original_hidden_states = torch.clone(hidden_states)
for layer_idx, layer in enumerate(self.layers):
layer_mamba_cache_params = None
if (isinstance(layer, (Zamba2HybridLayer, Zamba2MambaDecoderLayer))
and mamba_cache_params):
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
layer_idx)
layer_outputs = layer(
hidden_states,
original_hidden_states=original_hidden_states,
positions=positions,
mamba_cache_params=layer_mamba_cache_params,
mamba2_metadata=mamba2_metadata,
)
hidden_states = layer_outputs
......@@ -870,13 +829,11 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
def get_mamba_state_shape_from_config(
cls,
vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches.
Args:
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Tuple containing:
......@@ -896,7 +853,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
head_dim=hf_config.mamba_headdim,
state_size=hf_config.mamba_d_state,
conv_kernel=hf_config.mamba_d_conv,
use_v1=use_v1,
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
......@@ -945,9 +901,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
# Tie weights with input embeddings if using same dimensions
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
# Initialize logits processing and sampling
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
......@@ -977,61 +930,15 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
Returns:
Output hidden states
"""
# Initialize Mamba cache if needed
mamba_cache_params = None
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
num_mamba_layers = self.config.num_hidden_layers
mamba_state_shape = \
self.get_mamba_state_shape_from_config(
self.vllm_config, use_v1=False)
mamba_state_dtype = \
self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
num_mamba_layers,
*mamba_state_shape,
*mamba_state_dtype)
# Get cache parameters for current run
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
# Forward pass through model
hidden_states = self.model(
input_ids,
positions,
mamba_cache_params,
inputs_embeds,
)
return hidden_states
def copy_inputs_before_cuda_graphs(
self, input_buffers: dict[str, torch.Tensor],
**kwargs: Any) -> dict[str, torch.Tensor]:
"""Copy inputs before CUDA graph capture.
Args:
input_buffers: Dictionary of input tensors
**kwargs: Additional arguments passed to cache manager
Returns:
Updated input buffers
"""
return self.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(
self, batch_size: int) -> dict[str, torch.Tensor]:
"""Get inputs for sequence-length-agnostic graph capture.
Args:
batch_size: Size of batch to capture
Returns:
Dictionary of capture inputs
"""
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def compute_logits(
self,
hidden_states: torch.Tensor,
......
......@@ -12,6 +12,7 @@ from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
compute_causal_conv1d_metadata,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
......@@ -52,7 +53,6 @@ class GDNAttentionMetadata:
# The following attributes are for triton implementation of causal_conv1d
nums_dict: Optional[dict] = None
cu_seqlen: Optional[int] = None
batch_ptr: Optional[torch.Tensor] = None
token_chunk_offset_ptr: Optional[torch.Tensor] = None
......@@ -134,6 +134,7 @@ class GDNAttentionMetadataBuilder(
context_lens = m.num_computed_tokens_cpu
context_lens_tensor = context_lens.to(query_start_loc.device)
seq_lens_tensor = m.seq_lens
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
if (not self.use_spec_decode or num_draft_tokens is None
or num_draft_tokens.sum().item() == 0):
......@@ -210,6 +211,8 @@ class GDNAttentionMetadataBuilder(
has_initial_state = context_lens_tensor > 0
if spec_sequence_masks is not None:
has_initial_state = has_initial_state[~spec_sequence_masks]
nums_dict, batch_ptr, token_chunk_offset_ptr = \
compute_causal_conv1d_metadata(non_spec_query_start_loc)
else:
has_initial_state = None
num_actual_tokens = num_prefill_tokens + num_decode_tokens + \
......@@ -297,6 +300,9 @@ class GDNAttentionMetadataBuilder(
spec_sequence_masks=spec_sequence_masks,
spec_token_masks=spec_token_masks,
num_accepted_tokens=num_accepted_tokens,
nums_dict=nums_dict,
batch_ptr=batch_ptr,
token_chunk_offset_ptr=token_chunk_offset_ptr,
)
return attn_metadata
......
......@@ -7,11 +7,12 @@ from typing import Optional
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig
from vllm.v1.attention.backends.mamba_attn import (
BaseMambaAttentionMetadataBuilder)
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
from vllm.v1.attention.backends.utils import (PAD_SLOT_ID,
CommonAttentionMetadata,
compute_causal_conv1d_metadata,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec
......@@ -131,7 +132,6 @@ class Mamba2AttentionMetadata:
# The following attributes are for triton implementation of causal_conv1d
nums_dict: Optional[dict] = None
cu_seqlen: Optional[int] = None
batch_ptr: Optional[torch.Tensor] = None
token_chunk_offset_ptr: Optional[torch.Tensor] = None
......@@ -161,6 +161,9 @@ class Mamba2AttentionMetadataBuilder(
has_initial_states_p = None
prep_initial_states = False
# for causal_conv1d
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
......@@ -198,6 +201,9 @@ class Mamba2AttentionMetadataBuilder(
query_start_loc_p, self.chunk_size,
num_prefill_tokens))
nums_dict, batch_ptr, token_chunk_offset_ptr = \
compute_causal_conv1d_metadata(query_start_loc_p)
elif num_decodes <= self.decode_cudagraph_max_bs:
# Pad state tensor for CUDA graph
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes)
......@@ -220,5 +226,8 @@ class Mamba2AttentionMetadataBuilder(
chunk_indices_p=chunk_indices_p,
chunk_offsets_p=chunk_offsets_p,
state_indices_tensor=state_indices_tensor,
nums_dict=nums_dict,
batch_ptr=batch_ptr,
token_chunk_offset_ptr=token_chunk_offset_ptr,
)
return attn_metadata
......@@ -9,6 +9,7 @@ from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata,
compute_causal_conv1d_metadata,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
......@@ -33,7 +34,6 @@ class ShortConvAttentionMetadata:
# For causal_conv1d
nums_dict: Optional[dict] = None
cu_seqlen: Optional[int] = None
batch_ptr: Optional[torch.Tensor] = None
token_chunk_offset_ptr: Optional[torch.Tensor] = None
......@@ -57,6 +57,9 @@ class ShortConvAttentionMetadataBuilder(
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
# for causal_conv1d
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata,
......@@ -70,6 +73,12 @@ class ShortConvAttentionMetadataBuilder(
has_initial_states = has_initial_states_cpu.to(
query_start_loc.device)
query_start_loc_p = common_attn_metadata.query_start_loc[
-num_prefills - 1:] - num_decode_tokens
nums_dict, batch_ptr, token_chunk_offset_ptr = \
compute_causal_conv1d_metadata(query_start_loc_p)
attn_metadata = ShortConvAttentionMetadata(
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
......@@ -78,5 +87,8 @@ class ShortConvAttentionMetadataBuilder(
query_start_loc=query_start_loc,
has_initial_states=has_initial_states,
state_indices_tensor=state_indices_tensor,
nums_dict=nums_dict,
batch_ptr=batch_ptr,
token_chunk_offset_ptr=token_chunk_offset_ptr,
)
return attn_metadata
......@@ -34,6 +34,8 @@ logger = init_logger(__name__)
KVCacheLayoutType = Literal["NHD", "HND"]
_KV_CACHE_LAYOUT_OVERRIDE: Union[KVCacheLayoutType, None] = None
PAD_SLOT_ID = -1
def is_valid_kv_cache_layout(value: str) -> bool:
return value in get_args(KVCacheLayoutType)
......@@ -838,3 +840,52 @@ def create_fast_prefill_custom_backend(
builder_cls=FastPrefillAttentionBuilder)
return attn_backend
def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor):
# Needed for causal_conv1d
seqlens = query_start_loc_p.diff().to('cpu')
nums_dict = {} # type: ignore
batch_ptr = None
token_chunk_offset_ptr = None
for BLOCK_M in [8]: # cover all BLOCK_M values
nums = -(-seqlens // BLOCK_M)
nums_dict[BLOCK_M] = {}
nums_dict[BLOCK_M]['nums'] = nums
nums_dict[BLOCK_M]['tot'] = nums.sum().item()
mlist = torch.from_numpy(np.repeat(np.arange(len(nums)), nums))
nums_dict[BLOCK_M]['mlist'] = mlist
mlist_len = len(nums_dict[BLOCK_M]['mlist'])
nums_dict[BLOCK_M]['mlist_len'] = mlist_len
MAX_NUM_PROGRAMS = max(1024, mlist_len) * 2
offsetlist = [] # type: ignore
for idx, num in enumerate(nums):
offsetlist.extend(range(num))
offsetlist = torch.tensor(offsetlist, dtype=torch.int32)
nums_dict[BLOCK_M]['offsetlist'] = offsetlist
if batch_ptr is None:
# Update default value after class definition
batch_ptr = torch.full((MAX_NUM_PROGRAMS, ),
PAD_SLOT_ID,
dtype=torch.int32,
device='cuda')
token_chunk_offset_ptr = torch.full((MAX_NUM_PROGRAMS, ),
PAD_SLOT_ID,
dtype=torch.int32,
device='cuda')
else:
if batch_ptr.nelement() < MAX_NUM_PROGRAMS:
batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID)
token_chunk_offset_ptr.resize_( # type: ignore
MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID)
batch_ptr[0:mlist_len].copy_(mlist)
token_chunk_offset_ptr[ # type: ignore
0:mlist_len].copy_(offsetlist)
nums_dict[BLOCK_M]['batch_ptr'] = batch_ptr
nums_dict[BLOCK_M]['token_chunk_offset_ptr'] = (token_chunk_offset_ptr
) # type: ignore
return nums_dict, batch_ptr, token_chunk_offset_ptr
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment