Unverified Commit a903669e authored by Thomas Parnell's avatar Thomas Parnell Committed by GitHub
Browse files

[V1] Remove V0 code paths for Hybrid models (#25400)


Signed-off-by: default avatarThomas Parnell <tpa@zurich.ibm.com>
parent 2c58742d
...@@ -20,7 +20,9 @@ pytestmark = pytest.mark.hybrid_model ...@@ -20,7 +20,9 @@ pytestmark = pytest.mark.hybrid_model
SSM_MODELS = [ SSM_MODELS = [
"state-spaces/mamba-130m-hf", "state-spaces/mamba-130m-hf",
"tiiuae/falcon-mamba-tiny-dev", "tiiuae/falcon-mamba-tiny-dev",
"yujiepan/mamba2-codestral-v0.1-tiny-random", # mamba2-codestral in transformers is broken pending:
# https://github.com/huggingface/transformers/pull/40861
#"yujiepan/mamba2-codestral-v0.1-tiny-random",
] ]
HYBRID_MODELS = [ HYBRID_MODELS = [
...@@ -31,18 +33,7 @@ HYBRID_MODELS = [ ...@@ -31,18 +33,7 @@ HYBRID_MODELS = [
"ibm-granite/granite-4.0-tiny-preview", "ibm-granite/granite-4.0-tiny-preview",
"tiiuae/Falcon-H1-0.5B-Base", "tiiuae/Falcon-H1-0.5B-Base",
"LiquidAI/LFM2-1.2B", "LiquidAI/LFM2-1.2B",
] "tiny-random/qwen3-next-moe",
V1_SUPPORTED_MODELS = [
"state-spaces/mamba-130m-hf",
"ai21labs/Jamba-tiny-dev",
"pfnet/plamo-2-1b",
"yujiepan/mamba2-codestral-v0.1-tiny-random",
"Zyphra/Zamba2-1.2B-instruct",
"hmellor/tiny-random-BambaForCausalLM",
"ibm-granite/granite-4.0-tiny-preview",
"tiiuae/Falcon-H1-0.5B-Base",
"LiquidAI/LFM2-1.2B",
] ]
FULL_CUDA_GRAPH_MODELS = [ FULL_CUDA_GRAPH_MODELS = [
...@@ -51,10 +42,6 @@ FULL_CUDA_GRAPH_MODELS = [ ...@@ -51,10 +42,6 @@ FULL_CUDA_GRAPH_MODELS = [
"Zyphra/Zamba2-1.2B-instruct", "Zyphra/Zamba2-1.2B-instruct",
] ]
V0_UNSUPPORTED_MODELS = [
"LiquidAI/LFM2-1.2B",
]
FP32_STATE_MODELS = [ FP32_STATE_MODELS = [
"state-spaces/mamba-130m-hf", "state-spaces/mamba-130m-hf",
"Zyphra/Zamba2-1.2B-instruct", "Zyphra/Zamba2-1.2B-instruct",
...@@ -88,19 +75,15 @@ def test_models( ...@@ -88,19 +75,15 @@ def test_models(
hf_outputs = hf_model.generate_greedy_logprobs_limit( hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs) example_prompts, max_tokens, num_logprobs)
if model in V1_SUPPORTED_MODELS:
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs( vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs) example_prompts, max_tokens, num_logprobs)
else:
vllm_v1_outputs = None
if model in V1_SUPPORTED_MODELS:
check_logprobs_close( check_logprobs_close(
outputs_0_lst=hf_outputs, outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_v1_outputs, outputs_1_lst=vllm_outputs,
name_0="hf", name_0="hf",
name_1="vllm-v1", name_1="vllm",
) )
...@@ -299,14 +282,14 @@ def test_full_cuda_graph( ...@@ -299,14 +282,14 @@ def test_full_cuda_graph(
example_prompts, max_tokens, num_logprobs) example_prompts, max_tokens, num_logprobs)
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs( vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs) example_prompts, max_tokens, num_logprobs)
check_logprobs_close( check_logprobs_close(
outputs_0_lst=hf_outputs, outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_v1_outputs, outputs_1_lst=vllm_outputs,
name_0="hf", name_0="hf",
name_1="vllm-v1", name_1="vllm",
) )
...@@ -340,12 +323,12 @@ def test_fp32_cache_state( ...@@ -340,12 +323,12 @@ def test_fp32_cache_state(
with vllm_runner(model, with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS, max_num_seqs=MAX_NUM_SEQS,
**{cache_dtype_param: "float32"}) as vllm_model: **{cache_dtype_param: "float32"}) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs( vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs) example_prompts, max_tokens, num_logprobs)
check_logprobs_close( check_logprobs_close(
outputs_0_lst=hf_outputs, outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_v1_outputs, outputs_1_lst=vllm_outputs,
name_0="hf", name_0="hf",
name_1="vllm-v1", name_1="vllm",
) )
...@@ -312,13 +312,11 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -312,13 +312,11 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"), "PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"),
"PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"), "PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"),
"Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"), "Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"),
"Phi4FlashForCausalLM": _HfExamplesInfo("microsoft/Phi-4-mini-flash-reasoning", # noqa: E501
trust_remote_code=True,
v0_only=True,
max_model_len=10240),
"PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct", "PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct",
trust_remote_code=True), trust_remote_code=True),
"Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b", "Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b",
max_transformers_version="4.55.4",
transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers", # noqa: E501
trust_remote_code=True), trust_remote_code=True),
"QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat", "QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat",
max_transformers_version="4.53", max_transformers_version="4.53",
...@@ -330,7 +328,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -330,7 +328,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"), "Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"), "Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
"Qwen3NextForCausalLM": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct", "Qwen3NextForCausalLM": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct",
min_transformers_version="4.56.2"), extras={"tiny-random": "tiny-random/qwen3-next-moe"}, # noqa: E501
min_transformers_version="4.56.3"),
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"), "RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
"SeedOssForCausalLM": _HfExamplesInfo("ByteDance-Seed/Seed-OSS-36B-Instruct", # noqa: E501 "SeedOssForCausalLM": _HfExamplesInfo("ByteDance-Seed/Seed-OSS-36B-Instruct", # noqa: E501
trust_remote_code=True, trust_remote_code=True,
...@@ -644,7 +643,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { ...@@ -644,7 +643,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
trust_remote_code=True, trust_remote_code=True,
speculative_model="XiaomiMiMo/MiMo-7B-RL"), speculative_model="XiaomiMiMo/MiMo-7B-RL"),
"Qwen3NextMTP": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct", "Qwen3NextMTP": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct",
min_transformers_version="4.56.2"), min_transformers_version="4.56.3"),
} }
_TRANSFORMERS_BACKEND_MODELS = { _TRANSFORMERS_BACKEND_MODELS = {
......
...@@ -20,10 +20,7 @@ class MambaBase(AttentionLayerBase): ...@@ -20,10 +20,7 @@ class MambaBase(AttentionLayerBase):
# Contains the KV cache (mamba state) for the layer # Contains the KV cache (mamba state) for the layer
# in the shape specified by `self.get_state_shape`. # in the shape specified by `self.get_state_shape`.
# The outer list is for v0 PP virtual engine. Though this code path kv_cache: tuple[torch.Tensor, ...]
# only runs for v1, we have to do this to unify with the interface
# of Attention + v0 PP.
kv_cache: list[Iterable[torch.Tensor]]
@abstractmethod @abstractmethod
def get_state_shape(self) -> Iterable[tuple[int, ...]]: def get_state_shape(self) -> Iterable[tuple[int, ...]]:
......
...@@ -15,7 +15,6 @@ import torch.nn.functional as F ...@@ -15,7 +15,6 @@ import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from torch import nn from torch import nn
from vllm import envs
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
...@@ -42,8 +41,6 @@ if TYPE_CHECKING: ...@@ -42,8 +41,6 @@ if TYPE_CHECKING:
import torch import torch
import torch.distributed import torch.distributed
from vllm.model_executor.models.minimax_cache import MinimaxCacheParams
class MiniMaxText01RMSNormTP(CustomOp): class MiniMaxText01RMSNormTP(CustomOp):
name = "MiniMaxText01RMSNormTP" name = "MiniMaxText01RMSNormTP"
...@@ -225,7 +222,6 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase): ...@@ -225,7 +222,6 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
self.tp_heads:(self.tp_rank + 1) * self.tp_heads:(self.tp_rank + 1) *
self.tp_heads].contiguous() self.tp_heads].contiguous()
if envs.VLLM_USE_V1:
compilation_config = get_current_vllm_config().compilation_config compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context: if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}") raise ValueError(f"Duplicate layer name: {prefix}")
...@@ -268,8 +264,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase): ...@@ -268,8 +264,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
break break
if _prefill_idx >= len(state_indices_tensor): if _prefill_idx >= len(state_indices_tensor):
break break
# prefills are packed at end of batch in V1 offset = attn_metadata.num_decode_tokens
offset = attn_metadata.num_decode_tokens if envs.VLLM_USE_V1 else 0
_start = attn_metadata.query_start_loc[offset + _prefill_idx] _start = attn_metadata.query_start_loc[offset + _prefill_idx]
_end = attn_metadata.query_start_loc[offset + _prefill_idx + 1] _end = attn_metadata.query_start_loc[offset + _prefill_idx + 1]
slot_id = state_indices_tensor[offset + _prefill_idx] slot_id = state_indices_tensor[offset + _prefill_idx]
...@@ -291,10 +286,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase): ...@@ -291,10 +286,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
hidden_decode = self._decode_infer(q, k, v, kv_cache, hidden_decode = self._decode_infer(q, k, v, kv_cache,
state_indices_tensor, state_indices_tensor,
attn_metadata) attn_metadata)
if envs.VLLM_USE_V1:
hidden.insert(0, hidden_decode) hidden.insert(0, hidden_decode)
else:
hidden.append(hidden_decode)
if not hidden: if not hidden:
return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype) return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
...@@ -304,13 +296,6 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase): ...@@ -304,13 +296,6 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor,
attn_metadata): attn_metadata):
if not envs.VLLM_USE_V1:
q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
num_prefills = getattr(attn_metadata, "num_prefills", 0)
slot_id = state_indices_tensor[num_prefills:]
else:
q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
...@@ -320,11 +305,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase): ...@@ -320,11 +305,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
return hidden return hidden
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor, def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor) -> None:
kv_caches: MinimaxCacheParams) -> None:
if not envs.VLLM_USE_V1:
self._forward(hidden_states, output, positions, kv_caches)
else:
torch.ops.vllm.linear_attention( torch.ops.vllm.linear_attention(
hidden_states, hidden_states,
output, output,
...@@ -333,11 +314,10 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase): ...@@ -333,11 +314,10 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
) )
def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor, def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor) -> None:
kv_caches: Optional[MinimaxCacheParams]) -> None:
forward_context = get_forward_context() forward_context = get_forward_context()
attn_metadata: AttentionMetadata = forward_context.attn_metadata attn_metadata: AttentionMetadata = forward_context.attn_metadata
if envs.VLLM_USE_V1 and attn_metadata is not None: if attn_metadata is not None:
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix] attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, LinearAttentionMetadata) assert isinstance(attn_metadata, LinearAttentionMetadata)
...@@ -351,32 +331,26 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase): ...@@ -351,32 +331,26 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
qkvact = torch.nn.functional.silu(qkv32) qkvact = torch.nn.functional.silu(qkv32)
qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1)) qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1) q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
if envs.VLLM_USE_V1:
if attn_metadata is not None: if attn_metadata is not None:
kv_cache = self.kv_cache[forward_context.virtual_engine][0] kv_cache = self.kv_cache[forward_context.virtual_engine][0]
state_indices_tensor = attn_metadata.state_indices_tensor state_indices_tensor = attn_metadata.state_indices_tensor
num_prefills = getattr(attn_metadata, "num_prefills", 0) num_prefills = getattr(attn_metadata, "num_prefills", 0)
if num_prefills > 0: if num_prefills > 0:
num_decode_tokens = getattr(attn_metadata, num_decode_tokens = getattr(attn_metadata, "num_decode_tokens",
"num_decode_tokens", 0) 0)
for prefill_idx in range(num_prefills): for prefill_idx in range(num_prefills):
q_start = attn_metadata.query_start_loc[ q_start = attn_metadata.query_start_loc[num_decode_tokens +
num_decode_tokens + prefill_idx] prefill_idx]
q_end = attn_metadata.query_start_loc[num_decode_tokens q_end = attn_metadata.query_start_loc[num_decode_tokens +
+ prefill_idx + prefill_idx + 1]
1]
query_len = q_end - q_start query_len = q_end - q_start
context_len = attn_metadata.seq_lens[ context_len = attn_metadata.seq_lens[
num_decode_tokens + prefill_idx] - query_len num_decode_tokens + prefill_idx] - query_len
if context_len == 0: if context_len == 0:
block_to_clear = state_indices_tensor[ block_to_clear = state_indices_tensor[num_decode_tokens
num_decode_tokens + prefill_idx] + prefill_idx]
kv_cache[block_to_clear, ...] = 0 kv_cache[block_to_clear, ...] = 0
else:
assert kv_caches is not None
kv_cache = kv_caches.minimax_cache
state_indices_tensor = kv_caches.state_indices_tensor
decode_only = getattr(attn_metadata, "num_prefills", 0) == 0 decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
if attn_metadata is None: if attn_metadata is None:
...@@ -410,8 +384,7 @@ def linear_attention( ...@@ -410,8 +384,7 @@ def linear_attention(
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
self._forward(hidden_states=hidden_states, self._forward(hidden_states=hidden_states,
output=output, output=output,
positions=positions, positions=positions)
kv_caches=None)
def linear_attention_fake( def linear_attention_fake(
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional, Union
import numpy as np
import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.placeholder_attn import (
PlaceholderAttentionMetadata)
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.platforms import current_platform
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
from vllm.v1.attention.backends.mamba2_attn import (
Mamba2AttentionMetadata, _query_start_loc_to_chunk_indices_offsets)
@dataclass
class Mamba2Metadata:
prep_initial_states: bool
chunk_size: int
has_initial_states_p: torch.Tensor
seq_idx_p: torch.Tensor
chunk_indices_p: torch.Tensor
chunk_offsets_p: torch.Tensor
"""
With continuous batching layout of `x` in vLLM, to enable a Triton program
to handle a request in parallel, two supporting tensors are used
(batch_ptr, token_chunk_offset_ptr)
BLOCK_M = the # tokens to be handled by a Triton program
(can be customized for different hardware)
nums_dict:
tracks the data associated with a given value of BLOCK_M
BLOCK_M = #tokens handled by a Triton program
cu_seqlen: total tokens per batch
(used as flag to update other data at each new input)
batch_ptr: tracks batch-id handled by the Triton program
token_chunk_offset_ptr: tracks token group_idx handled by the Triton program
(Triton implementation of causal_conv1d handles parallelism in 3-axes
- feature-axis
- batch-axis
- sequence-axis)
"""
nums_dict: Optional[dict] = None
cu_seqlen: Optional[int] = None
batch_ptr: Optional[torch.Tensor] = None
token_chunk_offset_ptr: Optional[torch.Tensor] = None
def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]:
"""Returns the appropriate metadata classes for the current platform."""
if current_platform.is_rocm():
from vllm.v1.attention.backends.rocm_aiter_fa import (
AiterFlashAttentionMetadata)
from vllm.v1.attention.backends.triton_attn import (
TritonAttentionMetadata)
return (AiterFlashAttentionMetadata, TritonAttentionMetadata,
PlaceholderAttentionMetadata)
if current_platform.is_cuda():
from vllm.v1.attention.backends.flash_attn import (
FlashAttentionMetadata)
from vllm.v1.attention.backends.xformers import (
XFormersAttentionMetadata)
return (FlashAttentionMetadata, XFormersAttentionMetadata,
PlaceholderAttentionMetadata)
raise ValueError(
f"Unsupported platform for Mamba2: {current_platform.device_type}")
def prepare_mamba2_metadata(
chunk_size: int,
attn_metadata: AttentionMetadata,
) -> Mamba2Metadata:
# compute number of prefill and decode requests
# NOTE: in V0 we assume prefills are before decodes
num_prefills = attn_metadata.num_prefills
num_prefill_tokens = attn_metadata.num_prefill_tokens
seq_idx_p = None
chunk_indices_p, chunk_offsets_p = None, None
# Need flags to indicate if there are initial states
# currently we really only support the FlashAttention backend
has_initial_states_p = None
prep_initial_states = False
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
if num_prefills > 0:
attn_metadata_instances = get_platform_metadata_classes()
if (isinstance(attn_metadata, attn_metadata_instances)
and attn_metadata.context_lens_tensor is not None):
# precompute flag to avoid device syncs later in mamba2 layer
# forwards
# prep is only needed for mamba2 ssd prefill processing
has_initial_states_p = (
attn_metadata.context_lens_tensor[:num_prefills] > 0)
prep_initial_states = torch.any(has_initial_states_p).item()
query_start_loc_p = attn_metadata.query_start_loc[:num_prefills + 1]
seq_idx_p = torch.repeat_interleave(torch.arange(
num_prefills, dtype=torch.int32, device=query_start_loc_p.device),
query_start_loc_p.diff(),
output_size=num_prefill_tokens)
seq_idx_p.unsqueeze_(0)
# We compute metadata for chunked prefill once at the top level model
# forward and reuse them in mamba layers. If not needed, they will be
# ignored inside mamba kernels.
if prep_initial_states:
chunk_indices_p, chunk_offsets_p = \
_query_start_loc_to_chunk_indices_offsets(
query_start_loc_p, chunk_size, num_prefill_tokens)
return Mamba2Metadata(has_initial_states_p=has_initial_states_p,
prep_initial_states=prep_initial_states,
chunk_size=chunk_size,
seq_idx_p=seq_idx_p,
chunk_indices_p=chunk_indices_p,
chunk_offsets_p=chunk_offsets_p)
def update_metadata(x: torch.Tensor, query_start_loc: torch.Tensor,
mamba2_metadata: Union[Mamba2Metadata,
Mamba2AttentionMetadata,
GDNAttentionMetadata]):
"""
this is triggered upon handling a new input at the first layer
"""
dim, cu_seqlen = x.shape
mamba2_metadata.cu_seqlen = cu_seqlen
seqlens = np.diff(query_start_loc.to('cpu'))
nums_dict = {} # type: ignore
for BLOCK_M in [8]: # cover all BLOCK_M values
nums = -(-seqlens // BLOCK_M)
nums_dict[BLOCK_M] = {}
nums_dict[BLOCK_M]['nums'] = nums
nums_dict[BLOCK_M]['tot'] = nums.sum().item()
mlist = torch.from_numpy(np.repeat(np.arange(len(nums)), nums))
nums_dict[BLOCK_M]['mlist'] = mlist
mlist_len = len(nums_dict[BLOCK_M]['mlist'])
nums_dict[BLOCK_M]['mlist_len'] = mlist_len
MAX_NUM_PROGRAMS = max(1024, mlist_len) * 2
offsetlist = [] # type: ignore
for idx, num in enumerate(nums):
offsetlist.extend(range(num))
offsetlist = torch.tensor(offsetlist, dtype=torch.int32)
nums_dict[BLOCK_M]['offsetlist'] = offsetlist
if mamba2_metadata.batch_ptr is None:
# Update default value after class definition
#mamba2_metadata.MAX_NUM_PROGRAMS *= 2
mamba2_metadata.batch_ptr = torch.full((MAX_NUM_PROGRAMS, ),
PAD_SLOT_ID,
dtype=torch.int32,
device='cuda')
mamba2_metadata.token_chunk_offset_ptr = torch.full(
(MAX_NUM_PROGRAMS, ),
PAD_SLOT_ID,
dtype=torch.int32,
device='cuda')
else:
if mamba2_metadata.batch_ptr.nelement() < MAX_NUM_PROGRAMS:
mamba2_metadata.batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_(
PAD_SLOT_ID)
mamba2_metadata.token_chunk_offset_ptr.resize_( # type: ignore
MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID)
mamba2_metadata.batch_ptr[0:mlist_len].copy_(mlist)
mamba2_metadata.token_chunk_offset_ptr[ # type: ignore
0:mlist_len].copy_(offsetlist)
nums_dict[BLOCK_M]['batch_ptr'] = mamba2_metadata.batch_ptr
nums_dict[BLOCK_M]['token_chunk_offset_ptr'] = (
mamba2_metadata.token_chunk_offset_ptr) # type: ignore
mamba2_metadata.nums_dict = nums_dict
return mamba2_metadata
...@@ -10,8 +10,6 @@ import torch ...@@ -10,8 +10,6 @@ import torch
from torch import nn from torch import nn
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
...@@ -28,7 +26,6 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( ...@@ -28,7 +26,6 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update) causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_state_update) selective_scan_fn, selective_state_update)
from vllm.model_executor.models.mamba_cache import MambaCacheParams
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
...@@ -149,16 +146,12 @@ class MambaMixer(MambaBase, CustomOp): ...@@ -149,16 +146,12 @@ class MambaMixer(MambaBase, CustomOp):
has_weight=rms_norm_has_weight, has_weight=rms_norm_has_weight,
) if use_rms_norm else None ) if use_rms_norm else None
if envs.VLLM_USE_V1:
compilation_config = get_current_vllm_config().compilation_config compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context: if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}") raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self compilation_config.static_forward_context[prefix] = self
# The outer list is for v0 PP virtual engine. Though this code path
# only runs for v1, we have to do this to unify with the interface
# of Attention + v0 PP.
# The inner tuple is (conv_state, ssm_state) # The inner tuple is (conv_state, ssm_state)
self.kv_cache = [(torch.tensor([]), torch.tensor([]))] self.kv_cache = (torch.tensor([]), torch.tensor([]))
self.model_config = model_config self.model_config = model_config
self.cache_config = cache_config self.cache_config = cache_config
...@@ -186,29 +179,18 @@ class MambaMixer(MambaBase, CustomOp): ...@@ -186,29 +179,18 @@ class MambaMixer(MambaBase, CustomOp):
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
return discrete_time_step, B, C return discrete_time_step, B, C
def forward(self, def forward(self, hidden_states: torch.Tensor, output: torch.Tensor):
hidden_states: torch.Tensor,
output: torch.Tensor,
mamba_cache_params: Optional[MambaCacheParams] = None):
if not envs.VLLM_USE_V1:
CustomOp.forward(self, hidden_states, output, mamba_cache_params)
else:
torch.ops.vllm.mamba_mixer( torch.ops.vllm.mamba_mixer(
hidden_states, hidden_states,
output, output,
self.prefix, self.prefix,
) )
def forward_native(self, def forward_native(self, hidden_states: torch.Tensor,
hidden_states: torch.Tensor, output: torch.Tensor):
output: torch.Tensor,
mamba_cache_params: Optional[MambaCacheParams] = None):
pass pass
def forward_cuda(self, def forward_cuda(self, hidden_states: torch.Tensor, output: torch.Tensor):
hidden_states: torch.Tensor,
output: torch.Tensor,
mamba_cache_params: Optional[MambaCacheParams] = None):
""" """
Run the Mamba-1 SSM pipeline. Run the Mamba-1 SSM pipeline.
...@@ -234,7 +216,6 @@ class MambaMixer(MambaBase, CustomOp): ...@@ -234,7 +216,6 @@ class MambaMixer(MambaBase, CustomOp):
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata attn_metadata = forward_context.attn_metadata
if envs.VLLM_USE_V1:
if attn_metadata is not None: if attn_metadata is not None:
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix] attn_metadata = attn_metadata[self.prefix]
...@@ -247,18 +228,6 @@ class MambaMixer(MambaBase, CustomOp): ...@@ -247,18 +228,6 @@ class MambaMixer(MambaBase, CustomOp):
ssm_state = self_kv_cache[1] ssm_state = self_kv_cache[1]
has_initial_states = mamba1_metadata.has_initial_states has_initial_states = mamba1_metadata.has_initial_states
num_padded_decodes = mamba1_metadata.num_padded_decodes num_padded_decodes = mamba1_metadata.num_padded_decodes
else:
assert isinstance(attn_metadata, AttentionMetadata)
assert mamba_cache_params is not None
conv_state = mamba_cache_params.conv_state
ssm_state = mamba_cache_params.ssm_state
state_indices_tensor = mamba_cache_params.state_indices_tensor
query_start_loc = attn_metadata.query_start_loc
context_lens_tensor = attn_metadata.context_lens_tensor
has_initial_states = None
if context_lens_tensor is not None:
has_initial_states = context_lens_tensor > 0
num_padded_decodes = attn_metadata.num_decode_tokens
# 1. Gated MLP's linear projection # 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
...@@ -267,7 +236,7 @@ class MambaMixer(MambaBase, CustomOp): ...@@ -267,7 +236,7 @@ class MambaMixer(MambaBase, CustomOp):
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2)) self.conv1d.weight.size(2))
if envs.VLLM_USE_V1 and attn_metadata is None: if attn_metadata is None:
# V1 profile run # V1 profile run
hidden_states_BC = hidden_states_BC.contiguous() hidden_states_BC = hidden_states_BC.contiguous()
return self.out_proj(hidden_states_BC.transpose(-2, -1))[0] return self.out_proj(hidden_states_BC.transpose(-2, -1))[0]
...@@ -368,10 +337,7 @@ class MambaMixer(MambaBase, CustomOp): ...@@ -368,10 +337,7 @@ class MambaMixer(MambaBase, CustomOp):
out=scan_outputs_d) out=scan_outputs_d)
scan_outputs_d = scan_outputs_d.transpose(0, 1) scan_outputs_d = scan_outputs_d.transpose(0, 1)
if envs.VLLM_USE_V1:
ssm_outputs.insert(0, scan_outputs_d) ssm_outputs.insert(0, scan_outputs_d)
else:
ssm_outputs.append(scan_outputs_d)
scan_outputs_combined = ssm_outputs[0] if len( scan_outputs_combined = ssm_outputs[0] if len(
ssm_outputs) == 1 else torch.cat(ssm_outputs, dim=-1) ssm_outputs) == 1 else torch.cat(ssm_outputs, dim=-1)
...@@ -441,9 +407,9 @@ def split_batch_to_prefill_and_decode( ...@@ -441,9 +407,9 @@ def split_batch_to_prefill_and_decode(
num_decodes: int, num_decodes: int,
num_padded_decodes: int, num_padded_decodes: int,
) -> PrefillDecodeSplit: ) -> PrefillDecodeSplit:
num_actual_tokens = num_prefill_tokens + num_padded_decodes num_actual_tokens = num_prefill_tokens + num_padded_decodes
if envs.VLLM_USE_V1:
# In v1, decode tokens come first, then prefill tokens. # In v1, decode tokens come first, then prefill tokens.
hidden_states_BC_d, hidden_states_BC_p = torch.split( hidden_states_BC_d, hidden_states_BC_p = torch.split(
hidden_states_BC[..., :num_actual_tokens], hidden_states_BC[..., :num_actual_tokens],
...@@ -462,19 +428,6 @@ def split_batch_to_prefill_and_decode( ...@@ -462,19 +428,6 @@ def split_batch_to_prefill_and_decode(
num_padded_decodes if num_prefills > 0 else None) num_padded_decodes if num_prefills > 0 else None)
has_initial_states_p = has_initial_states[-num_prefills:] if ( has_initial_states_p = has_initial_states[-num_prefills:] if (
has_initial_states is not None and num_prefills > 0) else None has_initial_states is not None and num_prefills > 0) else None
else:
# In v0, prefill tokens come first, then decode tokens.
hidden_states_BC_p, hidden_states_BC_d = torch.split(
hidden_states_BC, [num_prefill_tokens, num_decode_tokens], dim=-1)
gate_p, gate_d = torch.split(gate,
[num_prefill_tokens, num_decode_tokens],
dim=-1)
state_indices_tensor_p, state_indices_tensor_d = torch.split(
state_indices_tensor, [num_prefills, num_decodes], dim=0)
query_start_loc_p = (query_start_loc[:num_prefills +
1] if num_prefills > 0 else None)
has_initial_states_p = has_initial_states[:num_prefills] if (
has_initial_states is not None and num_prefills > 0) else None
return PrefillDecodeSplit( return PrefillDecodeSplit(
hidden_states_BC_p=hidden_states_BC_p, hidden_states_BC_p=hidden_states_BC_p,
...@@ -495,9 +448,7 @@ def mamba_mixer( ...@@ -495,9 +448,7 @@ def mamba_mixer(
) -> None: ) -> None:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
self.forward_cuda(hidden_states=hidden_states, self.forward_cuda(hidden_states=hidden_states, output=output)
output=output,
mamba_cache_params=None)
def mamba_mixer_fake( def mamba_mixer_fake(
......
...@@ -9,7 +9,6 @@ if TYPE_CHECKING: ...@@ -9,7 +9,6 @@ if TYPE_CHECKING:
import torch import torch
from torch import nn from torch import nn
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed import (divide, get_tensor_model_parallel_rank, from vllm.distributed import (divide, get_tensor_model_parallel_rank,
...@@ -22,8 +21,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -22,8 +21,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
update_metadata)
from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator) MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
...@@ -36,7 +33,6 @@ from vllm.model_executor.layers.mamba.ops.ssd_combined import ( ...@@ -36,7 +33,6 @@ from vllm.model_executor.layers.mamba.ops.ssd_combined import (
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
LoaderFunction, composed_weight_loader, sharded_weight_loader) LoaderFunction, composed_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.mamba_cache import MambaCacheParams
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
...@@ -449,16 +445,12 @@ class MambaMixer2(MambaBase, CustomOp): ...@@ -449,16 +445,12 @@ class MambaMixer2(MambaBase, CustomOp):
self.use_rms_norm, self.use_rms_norm,
eps=rms_norm_eps) eps=rms_norm_eps)
if envs.VLLM_USE_V1:
compilation_config = get_current_vllm_config().compilation_config compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context: if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}") raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self compilation_config.static_forward_context[prefix] = self
# The outer list is for v0 PP virtual engine. Though this code path # The tuple is (conv_state, ssm_state)
# only runs for v1, we have to do this to unify with the interface self.kv_cache = (torch.tensor([]), torch.tensor([]))
# of Attention + v0 PP.
# The inner tuple is (conv_state, ssm_state)
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
self.model_config = model_config self.model_config = model_config
self.cache_config = cache_config self.cache_config = cache_config
...@@ -468,8 +460,6 @@ class MambaMixer2(MambaBase, CustomOp): ...@@ -468,8 +460,6 @@ class MambaMixer2(MambaBase, CustomOp):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
mup_vector: Optional[torch.Tensor] = None, mup_vector: Optional[torch.Tensor] = None,
): ):
pass pass
...@@ -478,14 +468,8 @@ class MambaMixer2(MambaBase, CustomOp): ...@@ -478,14 +468,8 @@ class MambaMixer2(MambaBase, CustomOp):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
mup_vector: Optional[torch.Tensor] = None, mup_vector: Optional[torch.Tensor] = None,
): ):
if not envs.VLLM_USE_V1:
CustomOp.forward(self, hidden_states, output, mamba_cache_params,
mamba2_metadata, mup_vector)
else:
torch.ops.vllm.mamba_mixer2( torch.ops.vllm.mamba_mixer2(
hidden_states, hidden_states,
output, output,
...@@ -497,40 +481,30 @@ class MambaMixer2(MambaBase, CustomOp): ...@@ -497,40 +481,30 @@ class MambaMixer2(MambaBase, CustomOp):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
mup_vector: Optional[torch.Tensor] = None, mup_vector: Optional[torch.Tensor] = None,
): ):
forward_context = get_forward_context() forward_context = get_forward_context()
# mamba2_metadata contains metadata necessary for the mamba2 triton # attn_metadata contains metadata necessary for the mamba2 triton
# kernels to operate in continuous batching and in chunked prefill # kernels to operate in continuous batching and in chunked prefill
# modes; they are computed at top-level model forward since they # modes; they are computed at top-level model forward since they
# stay the same and reused for all mamba layers in the same iteration # stay the same and reused for all mamba layers in the same iteration
attn_metadata: AttentionMetadata = forward_context.attn_metadata attn_metadata: AttentionMetadata = forward_context.attn_metadata
if envs.VLLM_USE_V1:
if attn_metadata is not None: if attn_metadata is not None:
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix] attn_metadata = attn_metadata[self.prefix]
mamba2_metadata = attn_metadata
assert isinstance(attn_metadata, Mamba2AttentionMetadata) assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine] self_kv_cache = self.kv_cache[forward_context.virtual_engine]
# conv_state = (..., dim, width-1) yet contiguous along 'dim' # conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2) conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1] ssm_state = self_kv_cache[1]
state_indices_tensor = attn_metadata.state_indices_tensor state_indices_tensor = attn_metadata.state_indices_tensor
else: has_initial_states_p = attn_metadata.has_initial_states_p
conv_state = mamba_cache_params.conv_state prep_initial_states = attn_metadata.prep_initial_states
ssm_state = mamba_cache_params.ssm_state chunk_size = attn_metadata.chunk_size
state_indices_tensor = mamba_cache_params.state_indices_tensor seq_idx_p = attn_metadata.seq_idx_p
chunk_indices_p = attn_metadata.chunk_indices_p
# Common members between V1 metadata and V0 metadata chunk_offsets_p = attn_metadata.chunk_offsets_p
if mamba2_metadata is not None:
has_initial_states_p = mamba2_metadata.has_initial_states_p
prep_initial_states = mamba2_metadata.prep_initial_states
chunk_size = mamba2_metadata.chunk_size
seq_idx_p = mamba2_metadata.seq_idx_p
chunk_indices_p = mamba2_metadata.chunk_indices_p
chunk_offsets_p = mamba2_metadata.chunk_offsets_p
# 1. Gated MLP's linear projection # 1. Gated MLP's linear projection
projected_states, _ = self.in_proj(hidden_states) projected_states, _ = self.in_proj(hidden_states)
...@@ -562,8 +536,8 @@ class MambaMixer2(MambaBase, CustomOp): ...@@ -562,8 +536,8 @@ class MambaMixer2(MambaBase, CustomOp):
dim=-1, dim=-1,
) )
if envs.VLLM_USE_V1 and attn_metadata is None: if attn_metadata is None:
# V1 profile run # profile run
hidden_states_B_C = (hidden_states_B_C.transpose( hidden_states_B_C = (hidden_states_B_C.transpose(
0, 1).clone().transpose(0, 1)).contiguous() 0, 1).clone().transpose(0, 1)).contiguous()
hidden_states, _B, _C = split_hidden_states_B_C_fn( hidden_states, _B, _C = split_hidden_states_B_C_fn(
...@@ -579,10 +553,8 @@ class MambaMixer2(MambaBase, CustomOp): ...@@ -579,10 +553,8 @@ class MambaMixer2(MambaBase, CustomOp):
has_decode = num_decodes > 0 has_decode = num_decodes > 0
num_actual_tokens = num_prefill_tokens + num_decodes num_actual_tokens = num_prefill_tokens + num_decodes
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
# Separate prefill and decode by splitting varlen input # Separate prefill and decode by splitting varlen input
# Split along token dimension # Split along token dimension
if envs.VLLM_USE_V1:
hidden_states_B_C_d, hidden_states_B_C_p = torch.split( hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
hidden_states_B_C[:num_actual_tokens], hidden_states_B_C[:num_actual_tokens],
[num_decodes, num_prefill_tokens], [num_decodes, num_prefill_tokens],
...@@ -602,26 +574,6 @@ class MambaMixer2(MambaBase, CustomOp): ...@@ -602,26 +574,6 @@ class MambaMixer2(MambaBase, CustomOp):
query_start_loc_p = ( query_start_loc_p = (
attn_metadata.query_start_loc[-num_prefills - 1:] - attn_metadata.query_start_loc[-num_prefills - 1:] -
num_decodes if has_prefill else None) num_decodes if has_prefill else None)
else:
hidden_states_B_C_p, hidden_states_B_C_d = torch.split(
hidden_states_B_C,
[num_prefill_tokens, num_decodes],
dim=0,
)
dt_p, dt_d = torch.split(
dt,
[num_prefill_tokens, num_decodes],
dim=0,
)
# Split along batch dimension
state_indices_tensor_p, state_indices_tensor_d = torch.split(
state_indices_tensor,
[num_prefills, num_decodes],
dim=0,
)
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills +
1]
if has_prefill else None)
# Preallocate output tensor to avoid memcpy cost for merging prefill # Preallocate output tensor to avoid memcpy cost for merging prefill
# and decode outputs # and decode outputs
...@@ -633,18 +585,11 @@ class MambaMixer2(MambaBase, CustomOp): ...@@ -633,18 +585,11 @@ class MambaMixer2(MambaBase, CustomOp):
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
device=hidden_states.device, device=hidden_states.device,
) )
if envs.VLLM_USE_V1:
preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split( preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
preallocated_ssm_out, preallocated_ssm_out,
[num_decodes, num_prefill_tokens], [num_decodes, num_prefill_tokens],
dim=0, dim=0,
) )
else:
preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
preallocated_ssm_out,
[num_prefill_tokens, num_decodes],
dim=0,
)
# Process prefill requests # Process prefill requests
if has_prefill: if has_prefill:
...@@ -653,9 +598,6 @@ class MambaMixer2(MambaBase, CustomOp): ...@@ -653,9 +598,6 @@ class MambaMixer2(MambaBase, CustomOp):
# pointed to by "state_indices_tensor" # pointed to by "state_indices_tensor"
x = hidden_states_B_C_p.transpose( x = hidden_states_B_C_p.transpose(
0, 1) # this is the form that causal-conv see 0, 1) # this is the form that causal-conv see
if mamba2_metadata.cu_seqlen is None:
mamba2_metadata = update_metadata(x, query_start_loc_p,
mamba2_metadata)
hidden_states_B_C_p = causal_conv1d_fn( hidden_states_B_C_p = causal_conv1d_fn(
x, x,
conv_weights, conv_weights,
...@@ -664,7 +606,7 @@ class MambaMixer2(MambaBase, CustomOp): ...@@ -664,7 +606,7 @@ class MambaMixer2(MambaBase, CustomOp):
conv_states=conv_state, conv_states=conv_state,
has_initial_state=has_initial_states_p, has_initial_state=has_initial_states_p,
cache_indices=state_indices_tensor_p, cache_indices=state_indices_tensor_p,
metadata=mamba2_metadata, metadata=attn_metadata,
query_start_loc=query_start_loc_p).transpose( query_start_loc=query_start_loc_p).transpose(
0, 1)[:num_prefill_tokens] 0, 1)[:num_prefill_tokens]
...@@ -806,8 +748,6 @@ def mamba_mixer2( ...@@ -806,8 +748,6 @@ def mamba_mixer2(
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
self.forward_cuda(hidden_states=hidden_states, self.forward_cuda(hidden_states=hidden_states,
output=output, output=output,
mamba_cache_params=None,
mamba2_metadata=None,
mup_vector=mup_vector) mup_vector=mup_vector)
......
...@@ -100,7 +100,6 @@ class MambaStateShapeCalculator: ...@@ -100,7 +100,6 @@ class MambaStateShapeCalculator:
intermediate_size: int, intermediate_size: int,
state_size: int, state_size: int,
conv_kernel: int, conv_kernel: int,
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int]]: ) -> tuple[tuple[int, int], tuple[int, int]]:
conv_state_shape = (divide(intermediate_size, conv_state_shape = (divide(intermediate_size,
tp_world_size), conv_kernel - 1) tp_world_size), conv_kernel - 1)
...@@ -108,10 +107,6 @@ class MambaStateShapeCalculator: ...@@ -108,10 +107,6 @@ class MambaStateShapeCalculator:
temporal_state_shape = (divide(intermediate_size, temporal_state_shape = (divide(intermediate_size,
tp_world_size), state_size) tp_world_size), state_size)
# In V0, the conv_state shape was swapped during allocation in
# MambaCacheManager, but in V1 it needs to be determined here at the
# calculation level
if use_v1:
conv_state_shape = conv_state_shape[1], conv_state_shape[0] conv_state_shape = conv_state_shape[1], conv_state_shape[0]
return conv_state_shape, temporal_state_shape return conv_state_shape, temporal_state_shape
...@@ -126,7 +121,6 @@ class MambaStateShapeCalculator: ...@@ -126,7 +121,6 @@ class MambaStateShapeCalculator:
head_dim: int, head_dim: int,
state_size: int, state_size: int,
conv_kernel: int, conv_kernel: int,
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]: ) -> tuple[tuple[int, int], tuple[int, int, int]]:
# if n_groups is not divisible by world_size, need to extend the shards # if n_groups is not divisible by world_size, need to extend the shards
# to ensure all groups needed by a head is sharded along with it # to ensure all groups needed by a head is sharded along with it
...@@ -137,8 +131,6 @@ class MambaStateShapeCalculator: ...@@ -137,8 +131,6 @@ class MambaStateShapeCalculator:
# contiguous along 'dim' axis # contiguous along 'dim' axis
conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size)) conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size))
if not use_v1:
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
# These are not TP-ed as they depend on A, dt_bias, D # These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small # - they are typically small
...@@ -153,12 +145,9 @@ class MambaStateShapeCalculator: ...@@ -153,12 +145,9 @@ class MambaStateShapeCalculator:
tp_world_size: int, tp_world_size: int,
intermediate_size: int, intermediate_size: int,
conv_kernel: int, conv_kernel: int,
use_v1: bool = True,
) -> tuple[tuple[int, int]]: ) -> tuple[tuple[int, int]]:
conv_dim = divide(intermediate_size, tp_world_size) conv_dim = divide(intermediate_size, tp_world_size)
conv_state_shape = (conv_kernel - 1, conv_dim) conv_state_shape = (conv_kernel - 1, conv_dim)
if not use_v1:
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
return (conv_state_shape, ) return (conv_state_shape, )
@classmethod @classmethod
...@@ -183,7 +172,6 @@ class MambaStateShapeCalculator: ...@@ -183,7 +172,6 @@ class MambaStateShapeCalculator:
head_v_dim: int, head_v_dim: int,
conv_kernel_size: int, conv_kernel_size: int,
num_spec: int = 0, num_spec: int = 0,
use_v1: bool = True,
): ):
conv_dim = (head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads) conv_dim = (head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads)
conv_state_shape = ( conv_state_shape = (
...@@ -191,10 +179,6 @@ class MambaStateShapeCalculator: ...@@ -191,10 +179,6 @@ class MambaStateShapeCalculator:
conv_kernel_size - 1 + num_spec, conv_kernel_size - 1 + num_spec,
) )
# In V0, the conv_state shape was swapped during allocation in
# MambaCacheManager, but in V1 it needs to be determined here at the
# calculation level
if use_v1:
conv_state_shape = conv_state_shape[1], conv_state_shape[0] conv_state_shape = conv_state_shape[1], conv_state_shape[0]
temporal_state_shape = (divide(num_v_heads, temporal_state_shape = (divide(num_v_heads,
......
...@@ -420,9 +420,7 @@ def causal_conv1d_fn( ...@@ -420,9 +420,7 @@ def causal_conv1d_fn(
x = x.to(conv_states.dtype) x = x.to(conv_states.dtype)
out = torch.empty_like(x) out = torch.empty_like(x)
if metadata is not None: if metadata is not None:
cu_seqlen = metadata.cu_seqlen
nums_dict = metadata.nums_dict nums_dict = metadata.nums_dict
#x = metadata.x
args = nums_dict args = nums_dict
batch_ptr = metadata.batch_ptr batch_ptr = metadata.batch_ptr
token_chunk_offset_ptr = metadata.token_chunk_offset_ptr token_chunk_offset_ptr = metadata.token_chunk_offset_ptr
...@@ -926,7 +924,6 @@ def causal_conv1d_update( ...@@ -926,7 +924,6 @@ def causal_conv1d_update(
query_start_loc: Optional[torch.Tensor] = None, query_start_loc: Optional[torch.Tensor] = None,
max_query_len: int = -1, max_query_len: int = -1,
pad_slot_id: int = PAD_SLOT_ID, pad_slot_id: int = PAD_SLOT_ID,
metadata=None,
validate_data=False, validate_data=False,
): ):
""" """
......
...@@ -8,7 +8,6 @@ if TYPE_CHECKING: ...@@ -8,7 +8,6 @@ if TYPE_CHECKING:
import torch import torch
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
...@@ -18,7 +17,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -18,7 +17,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba2_metadata import update_metadata
from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator) MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
...@@ -71,15 +69,11 @@ class ShortConv(MambaBase, CustomOp): ...@@ -71,15 +69,11 @@ class ShortConv(MambaBase, CustomOp):
prefix=f"{prefix}.out_proj", prefix=f"{prefix}.out_proj",
) )
assert envs.VLLM_USE_V1, ("ShortConv layers are only supported in V1")
compilation_config = get_current_vllm_config().compilation_config compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context: if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}") raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self compilation_config.static_forward_context[prefix] = self
# The outer list is for v0 PP virtual engine. Though this code path self.kv_cache = (torch.tensor([]), )
# only runs for v1, we have to do this to unify with the interface
# of Attention + v0 PP.
self.kv_cache = [(torch.tensor([]), )]
self.model_config = model_config self.model_config = model_config
self.cache_config = cache_config self.cache_config = cache_config
...@@ -89,7 +83,6 @@ class ShortConv(MambaBase, CustomOp): ...@@ -89,7 +83,6 @@ class ShortConv(MambaBase, CustomOp):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
conv_metadata: ShortConvAttentionMetadata,
): ):
return return
...@@ -97,7 +90,6 @@ class ShortConv(MambaBase, CustomOp): ...@@ -97,7 +90,6 @@ class ShortConv(MambaBase, CustomOp):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
conv_metadata: ShortConvAttentionMetadata,
): ):
torch.ops.vllm.short_conv( torch.ops.vllm.short_conv(
hidden_states, hidden_states,
...@@ -109,7 +101,6 @@ class ShortConv(MambaBase, CustomOp): ...@@ -109,7 +101,6 @@ class ShortConv(MambaBase, CustomOp):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
conv_metadata: ShortConvAttentionMetadata,
): ):
forward_context = get_forward_context() forward_context = get_forward_context()
# ShortConvAttentionMetadata contains metadata necessary for the # ShortConvAttentionMetadata contains metadata necessary for the
...@@ -121,7 +112,6 @@ class ShortConv(MambaBase, CustomOp): ...@@ -121,7 +112,6 @@ class ShortConv(MambaBase, CustomOp):
if attn_metadata is not None: if attn_metadata is not None:
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix] attn_metadata = attn_metadata[self.prefix]
conv_metadata = attn_metadata
assert isinstance(attn_metadata, ShortConvAttentionMetadata) assert isinstance(attn_metadata, ShortConvAttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine] self_kv_cache = self.kv_cache[forward_context.virtual_engine]
conv_state = self_kv_cache[0].transpose(-1, -2) conv_state = self_kv_cache[0].transpose(-1, -2)
...@@ -181,9 +171,6 @@ class ShortConv(MambaBase, CustomOp): ...@@ -181,9 +171,6 @@ class ShortConv(MambaBase, CustomOp):
if has_prefill: if has_prefill:
Bx_p = (B_p * x_p).transpose(0, 1) Bx_p = (B_p * x_p).transpose(0, 1)
if conv_metadata.cu_seqlen is None:
conv_metadata = update_metadata(Bx_p, query_start_loc_p,
conv_metadata)
Bx = causal_conv1d_fn(Bx_p, Bx = causal_conv1d_fn(Bx_p,
conv_weights, conv_weights,
self.conv.bias, self.conv.bias,
...@@ -191,7 +178,7 @@ class ShortConv(MambaBase, CustomOp): ...@@ -191,7 +178,7 @@ class ShortConv(MambaBase, CustomOp):
conv_states=conv_state, conv_states=conv_state,
has_initial_state=has_initial_states_p, has_initial_state=has_initial_states_p,
cache_indices=state_indices_tensor_p, cache_indices=state_indices_tensor_p,
metadata=conv_metadata, metadata=attn_metadata,
query_start_loc=query_start_loc_p).transpose( query_start_loc=query_start_loc_p).transpose(
0, 1)[:num_prefill_tokens] 0, 1)[:num_prefill_tokens]
...@@ -248,9 +235,7 @@ def short_conv( ...@@ -248,9 +235,7 @@ def short_conv(
) -> None: ) -> None:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
self.forward_cuda(hidden_states=hidden_states, self.forward_cuda(hidden_states=hidden_states, output=output)
output=output,
conv_metadata=None)
def short_conv_fake( def short_conv_fake(
......
...@@ -9,21 +9,17 @@ import torch ...@@ -9,21 +9,17 @@ import torch
from torch import nn from torch import nn
from transformers import BambaConfig from transformers import BambaConfig
from vllm import envs
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator) MambaStateDtypeCalculator, MambaStateShapeCalculator)
...@@ -32,10 +28,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope ...@@ -32,10 +28,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
SupportsQuant) SupportsQuant)
...@@ -115,8 +108,6 @@ class BambaMixerDecoderLayer(nn.Module): ...@@ -115,8 +108,6 @@ class BambaMixerDecoderLayer(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs, **kwargs,
): ):
if residual is None: if residual is None:
...@@ -127,7 +118,7 @@ class BambaMixerDecoderLayer(nn.Module): ...@@ -127,7 +118,7 @@ class BambaMixerDecoderLayer(nn.Module):
hidden_states, residual) hidden_states, residual)
output = torch.empty_like(hidden_states) output = torch.empty_like(hidden_states)
self.mamba(hidden_states, output, mamba_cache_params, mamba2_metadata) self.mamba(hidden_states, output)
# Fully Connected # Fully Connected
hidden_states, residual = self.pre_ff_layernorm(output, residual) hidden_states, residual = self.pre_ff_layernorm(output, residual)
hidden_states = self.feed_forward(hidden_states) hidden_states = self.feed_forward(hidden_states)
...@@ -315,22 +306,10 @@ class BambaModel(nn.Module): ...@@ -315,22 +306,10 @@ class BambaModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
attn_metadata = get_forward_context().attn_metadata
if not envs.VLLM_USE_V1:
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size,
attn_metadata=attn_metadata,
)
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -343,23 +322,11 @@ class BambaModel(nn.Module): ...@@ -343,23 +322,11 @@ class BambaModel(nn.Module):
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
residual = None residual = None
num_attn = 0
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
if isinstance(layer, BambaAttentionDecoderLayer):
num_attn += 1
layer_mamba_cache_params = None
if isinstance(layer,
BambaMixerDecoderLayer) and mamba_cache_params:
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
i - num_attn)
hidden_states, residual = layer( hidden_states, residual = layer(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
residual=residual, residual=residual,
mamba_cache_params=layer_mamba_cache_params,
mamba2_metadata=mamba2_metadata,
) )
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
...@@ -457,13 +424,11 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -457,13 +424,11 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def get_mamba_state_shape_from_config( def get_mamba_state_shape_from_config(
cls, cls,
vllm_config: "VllmConfig", vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]: ) -> tuple[tuple[int, int], tuple[int, int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches. """Calculate shapes for Mamba's convolutional and state caches.
Args: Args:
vllm_config: vLLM config vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns: Returns:
Tuple containing: Tuple containing:
...@@ -482,7 +447,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -482,7 +447,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
head_dim=hf_config.mamba_d_head, head_dim=hf_config.mamba_d_head,
state_size=hf_config.mamba_d_state, state_size=hf_config.mamba_d_state,
conv_kernel=hf_config.mamba_d_conv, conv_kernel=hf_config.mamba_d_conv,
use_v1=use_v1,
) )
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
...@@ -515,8 +479,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -515,8 +479,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
if not lora_config else lora_config.lora_vocab_padding_size, if not lora_config else lora_config.lora_vocab_padding_size,
prefix=maybe_prefix(prefix, "lm_head"), prefix=maybe_prefix(prefix, "lm_head"),
) )
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
...@@ -534,39 +496,11 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -534,39 +496,11 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs): **kwargs):
mamba_cache_params = None hidden_states = self.model(input_ids, positions, intermediate_tensors,
if not envs.VLLM_USE_V1: inputs_embeds)
if self.mamba_cache is None:
num_mamba_layers = \
self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config,
LayerBlockType.mamba
)
mamba_state_shape = \
self.get_mamba_state_shape_from_config(
self.vllm_config, use_v1=False)
mamba_state_dtype = \
self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
num_mamba_layers,
*mamba_state_shape,
*mamba_state_dtype)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.model(input_ids, positions, mamba_cache_params,
intermediate_tensors, inputs_embeds)
return hidden_states return hidden_states
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from typing import Any
import torch
from vllm.attention.backends.utils import PAD_SLOT_ID
class ConstantSizeCache(ABC):
"""
Abstract base class for managing constant size caches
like Mamba and Minimax.
"""
def __init__(self, max_batch_size: int):
# Maps between the request id and a dict that maps between the seq_id
# and its index inside the cache
self.cache_indices_mapping: dict[str, dict[int, int]] = {}
self.free_cache_indices = list(range(max_batch_size))
@property
@abstractmethod
def cache(self) -> Any:
"""Return the underlying cache tensor(s)"""
pass
@abstractmethod
def _copy_cache(self, from_index: int, to_index: int):
"""Copy cache data from one index to another"""
pass
def current_run_tensors(self, **kwargs) -> tuple:
"""
Return the tensors for the current run's conv and ssm state.
"""
if "seqlen_agnostic_capture_inputs" not in kwargs:
# We get here only on Prefill/Eager mode runs
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
finished_requests_ids = kwargs["finished_requests_ids"]
self._release_finished_requests(finished_requests_ids)
state_indices = self._prepare_current_run_cache(
request_ids_to_seq_ids, finished_requests_ids)
state_indices_tensor = torch.as_tensor(state_indices,
dtype=torch.int32,
device="cuda")
cache_tensors = self.cache
else:
# CUDA graph capturing runs
cache_tensors, state_indices_tensor = kwargs[
"seqlen_agnostic_capture_inputs"]
return (cache_tensors, state_indices_tensor)
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
"""
Copy the relevant state_indices into the CUDA graph input buffer
"""
assert all(
key in kwargs
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
finished_requests_ids = kwargs["finished_requests_ids"]
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
assert "seqlen_agnostic_capture_inputs" in input_buffers
_, input_state_indices_buffer = input_buffers[
"seqlen_agnostic_capture_inputs"]
self._release_finished_requests(finished_requests_ids)
state_indices = self._prepare_current_run_cache(
request_ids_to_seq_ids, finished_requests_ids)
cuda_graph_pad_len = input_state_indices_buffer.shape[0] - len(
state_indices)
state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len)
input_state_indices_buffer.copy_(
torch.as_tensor(state_indices, dtype=torch.int32, device="cuda"))
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
"""
Provide the CUDA graph capture runs with a buffer in adjusted size.
The buffer is used to maintain the Cache during the CUDA graph replay
runs.
"""
state_indices_tensor = torch.as_tensor([PAD_SLOT_ID] * batch_size,
dtype=torch.int32,
device="cuda")
return (self.cache, state_indices_tensor)
def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int,
finished_requests_ids) -> int:
"""
Assign (req_id,seq_id) pair to a `destination_index` index, if
already occupied, move the occupying index to a free index.
"""
if cur_rid in finished_requests_ids:
# set as pad, do not allocate destination index
return PAD_SLOT_ID
elif cur_rid not in self.cache_indices_mapping:
destination_index = self.free_cache_indices.pop()
self.cache_indices_mapping[cur_rid] = {seq_id: destination_index}
return destination_index
elif seq_id not in (seq_ids2indices :=
self.cache_indices_mapping[cur_rid]):
# parallel sampling , where n > 1, assume prefill have
# already happened, so we copy the
# existing cache into the siblings seq_ids caches
index_exists = next(iter(seq_ids2indices.values()))
# case of decoding n>1, copy prefill cache to decoding indices
destination_index = self.free_cache_indices.pop()
self._copy_cache(from_index=index_exists,
to_index=destination_index)
self.cache_indices_mapping[cur_rid][seq_id] = destination_index
return destination_index
else:
return self.cache_indices_mapping[cur_rid][seq_id]
def _prepare_current_run_cache(
self, request_ids_to_seq_ids: dict[str, list[int]],
finished_requests_ids: list[str]) -> list[int]:
return [
self._assign_seq_id_to_cache_index(req_id, seq_id,
finished_requests_ids)
for req_id, seq_ids in request_ids_to_seq_ids.items()
for seq_id in seq_ids
]
def _release_finished_requests(self,
finished_seq_groups_req_ids: list[str]):
for req_id in finished_seq_groups_req_ids:
if req_id in self.cache_indices_mapping:
for seq_id in self.cache_indices_mapping[req_id]:
self.free_cache_indices.append(
self.cache_indices_mapping[req_id][seq_id])
self.cache_indices_mapping.pop(req_id)
...@@ -8,21 +8,17 @@ import torch ...@@ -8,21 +8,17 @@ import torch
from torch import nn from torch import nn
from transformers import FalconH1Config from transformers import FalconH1Config
from vllm import envs
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator) MambaStateDtypeCalculator, MambaStateShapeCalculator)
...@@ -31,8 +27,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope ...@@ -31,8 +27,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
...@@ -179,16 +173,12 @@ class FalconH1SSMDecoderLayer(nn.Module): ...@@ -179,16 +173,12 @@ class FalconH1SSMDecoderLayer(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs, **kwargs,
): ):
output = torch.empty_like(hidden_states) output = torch.empty_like(hidden_states)
self.mamba( self.mamba(
hidden_states, hidden_states,
output, output,
mamba_cache_params,
mamba2_metadata=mamba2_metadata,
mup_vector=self.mup_vector, mup_vector=self.mup_vector,
) )
return output, residual return output, residual
...@@ -364,8 +354,6 @@ class FalconH1ParallelHybrid(nn.Module): ...@@ -364,8 +354,6 @@ class FalconH1ParallelHybrid(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs, **kwargs,
): ):
residual = hidden_states residual = hidden_states
...@@ -382,12 +370,10 @@ class FalconH1ParallelHybrid(nn.Module): ...@@ -382,12 +370,10 @@ class FalconH1ParallelHybrid(nn.Module):
# Process input through the SSM branch. # Process input through the SSM branch.
# FalconH1SSMDecoderLayer expects hidden_states, attn_metadata, # FalconH1SSMDecoderLayer expects hidden_states, attn_metadata,
# residual, mamba_cache_params, and sequence_idx. # residual, and sequence_idx.
ssm_hidden, _ = self.mamba( ssm_hidden, _ = self.mamba(
hidden_states=hidden_states * self.ssm_in_multiplier, hidden_states=hidden_states * self.ssm_in_multiplier,
residual=residual, residual=residual,
mamba_cache_params=mamba_cache_params,
mamba2_metadata=mamba2_metadata,
**kwargs, **kwargs,
) )
# Sum the outputs from both branches. # Sum the outputs from both branches.
...@@ -464,25 +450,10 @@ class FalconH1Model(nn.Module): ...@@ -464,25 +450,10 @@ class FalconH1Model(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
# pass a sequence index tensor, that is required for
# proper continuous batching computation including
# chunked prefill
attn_metadata = get_forward_context().attn_metadata
if not envs.VLLM_USE_V1:
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size,
attn_metadata=attn_metadata,
)
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds * self.embedding_multiplier hidden_states = inputs_embeds * self.embedding_multiplier
...@@ -495,14 +466,9 @@ class FalconH1Model(nn.Module): ...@@ -495,14 +466,9 @@ class FalconH1Model(nn.Module):
for i in range(self.start_layer, self.end_layer): for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
layer_mamba_cache_params = None
if mamba_cache_params:
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(i)
hidden_states = layer( hidden_states = layer(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
mamba_cache_params=layer_mamba_cache_params,
mamba2_metadata=mamba2_metadata,
) )
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
...@@ -541,13 +507,11 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -541,13 +507,11 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def get_mamba_state_shape_from_config( def get_mamba_state_shape_from_config(
cls, cls,
vllm_config: "VllmConfig", vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]: ) -> tuple[tuple[int, int], tuple[int, int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches. """Calculate shapes for Mamba's convolutional and state caches.
Args: Args:
vllm_config: vLLM config vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns: Returns:
Tuple containing: Tuple containing:
...@@ -570,7 +534,6 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -570,7 +534,6 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
head_dim=hf_config.mamba_d_head, head_dim=hf_config.mamba_d_head,
state_size=hf_config.mamba_d_state, state_size=hf_config.mamba_d_state,
conv_kernel=hf_config.mamba_d_conv, conv_kernel=hf_config.mamba_d_conv,
use_v1=use_v1,
) )
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
...@@ -592,7 +555,6 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -592,7 +555,6 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
prefix=maybe_prefix(prefix, "model")) prefix=maybe_prefix(prefix, "model"))
self.tie_word_embeddings = config.tie_word_embeddings self.tie_word_embeddings = config.tie_word_embeddings
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
self.mamba_cache: Optional[MambaCacheManager] = None
if lora_config: if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
...@@ -637,40 +599,15 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -637,40 +599,15 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
**kwargs, **kwargs,
): ):
mamba_cache_params = None
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
mamba_state_shape = \
self.get_mamba_state_shape_from_config(
self.vllm_config, use_v1=False)
mamba_state_dtype = \
self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(
self.vllm_config,
self.config.num_hidden_layers,
*mamba_state_shape,
*mamba_state_dtype,
)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,
positions, positions,
mamba_cache_params,
intermediate_tensors, intermediate_tensors,
inputs_embeds, inputs_embeds,
) )
return hidden_states return hidden_states
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
......
...@@ -9,19 +9,15 @@ import torch ...@@ -9,19 +9,15 @@ import torch
from torch import nn from torch import nn
from transformers import GraniteMoeHybridConfig from transformers import GraniteMoeHybridConfig
from vllm import envs
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator) MambaStateDtypeCalculator, MambaStateShapeCalculator)
...@@ -30,10 +26,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope ...@@ -30,10 +26,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType
from .granitemoe import GraniteMoeMoE from .granitemoe import GraniteMoeMoE
from .granitemoeshared import GraniteMoeSharedMLP from .granitemoeshared import GraniteMoeSharedMLP
...@@ -102,14 +95,12 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module): ...@@ -102,14 +95,12 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs, **kwargs,
): ):
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
output = torch.empty_like(hidden_states) output = torch.empty_like(hidden_states)
self.mamba(hidden_states, output, mamba_cache_params, mamba2_metadata) self.mamba(hidden_states, output)
hidden_states = residual + output * self.residual_multiplier hidden_states = residual + output * self.residual_multiplier
residual = hidden_states residual = hidden_states
...@@ -182,8 +173,6 @@ class GraniteMoeHybridAttentionDecoderLayer(nn.Module): ...@@ -182,8 +173,6 @@ class GraniteMoeHybridAttentionDecoderLayer(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
...@@ -366,22 +355,10 @@ class GraniteMoeHybridModel(nn.Module): ...@@ -366,22 +355,10 @@ class GraniteMoeHybridModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
attn_metadata = get_forward_context().attn_metadata
if not envs.VLLM_USE_V1:
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size,
attn_metadata=attn_metadata,
)
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -399,20 +376,9 @@ class GraniteMoeHybridModel(nn.Module): ...@@ -399,20 +376,9 @@ class GraniteMoeHybridModel(nn.Module):
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
if isinstance(layer, GraniteMoeHybridAttentionDecoderLayer): if isinstance(layer, GraniteMoeHybridAttentionDecoderLayer):
num_attn += 1 num_attn += 1
hidden_states, residual = layer(positions=positions,
layer_mamba_cache_params = None
if isinstance(
layer,
GraniteMoeHybridMambaDecoderLayer) and mamba_cache_params:
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
i - num_attn)
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
residual=residual, residual=residual)
mamba_cache_params=layer_mamba_cache_params,
mamba2_metadata=mamba2_metadata)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
...@@ -552,13 +518,11 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, ...@@ -552,13 +518,11 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
def get_mamba_state_shape_from_config( def get_mamba_state_shape_from_config(
cls, cls,
vllm_config: "VllmConfig", vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]: ) -> tuple[tuple[int, int], tuple[int, int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches. """Calculate shapes for Mamba's convolutional and state caches.
Args: Args:
vllm_config: vLLM config vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns: Returns:
Tuple containing: Tuple containing:
...@@ -577,7 +541,6 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, ...@@ -577,7 +541,6 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
head_dim=hf_config.mamba_d_head, head_dim=hf_config.mamba_d_head,
state_size=hf_config.mamba_d_state, state_size=hf_config.mamba_d_state,
conv_kernel=hf_config.mamba_d_conv, conv_kernel=hf_config.mamba_d_conv,
use_v1=use_v1,
) )
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
...@@ -620,9 +583,6 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, ...@@ -620,9 +583,6 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
scale=1 / scale=1 /
self.config.logits_scaling) self.config.logits_scaling)
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
...@@ -636,38 +596,11 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, ...@@ -636,38 +596,11 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs): **kwargs):
mamba_cache_params = None hidden_states = self.model(input_ids, positions, intermediate_tensors,
if not envs.VLLM_USE_V1: inputs_embeds)
if self.mamba_cache is None:
num_mamba_layers = (
self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config,
LayerBlockType.mamba))
mamba_state_shape = \
self.get_mamba_state_shape_from_config(
self.vllm_config, use_v1=False)
mamba_state_dtype = \
self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
num_mamba_layers,
*mamba_state_shape,
*mamba_state_dtype)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.model(input_ids, positions, mamba_cache_params,
intermediate_tensors, inputs_embeds)
return hidden_states return hidden_states
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
......
...@@ -9,7 +9,6 @@ import torch ...@@ -9,7 +9,6 @@ import torch
from torch import nn from torch import nn
from transformers import JambaConfig from transformers import JambaConfig
from vllm import envs
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, VllmConfig
...@@ -30,10 +29,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -30,10 +29,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaMLP as JambaMLP from vllm.model_executor.models.llama import LlamaMLP as JambaMLP
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
...@@ -145,7 +141,6 @@ class JambaMambaDecoderLayer(nn.Module): ...@@ -145,7 +141,6 @@ class JambaMambaDecoderLayer(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
**kwargs, **kwargs,
): ):
if residual is None: if residual is None:
...@@ -156,7 +151,7 @@ class JambaMambaDecoderLayer(nn.Module): ...@@ -156,7 +151,7 @@ class JambaMambaDecoderLayer(nn.Module):
hidden_states, residual) hidden_states, residual)
output = torch.empty_like(hidden_states) output = torch.empty_like(hidden_states)
self.mamba(hidden_states, output, mamba_cache_params) self.mamba(hidden_states, output)
# Fully Connected # Fully Connected
hidden_states, residual = self.pre_ff_layernorm(output, residual) hidden_states, residual = self.pre_ff_layernorm(output, residual)
hidden_states = self.feed_forward(hidden_states) hidden_states = self.feed_forward(hidden_states)
...@@ -333,7 +328,6 @@ class JambaModel(nn.Module): ...@@ -333,7 +328,6 @@ class JambaModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -348,24 +342,11 @@ class JambaModel(nn.Module): ...@@ -348,24 +342,11 @@ class JambaModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
kv_cache_index = 0
mamba_cache_index = 0
for layer in islice(self.layers, self.start_layer, self.end_layer): for layer in islice(self.layers, self.start_layer, self.end_layer):
layer_mamba_cache_params = None hidden_states, residual = layer(positions=positions,
if isinstance(layer, JambaAttentionDecoderLayer):
kv_cache_index += 1
if isinstance(layer,
JambaMambaDecoderLayer) and mamba_cache_params:
current_state_layer = mamba_cache_index
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
current_state_layer)
mamba_cache_index += 1
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
residual=residual, residual=residual)
mamba_cache_params=layer_mamba_cache_params)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
"hidden_states": hidden_states, "hidden_states": hidden_states,
...@@ -503,8 +484,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -503,8 +484,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
if not lora_config else lora_config.lora_vocab_padding_size, if not lora_config else lora_config.lora_vocab_padding_size,
prefix=maybe_prefix(prefix, "lm_head"), prefix=maybe_prefix(prefix, "lm_head"),
) )
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
...@@ -521,24 +500,9 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -521,24 +500,9 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs): **kwargs):
# NOTE: mamba_cache_params is not needed for v1
mamba_cache_params = None hidden_states = self.model(input_ids, positions, intermediate_tensors,
if not envs.VLLM_USE_V1: inputs_embeds)
if self.mamba_cache is None:
num_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba)
state_shape = self.get_mamba_state_shape_from_config(
self.vllm_config)
state_dtype = self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
num_layers, *state_shape,
*state_dtype)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.model(input_ids, positions, mamba_cache_params,
intermediate_tensors, inputs_embeds)
return hidden_states return hidden_states
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
...@@ -574,7 +538,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -574,7 +538,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
intermediate_size=hf_config.mamba_expand * hidden_size, intermediate_size=hf_config.mamba_expand * hidden_size,
state_size=hf_config.mamba_d_state, state_size=hf_config.mamba_d_state,
conv_kernel=hf_config.mamba_d_conv, conv_kernel=hf_config.mamba_d_conv,
use_v1=envs.VLLM_USE_V1,
) )
def compute_logits( def compute_logits(
......
...@@ -8,7 +8,6 @@ import torch ...@@ -8,7 +8,6 @@ import torch
import torch.nn as nn import torch.nn as nn
from transformers import Lfm2Config from transformers import Lfm2Config
from vllm import envs
from vllm.attention import Attention from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, VllmConfig
...@@ -297,7 +296,6 @@ class Lfm2ShortConvDecoderLayer(nn.Module): ...@@ -297,7 +296,6 @@ class Lfm2ShortConvDecoderLayer(nn.Module):
self.conv( self.conv(
hidden_states, hidden_states,
output, output,
conv_metadata=None,
) )
hidden_states, residual = self.ffn_norm(output, residual) hidden_states, residual = self.ffn_norm(output, residual)
hidden_states = self.feed_forward(hidden_states) hidden_states = self.feed_forward(hidden_states)
...@@ -459,13 +457,11 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -459,13 +457,11 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def get_mamba_state_shape_from_config( def get_mamba_state_shape_from_config(
cls, cls,
vllm_config: "VllmConfig", vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, int]]: ) -> tuple[tuple[int, int]]:
""" Calculate shapes for LFM2's convolutional cache. """ Calculate shapes for LFM2's convolutional cache.
Args: Args:
vllm_config: vLLM config vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns: Returns:
Tuple containing: Tuple containing:
...@@ -478,7 +474,6 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -478,7 +474,6 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
tp_world_size=parallel_config.tensor_parallel_size, tp_world_size=parallel_config.tensor_parallel_size,
intermediate_size=hf_config.conv_dim, intermediate_size=hf_config.conv_dim,
conv_kernel=hf_config.conv_L_cache, conv_kernel=hf_config.conv_L_cache,
use_v1=use_v1,
) )
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
...@@ -489,8 +484,6 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -489,8 +484,6 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
scheduler_config = vllm_config.scheduler_config scheduler_config = vllm_config.scheduler_config
assert (not cache_config.enable_prefix_caching assert (not cache_config.enable_prefix_caching
), "Lfm2 currently does not support prefix caching" ), "Lfm2 currently does not support prefix caching"
assert envs.VLLM_USE_V1, (
"Lfm2ForCausalLM doesn't support vLLM v0. Please enable v1")
super().__init__() super().__init__()
self.config = config self.config = config
......
...@@ -8,7 +8,6 @@ import torch ...@@ -8,7 +8,6 @@ import torch
from torch import nn from torch import nn
from transformers import MambaConfig from transformers import MambaConfig
from vllm import envs
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
...@@ -24,10 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -24,10 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (HasInnerState, from vllm.model_executor.models.interfaces import (HasInnerState,
IsAttentionFree, SupportsPP) IsAttentionFree, SupportsPP)
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType
from .utils import (AutoWeightsLoader, is_pp_missing_parameter, from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
...@@ -72,7 +68,6 @@ class MambaDecoderLayer(nn.Module): ...@@ -72,7 +68,6 @@ class MambaDecoderLayer(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
**kwargs, **kwargs,
): ):
if residual is None: if residual is None:
...@@ -82,7 +77,7 @@ class MambaDecoderLayer(nn.Module): ...@@ -82,7 +77,7 @@ class MambaDecoderLayer(nn.Module):
hidden_states, residual = self.norm(hidden_states, residual) hidden_states, residual = self.norm(hidden_states, residual)
output = torch.empty_like(hidden_states) output = torch.empty_like(hidden_states)
self.mixer(hidden_states, output, mamba_cache_params) self.mixer(hidden_states, output)
return output, residual return output, residual
...@@ -134,7 +129,6 @@ class MambaModel(nn.Module): ...@@ -134,7 +129,6 @@ class MambaModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
mamba_cache_params: Optional[MambaCacheParams] = None,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -151,17 +145,9 @@ class MambaModel(nn.Module): ...@@ -151,17 +145,9 @@ class MambaModel(nn.Module):
for i in range(self.start_layer, self.end_layer): for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer(positions=positions,
layer_cache_params = None
if mamba_cache_params is not None:
layer_cache_params = mamba_cache_params.at_layer_idx(
i - self.start_layer)
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
residual=residual, residual=residual)
mamba_cache_params=layer_cache_params)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
"hidden_states": hidden_states, "hidden_states": hidden_states,
...@@ -225,9 +211,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP): ...@@ -225,9 +211,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
prefix=maybe_prefix(prefix, "lm_head"), prefix=maybe_prefix(prefix, "lm_head"),
) )
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
...@@ -244,22 +227,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP): ...@@ -244,22 +227,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs): **kwargs):
mamba_cache_params = None hidden_states = self.backbone(input_ids, positions,
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
num_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba)
state_shape = self.get_mamba_state_shape_from_config(
self.vllm_config)
state_dtype = self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
num_layers, *state_shape,
*state_dtype)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.backbone(input_ids, positions, mamba_cache_params,
intermediate_tensors, inputs_embeds) intermediate_tensors, inputs_embeds)
return hidden_states return hidden_states
...@@ -288,8 +256,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP): ...@@ -288,8 +256,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
tp_world_size=parallel_config.tensor_parallel_size, tp_world_size=parallel_config.tensor_parallel_size,
intermediate_size=hf_config.intermediate_size, intermediate_size=hf_config.intermediate_size,
state_size=hf_config.state_size, state_size=hf_config.state_size,
conv_kernel=hf_config.conv_kernel, conv_kernel=hf_config.conv_kernel)
use_v1=envs.VLLM_USE_V1)
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs( return self.mamba_cache.copy_inputs_before_cuda_graphs(
......
...@@ -8,16 +8,11 @@ import torch ...@@ -8,16 +8,11 @@ import torch
from torch import nn from torch import nn
from transformers import MambaConfig from transformers import MambaConfig
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator) MambaStateDtypeCalculator, MambaStateShapeCalculator)
...@@ -28,10 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -28,10 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (HasInnerState, from vllm.model_executor.models.interfaces import (HasInnerState,
IsAttentionFree) IsAttentionFree)
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType
from .utils import (AutoWeightsLoader, is_pp_missing_parameter, from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
...@@ -74,8 +66,6 @@ class Mamba2DecoderLayer(nn.Module): ...@@ -74,8 +66,6 @@ class Mamba2DecoderLayer(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs, **kwargs,
): ):
if residual is None: if residual is None:
...@@ -85,7 +75,7 @@ class Mamba2DecoderLayer(nn.Module): ...@@ -85,7 +75,7 @@ class Mamba2DecoderLayer(nn.Module):
hidden_states, residual = self.norm(hidden_states, residual) hidden_states, residual = self.norm(hidden_states, residual)
output = torch.empty_like(hidden_states) output = torch.empty_like(hidden_states)
self.mixer(hidden_states, output, mamba_cache_params, mamba2_metadata) self.mixer(hidden_states, output)
return output, residual return output, residual
...@@ -137,7 +127,6 @@ class Mamba2Model(nn.Module): ...@@ -137,7 +127,6 @@ class Mamba2Model(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -152,25 +141,10 @@ class Mamba2Model(nn.Module): ...@@ -152,25 +141,10 @@ class Mamba2Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
if not envs.VLLM_USE_V1:
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.chunk_size,
attn_metadata=attn_metadata,
)
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(positions=positions,
positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
residual=residual, residual=residual)
mamba_cache_params=mamba_cache_params.at_layer_idx(
i - self.start_layer) if mamba_cache_params else None,
mamba2_metadata=mamba2_metadata)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
...@@ -222,13 +196,11 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): ...@@ -222,13 +196,11 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
def get_mamba_state_shape_from_config( def get_mamba_state_shape_from_config(
cls, cls,
vllm_config: "VllmConfig", vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]: ) -> tuple[tuple[int, int], tuple[int, int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches. """Calculate shapes for Mamba's convolutional and state caches.
Args: Args:
vllm_config: vLLM config vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns: Returns:
Tuple containing: Tuple containing:
...@@ -247,7 +219,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): ...@@ -247,7 +219,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
head_dim=hf_config.head_dim, head_dim=hf_config.head_dim,
state_size=hf_config.state_size, state_size=hf_config.state_size,
conv_kernel=hf_config.conv_kernel, conv_kernel=hf_config.conv_kernel,
use_v1=use_v1,
) )
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
...@@ -282,9 +253,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): ...@@ -282,9 +253,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
if config.tie_word_embeddings: if config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(self.backbone.embeddings) self.lm_head = self.lm_head.tie_weights(self.backbone.embeddings)
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
...@@ -300,29 +268,8 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): ...@@ -300,29 +268,8 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs): **kwargs):
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
num_mamba_layers = (
self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config,
LayerBlockType.mamba))
mamba_state_shape = \
self.get_mamba_state_shape_from_config(
self.vllm_config, use_v1=False)
mamba_state_dtype = \
self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
num_mamba_layers,
*mamba_state_shape,
*mamba_state_dtype)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
else:
# NOTE: mamba_cache_params is not needed for v1
mamba_cache_params = None
hidden_states = self.backbone(input_ids, positions, mamba_cache_params, hidden_states = self.backbone(input_ids, positions,
intermediate_tensors, inputs_embeds) intermediate_tensors, inputs_embeds)
return hidden_states return hidden_states
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import torch
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig
from vllm.model_executor.models.constant_size_cache import ConstantSizeCache
@dataclass
class MambaCacheParams:
conv_state: torch.Tensor = torch.Tensor()
ssm_state: torch.Tensor = torch.Tensor()
state_indices_tensor: torch.Tensor = torch.Tensor()
def at_layer_idx(self, layer_idx):
return MambaCacheParams(self.conv_state[layer_idx],
self.ssm_state[layer_idx],
self.state_indices_tensor)
class MambaCacheManager(ConstantSizeCache):
def __init__(self, vllm_config: VllmConfig, num_mamba_layers: int,
conv_state_shape: tuple[int, int],
temporal_state_shape: tuple[int, int],
conv_state_dtype: torch.dtype,
temporal_state_dtype: torch.dtype):
self.conv_state_dtype = conv_state_dtype
self.temporal_state_dtype = temporal_state_dtype
# Determine max batch size to set size of MambaCache
max_batch_size = vllm_config.scheduler_config.max_num_seqs
if not vllm_config.model_config.enforce_eager:
max_batch_size = vllm_config.pad_for_cudagraph(max_batch_size)
# Initialize parent class
super().__init__(max_batch_size)
# assume conv_state = (dim, state_len)
assert conv_state_shape[0] > conv_state_shape[1]
conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
(conv_state_shape[1], conv_state_shape[0]),
dtype=self.conv_state_dtype,
device="cuda").transpose(-1, -2)
temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
temporal_state_shape,
dtype=self.temporal_state_dtype,
device="cuda")
self._mamba_cache = (conv_state, temporal_state)
@property
def cache(self):
return self._mamba_cache
def _copy_cache(self, from_index: int, to_index: int):
for cache_t in self.cache:
cache_t[:, to_index].copy_(cache_t[:, from_index],
non_blocking=True)
def current_run_tensors(self, **kwargs) -> MambaCacheParams:
"""
Return the tensors for the current run's conv and ssm state.
"""
cache_tensors, state_indices_tensor = super().current_run_tensors(
**kwargs)
return MambaCacheParams(cache_tensors[0], cache_tensors[1],
state_indices_tensor)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
"""
Provide the CUDA graph capture runs with a buffer in adjusted size.
The buffer is used to maintain the Mamba Cache during the CUDA graph
replay runs.
"""
return self._mamba_cache, torch.as_tensor([PAD_SLOT_ID] * batch_size,
dtype=torch.int32,
device="cuda")
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import torch
from vllm.model_executor.models.constant_size_cache import ConstantSizeCache
@dataclass
class MinimaxCacheParams:
minimax_cache: torch.Tensor = torch.Tensor()
state_indices_tensor: torch.Tensor = torch.Tensor()
def at_layer_idx(self, layer_idx):
return MinimaxCacheParams(self.minimax_cache[layer_idx, ...],
self.state_indices_tensor)
class MinimaxCacheManager(ConstantSizeCache):
def __init__(self, dtype, cache_shape):
super().__init__(cache_shape[1]) # max_batch_size is cache_shape[1]
self._minimax_cache = torch.empty(size=cache_shape,
dtype=dtype,
device="cuda")
@property
def cache(self):
return self._minimax_cache
def _copy_cache(self, from_index: int, to_index: int):
assert len(self.cache) > 0
for cache_t in self.cache:
cache_t[:, to_index].copy_(cache_t[:, from_index],
non_blocking=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment