Unverified Commit 160c6fa3 authored by Nicolò Lucchesi's avatar Nicolò Lucchesi Committed by GitHub
Browse files

[Misc] Add `get_name` to missing AttentionBackends (#32698)


Signed-off-by: default avatarNickLucche <nlucches@redhat.com>
parent a8eb1182
...@@ -22,6 +22,10 @@ from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec ...@@ -22,6 +22,10 @@ from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
class GDNAttentionBackend(AttentionBackend): class GDNAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "GDN_ATTN"
@staticmethod @staticmethod
def get_builder_cls() -> type["GDNAttentionMetadataBuilder"]: def get_builder_cls() -> type["GDNAttentionMetadataBuilder"]:
return GDNAttentionMetadataBuilder return GDNAttentionMetadataBuilder
......
...@@ -16,6 +16,10 @@ from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec ...@@ -16,6 +16,10 @@ from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
class LinearAttentionBackend(AttentionBackend): class LinearAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "LINEAR_ATTN"
@staticmethod @staticmethod
def get_builder_cls() -> type["LinearAttentionMetadataBuilder"]: def get_builder_cls() -> type["LinearAttentionMetadataBuilder"]:
return LinearAttentionMetadataBuilder return LinearAttentionMetadataBuilder
......
...@@ -11,6 +11,10 @@ from vllm.v1.attention.backends.mamba_attn import ( ...@@ -11,6 +11,10 @@ from vllm.v1.attention.backends.mamba_attn import (
class Mamba1AttentionBackend(AttentionBackend): class Mamba1AttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "MAMBA1_ATTN"
@staticmethod @staticmethod
def get_builder_cls() -> type["Mamba1AttentionMetadataBuilder"]: def get_builder_cls() -> type["Mamba1AttentionMetadataBuilder"]:
return Mamba1AttentionMetadataBuilder return Mamba1AttentionMetadataBuilder
......
...@@ -7,7 +7,10 @@ import torch ...@@ -7,7 +7,10 @@ import torch
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backend import AttentionBackend, CommonAttentionMetadata from vllm.v1.attention.backend import (
AttentionBackend,
CommonAttentionMetadata,
)
from vllm.v1.attention.backends.mamba_attn import ( from vllm.v1.attention.backends.mamba_attn import (
BaseMambaAttentionMetadata, BaseMambaAttentionMetadata,
BaseMambaAttentionMetadataBuilder, BaseMambaAttentionMetadataBuilder,
...@@ -85,6 +88,10 @@ def compute_varlen_chunk_metadata( ...@@ -85,6 +88,10 @@ def compute_varlen_chunk_metadata(
class Mamba2AttentionBackend(AttentionBackend): class Mamba2AttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "MAMBA2_ATTN"
@staticmethod @staticmethod
def get_builder_cls() -> type["Mamba2AttentionMetadataBuilder"]: def get_builder_cls() -> type["Mamba2AttentionMetadataBuilder"]:
return Mamba2AttentionMetadataBuilder return Mamba2AttentionMetadataBuilder
......
...@@ -25,6 +25,10 @@ logger = init_logger(__name__) ...@@ -25,6 +25,10 @@ logger = init_logger(__name__)
class DeepseekV32IndexerBackend(AttentionBackend): class DeepseekV32IndexerBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "DEEPSEEK_V32_INDEXER"
@staticmethod @staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [1 if current_platform.is_rocm() else 64] return [1 if current_platform.is_rocm() else 64]
......
...@@ -10,6 +10,10 @@ from vllm.v1.attention.backends.mamba_attn import ( ...@@ -10,6 +10,10 @@ from vllm.v1.attention.backends.mamba_attn import (
class ShortConvAttentionBackend(AttentionBackend): class ShortConvAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "SHORT_CONV_ATTN"
@staticmethod @staticmethod
def get_builder_cls() -> type["ShortConvAttentionMetadataBuilder"]: def get_builder_cls() -> type["ShortConvAttentionMetadataBuilder"]:
return ShortConvAttentionMetadataBuilder return ShortConvAttentionMetadataBuilder
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment