Unverified Commit 160c6fa3 authored by Nicolò Lucchesi's avatar Nicolò Lucchesi Committed by GitHub
Browse files

[Misc] Add `get_name` to missing AttentionBackends (#32698)


Signed-off-by: default avatarNickLucche <nlucches@redhat.com>
parent a8eb1182
......@@ -22,6 +22,10 @@ from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
class GDNAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "GDN_ATTN"
@staticmethod
def get_builder_cls() -> type["GDNAttentionMetadataBuilder"]:
return GDNAttentionMetadataBuilder
......
......@@ -16,6 +16,10 @@ from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
class LinearAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "LINEAR_ATTN"
@staticmethod
def get_builder_cls() -> type["LinearAttentionMetadataBuilder"]:
return LinearAttentionMetadataBuilder
......
......@@ -11,6 +11,10 @@ from vllm.v1.attention.backends.mamba_attn import (
class Mamba1AttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "MAMBA1_ATTN"
@staticmethod
def get_builder_cls() -> type["Mamba1AttentionMetadataBuilder"]:
return Mamba1AttentionMetadataBuilder
......
......@@ -7,7 +7,10 @@ import torch
from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backend import AttentionBackend, CommonAttentionMetadata
from vllm.v1.attention.backend import (
AttentionBackend,
CommonAttentionMetadata,
)
from vllm.v1.attention.backends.mamba_attn import (
BaseMambaAttentionMetadata,
BaseMambaAttentionMetadataBuilder,
......@@ -85,6 +88,10 @@ def compute_varlen_chunk_metadata(
class Mamba2AttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "MAMBA2_ATTN"
@staticmethod
def get_builder_cls() -> type["Mamba2AttentionMetadataBuilder"]:
return Mamba2AttentionMetadataBuilder
......
......@@ -25,6 +25,10 @@ logger = init_logger(__name__)
class DeepseekV32IndexerBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "DEEPSEEK_V32_INDEXER"
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [1 if current_platform.is_rocm() else 64]
......
......@@ -10,6 +10,10 @@ from vllm.v1.attention.backends.mamba_attn import (
class ShortConvAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "SHORT_CONV_ATTN"
@staticmethod
def get_builder_cls() -> type["ShortConvAttentionMetadataBuilder"]:
return ShortConvAttentionMetadataBuilder
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment