"vscode:/vscode.git/clone" did not exist on "df29793dc73a83f3c86c19de967adffda1a28a93"
Unverified Commit 2e2000f3 authored by Paul Pak's avatar Paul Pak Committed by GitHub
Browse files

[Model] Add LFM2 architecture (#22845)


Signed-off-by: default avatarPaul Pak <paulpak58@gmail.com>
parent 31282401
...@@ -373,6 +373,7 @@ th { ...@@ -373,6 +373,7 @@ th {
| `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ | ✅︎ | | `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ | ✅︎ |
| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ | | `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Lfm2ForCausalLM` | LFM2 | `LiquidAI/LFM2-1.2B`, `LiquidAI/LFM2-700M`, `LiquidAI/LFM2-350M`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | ✅︎ | | `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | ✅︎ |
| `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | ✅︎ | | `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | ✅︎ |
......
...@@ -31,6 +31,7 @@ HYBRID_MODELS = [ ...@@ -31,6 +31,7 @@ HYBRID_MODELS = [
"hmellor/tiny-random-BambaForCausalLM", "hmellor/tiny-random-BambaForCausalLM",
"ibm-granite/granite-4.0-tiny-preview", "ibm-granite/granite-4.0-tiny-preview",
"tiiuae/Falcon-H1-0.5B-Base", "tiiuae/Falcon-H1-0.5B-Base",
"LiquidAI/LFM2-1.2B",
] ]
HF_UNSUPPORTED_MODELS = [ HF_UNSUPPORTED_MODELS = [
...@@ -52,6 +53,7 @@ V1_SUPPORTED_MODELS = [ ...@@ -52,6 +53,7 @@ V1_SUPPORTED_MODELS = [
"hmellor/tiny-random-BambaForCausalLM", "hmellor/tiny-random-BambaForCausalLM",
"ibm-granite/granite-4.0-tiny-preview", "ibm-granite/granite-4.0-tiny-preview",
"tiiuae/Falcon-H1-0.5B-Base", "tiiuae/Falcon-H1-0.5B-Base",
"LiquidAI/LFM2-1.2B",
] ]
FULL_CUDA_GRAPH_MODELS = [ FULL_CUDA_GRAPH_MODELS = [
...@@ -59,6 +61,10 @@ FULL_CUDA_GRAPH_MODELS = [ ...@@ -59,6 +61,10 @@ FULL_CUDA_GRAPH_MODELS = [
"Zyphra/Zamba2-1.2B-instruct", "Zyphra/Zamba2-1.2B-instruct",
] ]
V0_UNSUPPORTED_MODELS = [
"LiquidAI/LFM2-1.2B",
]
# Avoid OOM # Avoid OOM
MAX_NUM_SEQS = 4 MAX_NUM_SEQS = 4
...@@ -94,9 +100,12 @@ def test_models( ...@@ -94,9 +100,12 @@ def test_models(
else: else:
hf_outputs = None hf_outputs = None
if model not in V0_UNSUPPORTED_MODELS:
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_v0_outputs = vllm_model.generate_greedy_logprobs( vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs) example_prompts, max_tokens, num_logprobs)
else:
vllm_v0_outputs = None
if model in V1_SUPPORTED_MODELS: if model in V1_SUPPORTED_MODELS:
with monkeypatch.context() as m: with monkeypatch.context() as m:
...@@ -112,7 +121,7 @@ def test_models( ...@@ -112,7 +121,7 @@ def test_models(
else: else:
vllm_v1_outputs = None vllm_v1_outputs = None
if hf_outputs is not None: if hf_outputs is not None and vllm_v0_outputs is not None:
check_logprobs_close( check_logprobs_close(
outputs_0_lst=hf_outputs, outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_v0_outputs, outputs_1_lst=vllm_v0_outputs,
...@@ -122,6 +131,7 @@ def test_models( ...@@ -122,6 +131,7 @@ def test_models(
if model in V1_SUPPORTED_MODELS: if model in V1_SUPPORTED_MODELS:
ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs
assert ref_outputs is not None
check_logprobs_close( check_logprobs_close(
outputs_0_lst=ref_outputs, outputs_0_lst=ref_outputs,
outputs_1_lst=vllm_v1_outputs, outputs_1_lst=vllm_v1_outputs,
...@@ -140,6 +150,9 @@ def test_batching( ...@@ -140,6 +150,9 @@ def test_batching(
max_tokens: int, max_tokens: int,
num_logprobs: int, num_logprobs: int,
) -> None: ) -> None:
if model in V0_UNSUPPORTED_MODELS:
pytest.skip(
f"Unsupported V0 Engine. Skipping `test_batching` on {model}.")
try: try:
model_info = HF_EXAMPLE_MODELS.find_hf_info(model) model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
...@@ -392,9 +405,12 @@ def test_full_cuda_graph( ...@@ -392,9 +405,12 @@ def test_full_cuda_graph(
else: else:
hf_outputs = None hf_outputs = None
if model not in V0_UNSUPPORTED_MODELS:
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_v0_outputs = vllm_model.generate_greedy_logprobs( vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs) example_prompts, max_tokens, num_logprobs)
else:
vllm_v0_outputs = None
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
...@@ -408,7 +424,7 @@ def test_full_cuda_graph( ...@@ -408,7 +424,7 @@ def test_full_cuda_graph(
vllm_v1_outputs = vllm_model.generate_greedy_logprobs( vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs) example_prompts, max_tokens, num_logprobs)
if hf_outputs is not None: if hf_outputs is not None and vllm_v0_outputs is not None:
check_logprobs_close( check_logprobs_close(
outputs_0_lst=hf_outputs, outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_v0_outputs, outputs_1_lst=vllm_v0_outputs,
...@@ -417,6 +433,7 @@ def test_full_cuda_graph( ...@@ -417,6 +433,7 @@ def test_full_cuda_graph(
) )
ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs
assert ref_outputs is not None
check_logprobs_close( check_logprobs_close(
outputs_0_lst=ref_outputs, outputs_0_lst=ref_outputs,
outputs_1_lst=vllm_v1_outputs, outputs_1_lst=vllm_v1_outputs,
......
...@@ -230,6 +230,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -230,6 +230,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"tiny": "ai21labs/Jamba-tiny-dev", "tiny": "ai21labs/Jamba-tiny-dev",
"random": "ai21labs/Jamba-tiny-random", # noqa: E501 "random": "ai21labs/Jamba-tiny-random", # noqa: E501
}), }),
"Lfm2ForCausalLM": _HfExamplesInfo("LiquidAI/LFM2-1.2B",
min_transformers_version="4.54"),
"LlamaForCausalLM": _HfExamplesInfo("meta-llama/Llama-3.2-1B-Instruct", "LlamaForCausalLM": _HfExamplesInfo("meta-llama/Llama-3.2-1B-Instruct",
extras={"guard": "meta-llama/Llama-Guard-3-1B", # noqa: E501 extras={"guard": "meta-llama/Llama-Guard-3-1B", # noqa: E501
"hermes": "NousResearch/Hermes-3-Llama-3.1-8B", # noqa: E501 "hermes": "NousResearch/Hermes-3-Llama-3.1-8B", # noqa: E501
......
...@@ -95,6 +95,8 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch, ...@@ -95,6 +95,8 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs()) @pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())
def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch): def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
if model_arch == "Lfm2ForCausalLM":
pytest.skip("Skipping until test supports V1-only models")
can_initialize(model_arch, monkeypatch, HF_EXAMPLE_MODELS) can_initialize(model_arch, monkeypatch, HF_EXAMPLE_MODELS)
......
...@@ -337,6 +337,7 @@ class CompilationConfig: ...@@ -337,6 +337,7 @@ class CompilationConfig:
"vllm.unified_attention_with_output", "vllm.unified_attention_with_output",
"vllm.mamba_mixer2", "vllm.mamba_mixer2",
"vllm.mamba_mixer", "vllm.mamba_mixer",
"vllm.short_conv",
] ]
def compute_hash(self) -> str: def compute_hash(self) -> str:
......
...@@ -54,6 +54,16 @@ class MambaStateDtypeCalculator: ...@@ -54,6 +54,16 @@ class MambaStateDtypeCalculator:
return (conv_state_dtype, temporal_state_dtype) return (conv_state_dtype, temporal_state_dtype)
@classmethod
def short_conv_state_dtype(
cls,
model_dtype: Union[ModelDType, torch.dtype],
mamba_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype,
model_dtype)
return (conv_state_dtype, )
class MambaStateShapeCalculator: class MambaStateShapeCalculator:
...@@ -122,6 +132,20 @@ class MambaStateShapeCalculator: ...@@ -122,6 +132,20 @@ class MambaStateShapeCalculator:
tp_world_size), head_dim, state_size) tp_world_size), head_dim, state_size)
return conv_state_shape, temporal_state_shape return conv_state_shape, temporal_state_shape
@classmethod
def short_conv_state_shape(
cls,
tp_world_size: int,
intermediate_size: int,
conv_kernel: int,
use_v1: bool = True,
) -> tuple[tuple[int, int]]:
conv_dim = divide(intermediate_size, tp_world_size)
conv_state_shape = (conv_kernel - 1, conv_dim)
if not use_v1:
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
return (conv_state_shape, )
@classmethod @classmethod
def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int): def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int):
"""Compute the increase in group numbers to account for """Compute the increase in group numbers to account for
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba2_metadata import update_metadata
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.short_conv_attn import (
ShortConvAttentionMetadata)
@CustomOp.register("short_conv")
class ShortConv(MambaBase, CustomOp):
def __init__(self,
config,
dim: int,
layer_idx: int,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
prefix: str = ""):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.conv_dim = dim
self.L_cache = config.conv_L_cache
self.bias = config.conv_bias
self.conv = ColumnParallelLinear(
input_size=self.L_cache,
output_size=dim,
bias=self.bias,
prefix=f"{prefix}.conv1d",
)
# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
# `ColumnParallelLinear` and `set_weight_attrs`
# doesn't allow to override it
self.conv.weight.data = self.conv.weight.data.unsqueeze(1)
self.in_proj = MergedColumnParallelLinear(
input_size=dim,
output_sizes=[dim] * 3,
bias=self.bias,
prefix=f"{prefix}.in_proj",
)
self.out_proj = RowParallelLinear(
input_size=dim,
output_size=dim,
bias=self.bias,
prefix=f"{prefix}.out_proj",
)
assert envs.VLLM_USE_V1, ("ShortConv layers are only supported in V1")
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
# The outer list is for v0 PP virtual engine. Though this code path
# only runs for v1, we have to do this to unify with the interface
# of Attention + v0 PP.
self.kv_cache = [(torch.tensor([]), )]
self.model_config = model_config
self.cache_config = cache_config
self.prefix = prefix
def forward_native(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
conv_metadata: ShortConvAttentionMetadata,
):
return
def forward(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
conv_metadata: ShortConvAttentionMetadata,
):
torch.ops.vllm.short_conv(
hidden_states,
output,
self.prefix,
)
def forward_cuda(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
conv_metadata: ShortConvAttentionMetadata,
):
forward_context = get_forward_context()
# ShortConvAttentionMetadata contains metadata necessary for the
# short_conv triton kernels to operate in continuous batching and in
# chunked prefill modes; they are computed at top-level model forward
# since they stay the same and reused for all mamba layers in the same
# iteration.
attn_metadata: AttentionMetadata = forward_context.attn_metadata
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
conv_metadata = attn_metadata
assert isinstance(attn_metadata, ShortConvAttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
conv_state = self_kv_cache[0].transpose(-1, -2)
state_indices_tensor = attn_metadata.state_indices_tensor
has_initial_states_p = attn_metadata.has_initial_states
BCx, _ = self.in_proj(hidden_states)
B, C, x = BCx.chunk(3, dim=-1)
conv_weights = self.conv.weight.view(self.conv.weight.size(0),
self.conv.weight.size(2))
if attn_metadata is None:
# V1 profile run
Bx = (B * x).contiguous()
hidden_states = C * Bx
contextualized_states, _ = self.out_proj(hidden_states)
return contextualized_states
num_prefills = attn_metadata.num_prefills # request count
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
has_prefill = num_prefills > 0
has_decode = num_decodes > 0
num_actual_tokens = num_decodes + num_prefill_tokens
# NOTE: V1 puts decode before prefill
# Separate prefill and decode by splitting varlen input
# Split along token dimension
B_d, B_p = torch.split(
B[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
dim=0,
)
C_d, C_p = torch.split(
C[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
dim=0,
)
x_d, x_p = torch.split(
x[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
dim=0,
)
# Split along batch dimension
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor,
[num_decodes, num_prefills],
dim=0,
)
query_start_loc_p = (
attn_metadata.query_start_loc[-num_prefills - 1:] -
num_decodes if has_prefill else None)
conv_output_list = []
if has_prefill:
Bx_p = (B_p * x_p).transpose(0, 1)
if conv_metadata.cu_seqlen is None:
conv_metadata = update_metadata(Bx_p, query_start_loc_p,
conv_metadata)
Bx = causal_conv1d_fn(Bx_p,
conv_weights,
self.conv.bias,
activation=None,
conv_states=conv_state,
has_initial_state=has_initial_states_p,
cache_indices=state_indices_tensor_p,
metadata=conv_metadata,
query_start_loc=query_start_loc_p).transpose(
0, 1)[:num_prefill_tokens]
y = C_p * Bx
conv_output_list.append(y)
if has_decode:
Bx_d = (B_d * x_d).contiguous()
Bx = causal_conv1d_update(
Bx_d,
conv_state,
conv_weights,
self.conv.bias,
activation=None,
conv_state_indices=state_indices_tensor_d)
y = C_d * Bx
conv_output_list.insert(0, y)
# Merge prefill and decode outputs before passing to gated MLP
hidden_states = torch.vstack(conv_output_list)
# Final linear projection
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
def get_state_dtype(self) -> tuple[torch.dtype, ...]:
assert self.model_config is not None
assert self.cache_config is not None
return MambaStateDtypeCalculator.short_conv_state_dtype(
self.model_config.dtype,
self.cache_config.mamba_cache_dtype,
)
def get_state_shape(self) -> tuple[tuple[int, ...]]:
return MambaStateShapeCalculator.short_conv_state_shape(
tp_world_size=get_tensor_model_parallel_world_size(),
intermediate_size=self.conv_dim,
conv_kernel=self.L_cache,
)
@property
def mamba_type(self) -> str:
return "short_conv"
def short_conv(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self.forward_cuda(hidden_states=hidden_states,
output=output,
conv_metadata=None)
def short_conv_fake(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="short_conv",
op_func=short_conv,
mutates_args=["output"],
fake_impl=short_conv_fake,
dispatch_key=current_platform.dispatch_key,
)
This diff is collapsed.
...@@ -93,6 +93,7 @@ _TEXT_GENERATION_MODELS = { ...@@ -93,6 +93,7 @@ _TEXT_GENERATION_MODELS = {
"InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"), "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"), "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
"JambaForCausalLM": ("jamba", "JambaForCausalLM"), "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
"Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"), "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
"Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"), # noqa: E501 "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"), # noqa: E501
# For decapoda-research/llama-* # For decapoda-research/llama-*
......
...@@ -4,6 +4,8 @@ from vllm.attention.backends.abstract import AttentionBackend ...@@ -4,6 +4,8 @@ from vllm.attention.backends.abstract import AttentionBackend
from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend
from vllm.v1.attention.backends.short_conv_attn import (
ShortConvAttentionBackend)
def get_mamba_attn_backend(mamba_type: str) -> type[AttentionBackend]: def get_mamba_attn_backend(mamba_type: str) -> type[AttentionBackend]:
...@@ -13,6 +15,8 @@ def get_mamba_attn_backend(mamba_type: str) -> type[AttentionBackend]: ...@@ -13,6 +15,8 @@ def get_mamba_attn_backend(mamba_type: str) -> type[AttentionBackend]:
return Mamba2AttentionBackend return Mamba2AttentionBackend
if mamba_type == "linear_attention": if mamba_type == "linear_attention":
return LinearAttentionBackend return LinearAttentionBackend
if mamba_type == "short_conv":
return ShortConvAttentionBackend
raise NotImplementedError(f"Mamba Attention type {mamba_type} is not " raise NotImplementedError(f"Mamba Attention type {mamba_type} is not "
"supported yet.") "supported yet.")
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar, Optional
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
class ShortConvAttentionBackend(AttentionBackend):
@staticmethod
def get_builder_cls() -> type["ShortConvAttentionMetadataBuilder"]:
return ShortConvAttentionMetadataBuilder
@dataclass
class ShortConvAttentionMetadata:
num_prefills: int
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
query_start_loc: torch.Tensor
has_initial_states: torch.Tensor
state_indices_tensor: torch.Tensor # shape: [batch,]
# For causal_conv1d
nums_dict: Optional[dict] = None
cu_seqlen: Optional[int] = None
batch_ptr: Optional[torch.tensor] = None
token_chunk_offset_ptr: Optional[torch.tensor] = None
class ShortConvAttentionMetadataBuilder(
AttentionMetadataBuilder[ShortConvAttentionMetadata]):
reorder_batch_threshold: ClassVar[int] = 1
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
assert isinstance(kv_cache_spec, MambaSpec)
self.kv_cache_spec = kv_cache_spec
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> ShortConvAttentionMetadata:
num_reqs = common_attn_metadata.num_reqs
query_start_loc = common_attn_metadata.query_start_loc
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(common_attn_metadata,
decode_threshold=1))
has_initial_states = None
if num_prefills > 0:
#[batch,]
has_initial_states_cpu = (
common_attn_metadata.
num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0)
has_initial_states = has_initial_states_cpu.to(
query_start_loc.device)
attn_metadata = ShortConvAttentionMetadata(
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
query_start_loc=query_start_loc,
has_initial_states=has_initial_states,
state_indices_tensor=state_indices_tensor,
)
return attn_metadata
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment