Unverified Commit a903669e authored by Thomas Parnell's avatar Thomas Parnell Committed by GitHub
Browse files

[V1] Remove V0 code paths for Hybrid models (#25400)


Signed-off-by: default avatarThomas Parnell <tpa@zurich.ibm.com>
parent 2c58742d
......@@ -20,7 +20,9 @@ pytestmark = pytest.mark.hybrid_model
SSM_MODELS = [
"state-spaces/mamba-130m-hf",
"tiiuae/falcon-mamba-tiny-dev",
"yujiepan/mamba2-codestral-v0.1-tiny-random",
# mamba2-codestral in transformers is broken pending:
# https://github.com/huggingface/transformers/pull/40861
#"yujiepan/mamba2-codestral-v0.1-tiny-random",
]
HYBRID_MODELS = [
......@@ -31,18 +33,7 @@ HYBRID_MODELS = [
"ibm-granite/granite-4.0-tiny-preview",
"tiiuae/Falcon-H1-0.5B-Base",
"LiquidAI/LFM2-1.2B",
]
V1_SUPPORTED_MODELS = [
"state-spaces/mamba-130m-hf",
"ai21labs/Jamba-tiny-dev",
"pfnet/plamo-2-1b",
"yujiepan/mamba2-codestral-v0.1-tiny-random",
"Zyphra/Zamba2-1.2B-instruct",
"hmellor/tiny-random-BambaForCausalLM",
"ibm-granite/granite-4.0-tiny-preview",
"tiiuae/Falcon-H1-0.5B-Base",
"LiquidAI/LFM2-1.2B",
"tiny-random/qwen3-next-moe",
]
FULL_CUDA_GRAPH_MODELS = [
......@@ -51,10 +42,6 @@ FULL_CUDA_GRAPH_MODELS = [
"Zyphra/Zamba2-1.2B-instruct",
]
V0_UNSUPPORTED_MODELS = [
"LiquidAI/LFM2-1.2B",
]
FP32_STATE_MODELS = [
"state-spaces/mamba-130m-hf",
"Zyphra/Zamba2-1.2B-instruct",
......@@ -88,19 +75,15 @@ def test_models(
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)
if model in V1_SUPPORTED_MODELS:
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
else:
vllm_v1_outputs = None
if model in V1_SUPPORTED_MODELS:
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_v1_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm-v1",
name_1="vllm",
)
......@@ -299,14 +282,14 @@ def test_full_cuda_graph(
example_prompts, max_tokens, num_logprobs)
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_v1_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm-v1",
name_1="vllm",
)
......@@ -340,12 +323,12 @@ def test_fp32_cache_state(
with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
**{cache_dtype_param: "float32"}) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_v1_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm-v1",
name_1="vllm",
)
......@@ -312,13 +312,11 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"),
"PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"),
"Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"),
"Phi4FlashForCausalLM": _HfExamplesInfo("microsoft/Phi-4-mini-flash-reasoning", # noqa: E501
trust_remote_code=True,
v0_only=True,
max_model_len=10240),
"PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct",
trust_remote_code=True),
"Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b",
max_transformers_version="4.55.4",
transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers", # noqa: E501
trust_remote_code=True),
"QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat",
max_transformers_version="4.53",
......@@ -330,7 +328,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
"Qwen3NextForCausalLM": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct",
min_transformers_version="4.56.2"),
extras={"tiny-random": "tiny-random/qwen3-next-moe"}, # noqa: E501
min_transformers_version="4.56.3"),
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
"SeedOssForCausalLM": _HfExamplesInfo("ByteDance-Seed/Seed-OSS-36B-Instruct", # noqa: E501
trust_remote_code=True,
......@@ -644,7 +643,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
trust_remote_code=True,
speculative_model="XiaomiMiMo/MiMo-7B-RL"),
"Qwen3NextMTP": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct",
min_transformers_version="4.56.2"),
min_transformers_version="4.56.3"),
}
_TRANSFORMERS_BACKEND_MODELS = {
......
......@@ -20,10 +20,7 @@ class MambaBase(AttentionLayerBase):
# Contains the KV cache (mamba state) for the layer
# in the shape specified by `self.get_state_shape`.
# The outer list is for v0 PP virtual engine. Though this code path
# only runs for v1, we have to do this to unify with the interface
# of Attention + v0 PP.
kv_cache: list[Iterable[torch.Tensor]]
kv_cache: tuple[torch.Tensor, ...]
@abstractmethod
def get_state_shape(self) -> Iterable[tuple[int, ...]]:
......
......@@ -15,7 +15,6 @@ import torch.nn.functional as F
from einops import rearrange
from torch import nn
from vllm import envs
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
......@@ -42,8 +41,6 @@ if TYPE_CHECKING:
import torch
import torch.distributed
from vllm.model_executor.models.minimax_cache import MinimaxCacheParams
class MiniMaxText01RMSNormTP(CustomOp):
name = "MiniMaxText01RMSNormTP"
......@@ -225,7 +222,6 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
self.tp_heads:(self.tp_rank + 1) *
self.tp_heads].contiguous()
if envs.VLLM_USE_V1:
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
......@@ -268,8 +264,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
break
if _prefill_idx >= len(state_indices_tensor):
break
# prefills are packed at end of batch in V1
offset = attn_metadata.num_decode_tokens if envs.VLLM_USE_V1 else 0
offset = attn_metadata.num_decode_tokens
_start = attn_metadata.query_start_loc[offset + _prefill_idx]
_end = attn_metadata.query_start_loc[offset + _prefill_idx + 1]
slot_id = state_indices_tensor[offset + _prefill_idx]
......@@ -291,10 +286,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
hidden_decode = self._decode_infer(q, k, v, kv_cache,
state_indices_tensor,
attn_metadata)
if envs.VLLM_USE_V1:
hidden.insert(0, hidden_decode)
else:
hidden.append(hidden_decode)
if not hidden:
return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
......@@ -304,13 +296,6 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor,
attn_metadata):
if not envs.VLLM_USE_V1:
q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
num_prefills = getattr(attn_metadata, "num_prefills", 0)
slot_id = state_indices_tensor[num_prefills:]
else:
q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
......@@ -320,11 +305,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
return hidden
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
positions: torch.Tensor,
kv_caches: MinimaxCacheParams) -> None:
if not envs.VLLM_USE_V1:
self._forward(hidden_states, output, positions, kv_caches)
else:
positions: torch.Tensor) -> None:
torch.ops.vllm.linear_attention(
hidden_states,
output,
......@@ -333,11 +314,10 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
)
def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
positions: torch.Tensor,
kv_caches: Optional[MinimaxCacheParams]) -> None:
positions: torch.Tensor) -> None:
forward_context = get_forward_context()
attn_metadata: AttentionMetadata = forward_context.attn_metadata
if envs.VLLM_USE_V1 and attn_metadata is not None:
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, LinearAttentionMetadata)
......@@ -351,32 +331,26 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
qkvact = torch.nn.functional.silu(qkv32)
qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
if envs.VLLM_USE_V1:
if attn_metadata is not None:
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
state_indices_tensor = attn_metadata.state_indices_tensor
num_prefills = getattr(attn_metadata, "num_prefills", 0)
if num_prefills > 0:
num_decode_tokens = getattr(attn_metadata,
"num_decode_tokens", 0)
num_decode_tokens = getattr(attn_metadata, "num_decode_tokens",
0)
for prefill_idx in range(num_prefills):
q_start = attn_metadata.query_start_loc[
num_decode_tokens + prefill_idx]
q_end = attn_metadata.query_start_loc[num_decode_tokens
+ prefill_idx +
1]
q_start = attn_metadata.query_start_loc[num_decode_tokens +
prefill_idx]
q_end = attn_metadata.query_start_loc[num_decode_tokens +
prefill_idx + 1]
query_len = q_end - q_start
context_len = attn_metadata.seq_lens[
num_decode_tokens + prefill_idx] - query_len
if context_len == 0:
block_to_clear = state_indices_tensor[
num_decode_tokens + prefill_idx]
block_to_clear = state_indices_tensor[num_decode_tokens
+ prefill_idx]
kv_cache[block_to_clear, ...] = 0
else:
assert kv_caches is not None
kv_cache = kv_caches.minimax_cache
state_indices_tensor = kv_caches.state_indices_tensor
decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
if attn_metadata is None:
......@@ -410,8 +384,7 @@ def linear_attention(
self = forward_context.no_compile_layers[layer_name]
self._forward(hidden_states=hidden_states,
output=output,
positions=positions,
kv_caches=None)
positions=positions)
def linear_attention_fake(
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional, Union
import numpy as np
import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.placeholder_attn import (
PlaceholderAttentionMetadata)
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.platforms import current_platform
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
from vllm.v1.attention.backends.mamba2_attn import (
Mamba2AttentionMetadata, _query_start_loc_to_chunk_indices_offsets)
@dataclass
class Mamba2Metadata:
prep_initial_states: bool
chunk_size: int
has_initial_states_p: torch.Tensor
seq_idx_p: torch.Tensor
chunk_indices_p: torch.Tensor
chunk_offsets_p: torch.Tensor
"""
With continuous batching layout of `x` in vLLM, to enable a Triton program
to handle a request in parallel, two supporting tensors are used
(batch_ptr, token_chunk_offset_ptr)
BLOCK_M = the # tokens to be handled by a Triton program
(can be customized for different hardware)
nums_dict:
tracks the data associated with a given value of BLOCK_M
BLOCK_M = #tokens handled by a Triton program
cu_seqlen: total tokens per batch
(used as flag to update other data at each new input)
batch_ptr: tracks batch-id handled by the Triton program
token_chunk_offset_ptr: tracks token group_idx handled by the Triton program
(Triton implementation of causal_conv1d handles parallelism in 3-axes
- feature-axis
- batch-axis
- sequence-axis)
"""
nums_dict: Optional[dict] = None
cu_seqlen: Optional[int] = None
batch_ptr: Optional[torch.Tensor] = None
token_chunk_offset_ptr: Optional[torch.Tensor] = None
def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]:
"""Returns the appropriate metadata classes for the current platform."""
if current_platform.is_rocm():
from vllm.v1.attention.backends.rocm_aiter_fa import (
AiterFlashAttentionMetadata)
from vllm.v1.attention.backends.triton_attn import (
TritonAttentionMetadata)
return (AiterFlashAttentionMetadata, TritonAttentionMetadata,
PlaceholderAttentionMetadata)
if current_platform.is_cuda():
from vllm.v1.attention.backends.flash_attn import (
FlashAttentionMetadata)
from vllm.v1.attention.backends.xformers import (
XFormersAttentionMetadata)
return (FlashAttentionMetadata, XFormersAttentionMetadata,
PlaceholderAttentionMetadata)
raise ValueError(
f"Unsupported platform for Mamba2: {current_platform.device_type}")
def prepare_mamba2_metadata(
chunk_size: int,
attn_metadata: AttentionMetadata,
) -> Mamba2Metadata:
# compute number of prefill and decode requests
# NOTE: in V0 we assume prefills are before decodes
num_prefills = attn_metadata.num_prefills
num_prefill_tokens = attn_metadata.num_prefill_tokens
seq_idx_p = None
chunk_indices_p, chunk_offsets_p = None, None
# Need flags to indicate if there are initial states
# currently we really only support the FlashAttention backend
has_initial_states_p = None
prep_initial_states = False
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
if num_prefills > 0:
attn_metadata_instances = get_platform_metadata_classes()
if (isinstance(attn_metadata, attn_metadata_instances)
and attn_metadata.context_lens_tensor is not None):
# precompute flag to avoid device syncs later in mamba2 layer
# forwards
# prep is only needed for mamba2 ssd prefill processing
has_initial_states_p = (
attn_metadata.context_lens_tensor[:num_prefills] > 0)
prep_initial_states = torch.any(has_initial_states_p).item()
query_start_loc_p = attn_metadata.query_start_loc[:num_prefills + 1]
seq_idx_p = torch.repeat_interleave(torch.arange(
num_prefills, dtype=torch.int32, device=query_start_loc_p.device),
query_start_loc_p.diff(),
output_size=num_prefill_tokens)
seq_idx_p.unsqueeze_(0)
# We compute metadata for chunked prefill once at the top level model
# forward and reuse them in mamba layers. If not needed, they will be
# ignored inside mamba kernels.
if prep_initial_states:
chunk_indices_p, chunk_offsets_p = \
_query_start_loc_to_chunk_indices_offsets(
query_start_loc_p, chunk_size, num_prefill_tokens)
return Mamba2Metadata(has_initial_states_p=has_initial_states_p,
prep_initial_states=prep_initial_states,
chunk_size=chunk_size,
seq_idx_p=seq_idx_p,
chunk_indices_p=chunk_indices_p,
chunk_offsets_p=chunk_offsets_p)
def update_metadata(x: torch.Tensor, query_start_loc: torch.Tensor,
mamba2_metadata: Union[Mamba2Metadata,
Mamba2AttentionMetadata,
GDNAttentionMetadata]):
"""
this is triggered upon handling a new input at the first layer
"""
dim, cu_seqlen = x.shape
mamba2_metadata.cu_seqlen = cu_seqlen
seqlens = np.diff(query_start_loc.to('cpu'))
nums_dict = {} # type: ignore
for BLOCK_M in [8]: # cover all BLOCK_M values
nums = -(-seqlens // BLOCK_M)
nums_dict[BLOCK_M] = {}
nums_dict[BLOCK_M]['nums'] = nums
nums_dict[BLOCK_M]['tot'] = nums.sum().item()
mlist = torch.from_numpy(np.repeat(np.arange(len(nums)), nums))
nums_dict[BLOCK_M]['mlist'] = mlist
mlist_len = len(nums_dict[BLOCK_M]['mlist'])
nums_dict[BLOCK_M]['mlist_len'] = mlist_len
MAX_NUM_PROGRAMS = max(1024, mlist_len) * 2
offsetlist = [] # type: ignore
for idx, num in enumerate(nums):
offsetlist.extend(range(num))
offsetlist = torch.tensor(offsetlist, dtype=torch.int32)
nums_dict[BLOCK_M]['offsetlist'] = offsetlist
if mamba2_metadata.batch_ptr is None:
# Update default value after class definition
#mamba2_metadata.MAX_NUM_PROGRAMS *= 2
mamba2_metadata.batch_ptr = torch.full((MAX_NUM_PROGRAMS, ),
PAD_SLOT_ID,
dtype=torch.int32,
device='cuda')
mamba2_metadata.token_chunk_offset_ptr = torch.full(
(MAX_NUM_PROGRAMS, ),
PAD_SLOT_ID,
dtype=torch.int32,
device='cuda')
else:
if mamba2_metadata.batch_ptr.nelement() < MAX_NUM_PROGRAMS:
mamba2_metadata.batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_(
PAD_SLOT_ID)
mamba2_metadata.token_chunk_offset_ptr.resize_( # type: ignore
MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID)
mamba2_metadata.batch_ptr[0:mlist_len].copy_(mlist)
mamba2_metadata.token_chunk_offset_ptr[ # type: ignore
0:mlist_len].copy_(offsetlist)
nums_dict[BLOCK_M]['batch_ptr'] = mamba2_metadata.batch_ptr
nums_dict[BLOCK_M]['token_chunk_offset_ptr'] = (
mamba2_metadata.token_chunk_offset_ptr) # type: ignore
mamba2_metadata.nums_dict = nums_dict
return mamba2_metadata
......@@ -10,8 +10,6 @@ import torch
from torch import nn
from torch.nn.parameter import Parameter
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
......@@ -28,7 +26,6 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_state_update)
from vllm.model_executor.models.mamba_cache import MambaCacheParams
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
......@@ -149,16 +146,12 @@ class MambaMixer(MambaBase, CustomOp):
has_weight=rms_norm_has_weight,
) if use_rms_norm else None
if envs.VLLM_USE_V1:
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
# The outer list is for v0 PP virtual engine. Though this code path
# only runs for v1, we have to do this to unify with the interface
# of Attention + v0 PP.
# The inner tuple is (conv_state, ssm_state)
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
self.kv_cache = (torch.tensor([]), torch.tensor([]))
self.model_config = model_config
self.cache_config = cache_config
......@@ -186,29 +179,18 @@ class MambaMixer(MambaBase, CustomOp):
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
return discrete_time_step, B, C
def forward(self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mamba_cache_params: Optional[MambaCacheParams] = None):
if not envs.VLLM_USE_V1:
CustomOp.forward(self, hidden_states, output, mamba_cache_params)
else:
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor):
torch.ops.vllm.mamba_mixer(
hidden_states,
output,
self.prefix,
)
def forward_native(self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mamba_cache_params: Optional[MambaCacheParams] = None):
def forward_native(self, hidden_states: torch.Tensor,
output: torch.Tensor):
pass
def forward_cuda(self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mamba_cache_params: Optional[MambaCacheParams] = None):
def forward_cuda(self, hidden_states: torch.Tensor, output: torch.Tensor):
"""
Run the Mamba-1 SSM pipeline.
......@@ -234,7 +216,6 @@ class MambaMixer(MambaBase, CustomOp):
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if envs.VLLM_USE_V1:
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
......@@ -247,18 +228,6 @@ class MambaMixer(MambaBase, CustomOp):
ssm_state = self_kv_cache[1]
has_initial_states = mamba1_metadata.has_initial_states
num_padded_decodes = mamba1_metadata.num_padded_decodes
else:
assert isinstance(attn_metadata, AttentionMetadata)
assert mamba_cache_params is not None
conv_state = mamba_cache_params.conv_state
ssm_state = mamba_cache_params.ssm_state
state_indices_tensor = mamba_cache_params.state_indices_tensor
query_start_loc = attn_metadata.query_start_loc
context_lens_tensor = attn_metadata.context_lens_tensor
has_initial_states = None
if context_lens_tensor is not None:
has_initial_states = context_lens_tensor > 0
num_padded_decodes = attn_metadata.num_decode_tokens
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
......@@ -267,7 +236,7 @@ class MambaMixer(MambaBase, CustomOp):
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))
if envs.VLLM_USE_V1 and attn_metadata is None:
if attn_metadata is None:
# V1 profile run
hidden_states_BC = hidden_states_BC.contiguous()
return self.out_proj(hidden_states_BC.transpose(-2, -1))[0]
......@@ -368,10 +337,7 @@ class MambaMixer(MambaBase, CustomOp):
out=scan_outputs_d)
scan_outputs_d = scan_outputs_d.transpose(0, 1)
if envs.VLLM_USE_V1:
ssm_outputs.insert(0, scan_outputs_d)
else:
ssm_outputs.append(scan_outputs_d)
scan_outputs_combined = ssm_outputs[0] if len(
ssm_outputs) == 1 else torch.cat(ssm_outputs, dim=-1)
......@@ -441,9 +407,9 @@ def split_batch_to_prefill_and_decode(
num_decodes: int,
num_padded_decodes: int,
) -> PrefillDecodeSplit:
num_actual_tokens = num_prefill_tokens + num_padded_decodes
if envs.VLLM_USE_V1:
# In v1, decode tokens come first, then prefill tokens.
hidden_states_BC_d, hidden_states_BC_p = torch.split(
hidden_states_BC[..., :num_actual_tokens],
......@@ -462,19 +428,6 @@ def split_batch_to_prefill_and_decode(
num_padded_decodes if num_prefills > 0 else None)
has_initial_states_p = has_initial_states[-num_prefills:] if (
has_initial_states is not None and num_prefills > 0) else None
else:
# In v0, prefill tokens come first, then decode tokens.
hidden_states_BC_p, hidden_states_BC_d = torch.split(
hidden_states_BC, [num_prefill_tokens, num_decode_tokens], dim=-1)
gate_p, gate_d = torch.split(gate,
[num_prefill_tokens, num_decode_tokens],
dim=-1)
state_indices_tensor_p, state_indices_tensor_d = torch.split(
state_indices_tensor, [num_prefills, num_decodes], dim=0)
query_start_loc_p = (query_start_loc[:num_prefills +
1] if num_prefills > 0 else None)
has_initial_states_p = has_initial_states[:num_prefills] if (
has_initial_states is not None and num_prefills > 0) else None
return PrefillDecodeSplit(
hidden_states_BC_p=hidden_states_BC_p,
......@@ -495,9 +448,7 @@ def mamba_mixer(
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self.forward_cuda(hidden_states=hidden_states,
output=output,
mamba_cache_params=None)
self.forward_cuda(hidden_states=hidden_states, output=output)
def mamba_mixer_fake(
......
......@@ -9,7 +9,6 @@ if TYPE_CHECKING:
import torch
from torch import nn
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
......@@ -22,8 +21,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
update_metadata)
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
......@@ -36,7 +33,6 @@ from vllm.model_executor.layers.mamba.ops.ssd_combined import (
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import (
LoaderFunction, composed_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.mamba_cache import MambaCacheParams
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
......@@ -449,16 +445,12 @@ class MambaMixer2(MambaBase, CustomOp):
self.use_rms_norm,
eps=rms_norm_eps)
if envs.VLLM_USE_V1:
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
# The outer list is for v0 PP virtual engine. Though this code path
# only runs for v1, we have to do this to unify with the interface
# of Attention + v0 PP.
# The inner tuple is (conv_state, ssm_state)
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
# The tuple is (conv_state, ssm_state)
self.kv_cache = (torch.tensor([]), torch.tensor([]))
self.model_config = model_config
self.cache_config = cache_config
......@@ -468,8 +460,6 @@ class MambaMixer2(MambaBase, CustomOp):
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
mup_vector: Optional[torch.Tensor] = None,
):
pass
......@@ -478,14 +468,8 @@ class MambaMixer2(MambaBase, CustomOp):
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
mup_vector: Optional[torch.Tensor] = None,
):
if not envs.VLLM_USE_V1:
CustomOp.forward(self, hidden_states, output, mamba_cache_params,
mamba2_metadata, mup_vector)
else:
torch.ops.vllm.mamba_mixer2(
hidden_states,
output,
......@@ -497,40 +481,30 @@ class MambaMixer2(MambaBase, CustomOp):
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
mup_vector: Optional[torch.Tensor] = None,
):
forward_context = get_forward_context()
# mamba2_metadata contains metadata necessary for the mamba2 triton
# attn_metadata contains metadata necessary for the mamba2 triton
# kernels to operate in continuous batching and in chunked prefill
# modes; they are computed at top-level model forward since they
# stay the same and reused for all mamba layers in the same iteration
attn_metadata: AttentionMetadata = forward_context.attn_metadata
if envs.VLLM_USE_V1:
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
mamba2_metadata = attn_metadata
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
state_indices_tensor = attn_metadata.state_indices_tensor
else:
conv_state = mamba_cache_params.conv_state
ssm_state = mamba_cache_params.ssm_state
state_indices_tensor = mamba_cache_params.state_indices_tensor
# Common members between V1 metadata and V0 metadata
if mamba2_metadata is not None:
has_initial_states_p = mamba2_metadata.has_initial_states_p
prep_initial_states = mamba2_metadata.prep_initial_states
chunk_size = mamba2_metadata.chunk_size
seq_idx_p = mamba2_metadata.seq_idx_p
chunk_indices_p = mamba2_metadata.chunk_indices_p
chunk_offsets_p = mamba2_metadata.chunk_offsets_p
has_initial_states_p = attn_metadata.has_initial_states_p
prep_initial_states = attn_metadata.prep_initial_states
chunk_size = attn_metadata.chunk_size
seq_idx_p = attn_metadata.seq_idx_p
chunk_indices_p = attn_metadata.chunk_indices_p
chunk_offsets_p = attn_metadata.chunk_offsets_p
# 1. Gated MLP's linear projection
projected_states, _ = self.in_proj(hidden_states)
......@@ -562,8 +536,8 @@ class MambaMixer2(MambaBase, CustomOp):
dim=-1,
)
if envs.VLLM_USE_V1 and attn_metadata is None:
# V1 profile run
if attn_metadata is None:
# profile run
hidden_states_B_C = (hidden_states_B_C.transpose(
0, 1).clone().transpose(0, 1)).contiguous()
hidden_states, _B, _C = split_hidden_states_B_C_fn(
......@@ -579,10 +553,8 @@ class MambaMixer2(MambaBase, CustomOp):
has_decode = num_decodes > 0
num_actual_tokens = num_prefill_tokens + num_decodes
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
# Separate prefill and decode by splitting varlen input
# Split along token dimension
if envs.VLLM_USE_V1:
hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
hidden_states_B_C[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
......@@ -602,26 +574,6 @@ class MambaMixer2(MambaBase, CustomOp):
query_start_loc_p = (
attn_metadata.query_start_loc[-num_prefills - 1:] -
num_decodes if has_prefill else None)
else:
hidden_states_B_C_p, hidden_states_B_C_d = torch.split(
hidden_states_B_C,
[num_prefill_tokens, num_decodes],
dim=0,
)
dt_p, dt_d = torch.split(
dt,
[num_prefill_tokens, num_decodes],
dim=0,
)
# Split along batch dimension
state_indices_tensor_p, state_indices_tensor_d = torch.split(
state_indices_tensor,
[num_prefills, num_decodes],
dim=0,
)
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills +
1]
if has_prefill else None)
# Preallocate output tensor to avoid memcpy cost for merging prefill
# and decode outputs
......@@ -633,18 +585,11 @@ class MambaMixer2(MambaBase, CustomOp):
dtype=hidden_states.dtype,
device=hidden_states.device,
)
if envs.VLLM_USE_V1:
preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
preallocated_ssm_out,
[num_decodes, num_prefill_tokens],
dim=0,
)
else:
preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
preallocated_ssm_out,
[num_prefill_tokens, num_decodes],
dim=0,
)
# Process prefill requests
if has_prefill:
......@@ -653,9 +598,6 @@ class MambaMixer2(MambaBase, CustomOp):
# pointed to by "state_indices_tensor"
x = hidden_states_B_C_p.transpose(
0, 1) # this is the form that causal-conv see
if mamba2_metadata.cu_seqlen is None:
mamba2_metadata = update_metadata(x, query_start_loc_p,
mamba2_metadata)
hidden_states_B_C_p = causal_conv1d_fn(
x,
conv_weights,
......@@ -664,7 +606,7 @@ class MambaMixer2(MambaBase, CustomOp):
conv_states=conv_state,
has_initial_state=has_initial_states_p,
cache_indices=state_indices_tensor_p,
metadata=mamba2_metadata,
metadata=attn_metadata,
query_start_loc=query_start_loc_p).transpose(
0, 1)[:num_prefill_tokens]
......@@ -806,8 +748,6 @@ def mamba_mixer2(
self = forward_context.no_compile_layers[layer_name]
self.forward_cuda(hidden_states=hidden_states,
output=output,
mamba_cache_params=None,
mamba2_metadata=None,
mup_vector=mup_vector)
......
......@@ -100,7 +100,6 @@ class MambaStateShapeCalculator:
intermediate_size: int,
state_size: int,
conv_kernel: int,
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int]]:
conv_state_shape = (divide(intermediate_size,
tp_world_size), conv_kernel - 1)
......@@ -108,10 +107,6 @@ class MambaStateShapeCalculator:
temporal_state_shape = (divide(intermediate_size,
tp_world_size), state_size)
# In V0, the conv_state shape was swapped during allocation in
# MambaCacheManager, but in V1 it needs to be determined here at the
# calculation level
if use_v1:
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
return conv_state_shape, temporal_state_shape
......@@ -126,7 +121,6 @@ class MambaStateShapeCalculator:
head_dim: int,
state_size: int,
conv_kernel: int,
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]:
# if n_groups is not divisible by world_size, need to extend the shards
# to ensure all groups needed by a head is sharded along with it
......@@ -137,8 +131,6 @@ class MambaStateShapeCalculator:
# contiguous along 'dim' axis
conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size))
if not use_v1:
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
# These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small
......@@ -153,12 +145,9 @@ class MambaStateShapeCalculator:
tp_world_size: int,
intermediate_size: int,
conv_kernel: int,
use_v1: bool = True,
) -> tuple[tuple[int, int]]:
conv_dim = divide(intermediate_size, tp_world_size)
conv_state_shape = (conv_kernel - 1, conv_dim)
if not use_v1:
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
return (conv_state_shape, )
@classmethod
......@@ -183,7 +172,6 @@ class MambaStateShapeCalculator:
head_v_dim: int,
conv_kernel_size: int,
num_spec: int = 0,
use_v1: bool = True,
):
conv_dim = (head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads)
conv_state_shape = (
......@@ -191,10 +179,6 @@ class MambaStateShapeCalculator:
conv_kernel_size - 1 + num_spec,
)
# In V0, the conv_state shape was swapped during allocation in
# MambaCacheManager, but in V1 it needs to be determined here at the
# calculation level
if use_v1:
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
temporal_state_shape = (divide(num_v_heads,
......
......@@ -420,9 +420,7 @@ def causal_conv1d_fn(
x = x.to(conv_states.dtype)
out = torch.empty_like(x)
if metadata is not None:
cu_seqlen = metadata.cu_seqlen
nums_dict = metadata.nums_dict
#x = metadata.x
args = nums_dict
batch_ptr = metadata.batch_ptr
token_chunk_offset_ptr = metadata.token_chunk_offset_ptr
......@@ -926,7 +924,6 @@ def causal_conv1d_update(
query_start_loc: Optional[torch.Tensor] = None,
max_query_len: int = -1,
pad_slot_id: int = PAD_SLOT_ID,
metadata=None,
validate_data=False,
):
"""
......
......@@ -8,7 +8,6 @@ if TYPE_CHECKING:
import torch
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size
......@@ -18,7 +17,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba2_metadata import update_metadata
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
......@@ -71,15 +69,11 @@ class ShortConv(MambaBase, CustomOp):
prefix=f"{prefix}.out_proj",
)
assert envs.VLLM_USE_V1, ("ShortConv layers are only supported in V1")
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
# The outer list is for v0 PP virtual engine. Though this code path
# only runs for v1, we have to do this to unify with the interface
# of Attention + v0 PP.
self.kv_cache = [(torch.tensor([]), )]
self.kv_cache = (torch.tensor([]), )
self.model_config = model_config
self.cache_config = cache_config
......@@ -89,7 +83,6 @@ class ShortConv(MambaBase, CustomOp):
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
conv_metadata: ShortConvAttentionMetadata,
):
return
......@@ -97,7 +90,6 @@ class ShortConv(MambaBase, CustomOp):
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
conv_metadata: ShortConvAttentionMetadata,
):
torch.ops.vllm.short_conv(
hidden_states,
......@@ -109,7 +101,6 @@ class ShortConv(MambaBase, CustomOp):
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
conv_metadata: ShortConvAttentionMetadata,
):
forward_context = get_forward_context()
# ShortConvAttentionMetadata contains metadata necessary for the
......@@ -121,7 +112,6 @@ class ShortConv(MambaBase, CustomOp):
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
conv_metadata = attn_metadata
assert isinstance(attn_metadata, ShortConvAttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
conv_state = self_kv_cache[0].transpose(-1, -2)
......@@ -181,9 +171,6 @@ class ShortConv(MambaBase, CustomOp):
if has_prefill:
Bx_p = (B_p * x_p).transpose(0, 1)
if conv_metadata.cu_seqlen is None:
conv_metadata = update_metadata(Bx_p, query_start_loc_p,
conv_metadata)
Bx = causal_conv1d_fn(Bx_p,
conv_weights,
self.conv.bias,
......@@ -191,7 +178,7 @@ class ShortConv(MambaBase, CustomOp):
conv_states=conv_state,
has_initial_state=has_initial_states_p,
cache_indices=state_indices_tensor_p,
metadata=conv_metadata,
metadata=attn_metadata,
query_start_loc=query_start_loc_p).transpose(
0, 1)[:num_prefill_tokens]
......@@ -248,9 +235,7 @@ def short_conv(
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self.forward_cuda(hidden_states=hidden_states,
output=output,
conv_metadata=None)
self.forward_cuda(hidden_states=hidden_states, output=output)
def short_conv_fake(
......
......@@ -9,21 +9,17 @@ import torch
from torch import nn
from transformers import BambaConfig
from vllm import envs
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator)
......@@ -32,10 +28,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
SupportsQuant)
......@@ -115,8 +108,6 @@ class BambaMixerDecoderLayer(nn.Module):
self,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs,
):
if residual is None:
......@@ -127,7 +118,7 @@ class BambaMixerDecoderLayer(nn.Module):
hidden_states, residual)
output = torch.empty_like(hidden_states)
self.mamba(hidden_states, output, mamba_cache_params, mamba2_metadata)
self.mamba(hidden_states, output)
# Fully Connected
hidden_states, residual = self.pre_ff_layernorm(output, residual)
hidden_states = self.feed_forward(hidden_states)
......@@ -315,22 +306,10 @@ class BambaModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
attn_metadata = get_forward_context().attn_metadata
if not envs.VLLM_USE_V1:
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size,
attn_metadata=attn_metadata,
)
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
......@@ -343,23 +322,11 @@ class BambaModel(nn.Module):
residual = intermediate_tensors["residual"]
residual = None
num_attn = 0
for i, layer in enumerate(self.layers):
if isinstance(layer, BambaAttentionDecoderLayer):
num_attn += 1
layer_mamba_cache_params = None
if isinstance(layer,
BambaMixerDecoderLayer) and mamba_cache_params:
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
i - num_attn)
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states,
residual=residual,
mamba_cache_params=layer_mamba_cache_params,
mamba2_metadata=mamba2_metadata,
)
if not get_pp_group().is_last_rank:
......@@ -457,13 +424,11 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def get_mamba_state_shape_from_config(
cls,
vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches.
Args:
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Tuple containing:
......@@ -482,7 +447,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
head_dim=hf_config.mamba_d_head,
state_size=hf_config.mamba_d_state,
conv_kernel=hf_config.mamba_d_conv,
use_v1=use_v1,
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
......@@ -515,8 +479,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
if not lora_config else lora_config.lora_vocab_padding_size,
prefix=maybe_prefix(prefix, "lm_head"),
)
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
......@@ -534,39 +496,11 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
mamba_cache_params = None
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
num_mamba_layers = \
self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config,
LayerBlockType.mamba
)
mamba_state_shape = \
self.get_mamba_state_shape_from_config(
self.vllm_config, use_v1=False)
mamba_state_dtype = \
self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
num_mamba_layers,
*mamba_state_shape,
*mamba_state_dtype)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.model(input_ids, positions, mamba_cache_params,
intermediate_tensors, inputs_embeds)
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def compute_logits(
self,
hidden_states: torch.Tensor,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from typing import Any
import torch
from vllm.attention.backends.utils import PAD_SLOT_ID
class ConstantSizeCache(ABC):
"""
Abstract base class for managing constant size caches
like Mamba and Minimax.
"""
def __init__(self, max_batch_size: int):
# Maps between the request id and a dict that maps between the seq_id
# and its index inside the cache
self.cache_indices_mapping: dict[str, dict[int, int]] = {}
self.free_cache_indices = list(range(max_batch_size))
@property
@abstractmethod
def cache(self) -> Any:
"""Return the underlying cache tensor(s)"""
pass
@abstractmethod
def _copy_cache(self, from_index: int, to_index: int):
"""Copy cache data from one index to another"""
pass
def current_run_tensors(self, **kwargs) -> tuple:
"""
Return the tensors for the current run's conv and ssm state.
"""
if "seqlen_agnostic_capture_inputs" not in kwargs:
# We get here only on Prefill/Eager mode runs
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
finished_requests_ids = kwargs["finished_requests_ids"]
self._release_finished_requests(finished_requests_ids)
state_indices = self._prepare_current_run_cache(
request_ids_to_seq_ids, finished_requests_ids)
state_indices_tensor = torch.as_tensor(state_indices,
dtype=torch.int32,
device="cuda")
cache_tensors = self.cache
else:
# CUDA graph capturing runs
cache_tensors, state_indices_tensor = kwargs[
"seqlen_agnostic_capture_inputs"]
return (cache_tensors, state_indices_tensor)
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
"""
Copy the relevant state_indices into the CUDA graph input buffer
"""
assert all(
key in kwargs
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
finished_requests_ids = kwargs["finished_requests_ids"]
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
assert "seqlen_agnostic_capture_inputs" in input_buffers
_, input_state_indices_buffer = input_buffers[
"seqlen_agnostic_capture_inputs"]
self._release_finished_requests(finished_requests_ids)
state_indices = self._prepare_current_run_cache(
request_ids_to_seq_ids, finished_requests_ids)
cuda_graph_pad_len = input_state_indices_buffer.shape[0] - len(
state_indices)
state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len)
input_state_indices_buffer.copy_(
torch.as_tensor(state_indices, dtype=torch.int32, device="cuda"))
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
"""
Provide the CUDA graph capture runs with a buffer in adjusted size.
The buffer is used to maintain the Cache during the CUDA graph replay
runs.
"""
state_indices_tensor = torch.as_tensor([PAD_SLOT_ID] * batch_size,
dtype=torch.int32,
device="cuda")
return (self.cache, state_indices_tensor)
def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int,
finished_requests_ids) -> int:
"""
Assign (req_id,seq_id) pair to a `destination_index` index, if
already occupied, move the occupying index to a free index.
"""
if cur_rid in finished_requests_ids:
# set as pad, do not allocate destination index
return PAD_SLOT_ID
elif cur_rid not in self.cache_indices_mapping:
destination_index = self.free_cache_indices.pop()
self.cache_indices_mapping[cur_rid] = {seq_id: destination_index}
return destination_index
elif seq_id not in (seq_ids2indices :=
self.cache_indices_mapping[cur_rid]):
# parallel sampling , where n > 1, assume prefill have
# already happened, so we copy the
# existing cache into the siblings seq_ids caches
index_exists = next(iter(seq_ids2indices.values()))
# case of decoding n>1, copy prefill cache to decoding indices
destination_index = self.free_cache_indices.pop()
self._copy_cache(from_index=index_exists,
to_index=destination_index)
self.cache_indices_mapping[cur_rid][seq_id] = destination_index
return destination_index
else:
return self.cache_indices_mapping[cur_rid][seq_id]
def _prepare_current_run_cache(
self, request_ids_to_seq_ids: dict[str, list[int]],
finished_requests_ids: list[str]) -> list[int]:
return [
self._assign_seq_id_to_cache_index(req_id, seq_id,
finished_requests_ids)
for req_id, seq_ids in request_ids_to_seq_ids.items()
for seq_id in seq_ids
]
def _release_finished_requests(self,
finished_seq_groups_req_ids: list[str]):
for req_id in finished_seq_groups_req_ids:
if req_id in self.cache_indices_mapping:
for seq_id in self.cache_indices_mapping[req_id]:
self.free_cache_indices.append(
self.cache_indices_mapping[req_id][seq_id])
self.cache_indices_mapping.pop(req_id)
......@@ -8,21 +8,17 @@ import torch
from torch import nn
from transformers import FalconH1Config
from vllm import envs
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator)
......@@ -31,8 +27,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.sequence import IntermediateTensors
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
......@@ -179,16 +173,12 @@ class FalconH1SSMDecoderLayer(nn.Module):
self,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs,
):
output = torch.empty_like(hidden_states)
self.mamba(
hidden_states,
output,
mamba_cache_params,
mamba2_metadata=mamba2_metadata,
mup_vector=self.mup_vector,
)
return output, residual
......@@ -364,8 +354,6 @@ class FalconH1ParallelHybrid(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs,
):
residual = hidden_states
......@@ -382,12 +370,10 @@ class FalconH1ParallelHybrid(nn.Module):
# Process input through the SSM branch.
# FalconH1SSMDecoderLayer expects hidden_states, attn_metadata,
# residual, mamba_cache_params, and sequence_idx.
# residual, and sequence_idx.
ssm_hidden, _ = self.mamba(
hidden_states=hidden_states * self.ssm_in_multiplier,
residual=residual,
mamba_cache_params=mamba_cache_params,
mamba2_metadata=mamba2_metadata,
**kwargs,
)
# Sum the outputs from both branches.
......@@ -464,25 +450,10 @@ class FalconH1Model(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# pass a sequence index tensor, that is required for
# proper continuous batching computation including
# chunked prefill
attn_metadata = get_forward_context().attn_metadata
if not envs.VLLM_USE_V1:
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size,
attn_metadata=attn_metadata,
)
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds * self.embedding_multiplier
......@@ -495,14 +466,9 @@ class FalconH1Model(nn.Module):
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
layer_mamba_cache_params = None
if mamba_cache_params:
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(i)
hidden_states = layer(
positions=positions,
hidden_states=hidden_states,
mamba_cache_params=layer_mamba_cache_params,
mamba2_metadata=mamba2_metadata,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
......@@ -541,13 +507,11 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def get_mamba_state_shape_from_config(
cls,
vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches.
Args:
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Tuple containing:
......@@ -570,7 +534,6 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
head_dim=hf_config.mamba_d_head,
state_size=hf_config.mamba_d_state,
conv_kernel=hf_config.mamba_d_conv,
use_v1=use_v1,
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
......@@ -592,7 +555,6 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
prefix=maybe_prefix(prefix, "model"))
self.tie_word_embeddings = config.tie_word_embeddings
self.unpadded_vocab_size = config.vocab_size
self.mamba_cache: Optional[MambaCacheManager] = None
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
if get_pp_group().is_last_rank:
......@@ -637,40 +599,15 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
**kwargs,
):
mamba_cache_params = None
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
mamba_state_shape = \
self.get_mamba_state_shape_from_config(
self.vllm_config, use_v1=False)
mamba_state_dtype = \
self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(
self.vllm_config,
self.config.num_hidden_layers,
*mamba_state_shape,
*mamba_state_dtype,
)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.model(
input_ids,
positions,
mamba_cache_params,
intermediate_tensors,
inputs_embeds,
)
return hidden_states
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def compute_logits(
self,
hidden_states: torch.Tensor,
......
......@@ -9,19 +9,15 @@ import torch
from torch import nn
from transformers import GraniteMoeHybridConfig
from vllm import envs
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator)
......@@ -30,10 +26,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType
from .granitemoe import GraniteMoeMoE
from .granitemoeshared import GraniteMoeSharedMLP
......@@ -102,14 +95,12 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module):
self,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs,
):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
output = torch.empty_like(hidden_states)
self.mamba(hidden_states, output, mamba_cache_params, mamba2_metadata)
self.mamba(hidden_states, output)
hidden_states = residual + output * self.residual_multiplier
residual = hidden_states
......@@ -182,8 +173,6 @@ class GraniteMoeHybridAttentionDecoderLayer(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
......@@ -366,22 +355,10 @@ class GraniteMoeHybridModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
attn_metadata = get_forward_context().attn_metadata
if not envs.VLLM_USE_V1:
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size,
attn_metadata=attn_metadata,
)
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
......@@ -399,20 +376,9 @@ class GraniteMoeHybridModel(nn.Module):
for i, layer in enumerate(self.layers):
if isinstance(layer, GraniteMoeHybridAttentionDecoderLayer):
num_attn += 1
layer_mamba_cache_params = None
if isinstance(
layer,
GraniteMoeHybridMambaDecoderLayer) and mamba_cache_params:
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
i - num_attn)
hidden_states, residual = layer(
positions=positions,
hidden_states, residual = layer(positions=positions,
hidden_states=hidden_states,
residual=residual,
mamba_cache_params=layer_mamba_cache_params,
mamba2_metadata=mamba2_metadata)
residual=residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
......@@ -552,13 +518,11 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
def get_mamba_state_shape_from_config(
cls,
vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches.
Args:
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Tuple containing:
......@@ -577,7 +541,6 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
head_dim=hf_config.mamba_d_head,
state_size=hf_config.mamba_d_state,
conv_kernel=hf_config.mamba_d_conv,
use_v1=use_v1,
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
......@@ -620,9 +583,6 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
scale=1 /
self.config.logits_scaling)
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
......@@ -636,38 +596,11 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
mamba_cache_params = None
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
num_mamba_layers = (
self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config,
LayerBlockType.mamba))
mamba_state_shape = \
self.get_mamba_state_shape_from_config(
self.vllm_config, use_v1=False)
mamba_state_dtype = \
self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
num_mamba_layers,
*mamba_state_shape,
*mamba_state_dtype)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.model(input_ids, positions, mamba_cache_params,
intermediate_tensors, inputs_embeds)
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def compute_logits(
self,
hidden_states: torch.Tensor,
......
......@@ -9,7 +9,6 @@ import torch
from torch import nn
from transformers import JambaConfig
from vllm import envs
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
......@@ -30,10 +29,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaMLP as JambaMLP
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
......@@ -145,7 +141,6 @@ class JambaMambaDecoderLayer(nn.Module):
self,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
**kwargs,
):
if residual is None:
......@@ -156,7 +151,7 @@ class JambaMambaDecoderLayer(nn.Module):
hidden_states, residual)
output = torch.empty_like(hidden_states)
self.mamba(hidden_states, output, mamba_cache_params)
self.mamba(hidden_states, output)
# Fully Connected
hidden_states, residual = self.pre_ff_layernorm(output, residual)
hidden_states = self.feed_forward(hidden_states)
......@@ -333,7 +328,6 @@ class JambaModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
......@@ -348,24 +342,11 @@ class JambaModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
kv_cache_index = 0
mamba_cache_index = 0
for layer in islice(self.layers, self.start_layer, self.end_layer):
layer_mamba_cache_params = None
if isinstance(layer, JambaAttentionDecoderLayer):
kv_cache_index += 1
if isinstance(layer,
JambaMambaDecoderLayer) and mamba_cache_params:
current_state_layer = mamba_cache_index
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
current_state_layer)
mamba_cache_index += 1
hidden_states, residual = layer(
positions=positions,
hidden_states, residual = layer(positions=positions,
hidden_states=hidden_states,
residual=residual,
mamba_cache_params=layer_mamba_cache_params)
residual=residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
......@@ -503,8 +484,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
if not lora_config else lora_config.lora_vocab_padding_size,
prefix=maybe_prefix(prefix, "lm_head"),
)
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
......@@ -521,24 +500,9 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
# NOTE: mamba_cache_params is not needed for v1
mamba_cache_params = None
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
num_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba)
state_shape = self.get_mamba_state_shape_from_config(
self.vllm_config)
state_dtype = self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
num_layers, *state_shape,
*state_dtype)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.model(input_ids, positions, mamba_cache_params,
intermediate_tensors, inputs_embeds)
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
......@@ -574,7 +538,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
intermediate_size=hf_config.mamba_expand * hidden_size,
state_size=hf_config.mamba_d_state,
conv_kernel=hf_config.mamba_d_conv,
use_v1=envs.VLLM_USE_V1,
)
def compute_logits(
......
......@@ -8,7 +8,6 @@ import torch
import torch.nn as nn
from transformers import Lfm2Config
from vllm import envs
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
......@@ -297,7 +296,6 @@ class Lfm2ShortConvDecoderLayer(nn.Module):
self.conv(
hidden_states,
output,
conv_metadata=None,
)
hidden_states, residual = self.ffn_norm(output, residual)
hidden_states = self.feed_forward(hidden_states)
......@@ -459,13 +457,11 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def get_mamba_state_shape_from_config(
cls,
vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, int]]:
""" Calculate shapes for LFM2's convolutional cache.
Args:
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Tuple containing:
......@@ -478,7 +474,6 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
tp_world_size=parallel_config.tensor_parallel_size,
intermediate_size=hf_config.conv_dim,
conv_kernel=hf_config.conv_L_cache,
use_v1=use_v1,
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
......@@ -489,8 +484,6 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
scheduler_config = vllm_config.scheduler_config
assert (not cache_config.enable_prefix_caching
), "Lfm2 currently does not support prefix caching"
assert envs.VLLM_USE_V1, (
"Lfm2ForCausalLM doesn't support vLLM v0. Please enable v1")
super().__init__()
self.config = config
......
......@@ -8,7 +8,6 @@ import torch
from torch import nn
from transformers import MambaConfig
from vllm import envs
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed.parallel_state import get_pp_group
......@@ -24,10 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (HasInnerState,
IsAttentionFree, SupportsPP)
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
......@@ -72,7 +68,6 @@ class MambaDecoderLayer(nn.Module):
self,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
**kwargs,
):
if residual is None:
......@@ -82,7 +77,7 @@ class MambaDecoderLayer(nn.Module):
hidden_states, residual = self.norm(hidden_states, residual)
output = torch.empty_like(hidden_states)
self.mixer(hidden_states, output, mamba_cache_params)
self.mixer(hidden_states, output)
return output, residual
......@@ -134,7 +129,6 @@ class MambaModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
mamba_cache_params: Optional[MambaCacheParams] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
......@@ -151,17 +145,9 @@ class MambaModel(nn.Module):
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
layer_cache_params = None
if mamba_cache_params is not None:
layer_cache_params = mamba_cache_params.at_layer_idx(
i - self.start_layer)
hidden_states, residual = layer(
positions=positions,
hidden_states, residual = layer(positions=positions,
hidden_states=hidden_states,
residual=residual,
mamba_cache_params=layer_cache_params)
residual=residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
......@@ -225,9 +211,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
prefix=maybe_prefix(prefix, "lm_head"),
)
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
......@@ -244,22 +227,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
mamba_cache_params = None
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
num_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba)
state_shape = self.get_mamba_state_shape_from_config(
self.vllm_config)
state_dtype = self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
num_layers, *state_shape,
*state_dtype)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.backbone(input_ids, positions, mamba_cache_params,
hidden_states = self.backbone(input_ids, positions,
intermediate_tensors, inputs_embeds)
return hidden_states
......@@ -288,8 +256,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
tp_world_size=parallel_config.tensor_parallel_size,
intermediate_size=hf_config.intermediate_size,
state_size=hf_config.state_size,
conv_kernel=hf_config.conv_kernel,
use_v1=envs.VLLM_USE_V1)
conv_kernel=hf_config.conv_kernel)
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(
......
......@@ -8,16 +8,11 @@ import torch
from torch import nn
from transformers import MambaConfig
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator)
......@@ -28,10 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (HasInnerState,
IsAttentionFree)
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
......@@ -74,8 +66,6 @@ class Mamba2DecoderLayer(nn.Module):
self,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs,
):
if residual is None:
......@@ -85,7 +75,7 @@ class Mamba2DecoderLayer(nn.Module):
hidden_states, residual = self.norm(hidden_states, residual)
output = torch.empty_like(hidden_states)
self.mixer(hidden_states, output, mamba_cache_params, mamba2_metadata)
self.mixer(hidden_states, output)
return output, residual
......@@ -137,7 +127,6 @@ class Mamba2Model(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
......@@ -152,25 +141,10 @@ class Mamba2Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
if not envs.VLLM_USE_V1:
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.chunk_size,
attn_metadata=attn_metadata,
)
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
positions=positions,
hidden_states, residual = layer(positions=positions,
hidden_states=hidden_states,
residual=residual,
mamba_cache_params=mamba_cache_params.at_layer_idx(
i - self.start_layer) if mamba_cache_params else None,
mamba2_metadata=mamba2_metadata)
residual=residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
......@@ -222,13 +196,11 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
def get_mamba_state_shape_from_config(
cls,
vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches.
Args:
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Tuple containing:
......@@ -247,7 +219,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
head_dim=hf_config.head_dim,
state_size=hf_config.state_size,
conv_kernel=hf_config.conv_kernel,
use_v1=use_v1,
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
......@@ -282,9 +253,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
if config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(self.backbone.embeddings)
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
......@@ -300,29 +268,8 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
num_mamba_layers = (
self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config,
LayerBlockType.mamba))
mamba_state_shape = \
self.get_mamba_state_shape_from_config(
self.vllm_config, use_v1=False)
mamba_state_dtype = \
self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
num_mamba_layers,
*mamba_state_shape,
*mamba_state_dtype)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
else:
# NOTE: mamba_cache_params is not needed for v1
mamba_cache_params = None
hidden_states = self.backbone(input_ids, positions, mamba_cache_params,
hidden_states = self.backbone(input_ids, positions,
intermediate_tensors, inputs_embeds)
return hidden_states
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import torch
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig
from vllm.model_executor.models.constant_size_cache import ConstantSizeCache
@dataclass
class MambaCacheParams:
conv_state: torch.Tensor = torch.Tensor()
ssm_state: torch.Tensor = torch.Tensor()
state_indices_tensor: torch.Tensor = torch.Tensor()
def at_layer_idx(self, layer_idx):
return MambaCacheParams(self.conv_state[layer_idx],
self.ssm_state[layer_idx],
self.state_indices_tensor)
class MambaCacheManager(ConstantSizeCache):
def __init__(self, vllm_config: VllmConfig, num_mamba_layers: int,
conv_state_shape: tuple[int, int],
temporal_state_shape: tuple[int, int],
conv_state_dtype: torch.dtype,
temporal_state_dtype: torch.dtype):
self.conv_state_dtype = conv_state_dtype
self.temporal_state_dtype = temporal_state_dtype
# Determine max batch size to set size of MambaCache
max_batch_size = vllm_config.scheduler_config.max_num_seqs
if not vllm_config.model_config.enforce_eager:
max_batch_size = vllm_config.pad_for_cudagraph(max_batch_size)
# Initialize parent class
super().__init__(max_batch_size)
# assume conv_state = (dim, state_len)
assert conv_state_shape[0] > conv_state_shape[1]
conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
(conv_state_shape[1], conv_state_shape[0]),
dtype=self.conv_state_dtype,
device="cuda").transpose(-1, -2)
temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
temporal_state_shape,
dtype=self.temporal_state_dtype,
device="cuda")
self._mamba_cache = (conv_state, temporal_state)
@property
def cache(self):
return self._mamba_cache
def _copy_cache(self, from_index: int, to_index: int):
for cache_t in self.cache:
cache_t[:, to_index].copy_(cache_t[:, from_index],
non_blocking=True)
def current_run_tensors(self, **kwargs) -> MambaCacheParams:
"""
Return the tensors for the current run's conv and ssm state.
"""
cache_tensors, state_indices_tensor = super().current_run_tensors(
**kwargs)
return MambaCacheParams(cache_tensors[0], cache_tensors[1],
state_indices_tensor)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
"""
Provide the CUDA graph capture runs with a buffer in adjusted size.
The buffer is used to maintain the Mamba Cache during the CUDA graph
replay runs.
"""
return self._mamba_cache, torch.as_tensor([PAD_SLOT_ID] * batch_size,
dtype=torch.int32,
device="cuda")
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import torch
from vllm.model_executor.models.constant_size_cache import ConstantSizeCache
@dataclass
class MinimaxCacheParams:
minimax_cache: torch.Tensor = torch.Tensor()
state_indices_tensor: torch.Tensor = torch.Tensor()
def at_layer_idx(self, layer_idx):
return MinimaxCacheParams(self.minimax_cache[layer_idx, ...],
self.state_indices_tensor)
class MinimaxCacheManager(ConstantSizeCache):
def __init__(self, dtype, cache_shape):
super().__init__(cache_shape[1]) # max_batch_size is cache_shape[1]
self._minimax_cache = torch.empty(size=cache_shape,
dtype=dtype,
device="cuda")
@property
def cache(self):
return self._minimax_cache
def _copy_cache(self, from_index: int, to_index: int):
assert len(self.cache) > 0
for cache_t in self.cache:
cache_t[:, to_index].copy_(cache_t[:, from_index],
non_blocking=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment