Unverified Commit 3468f17e authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[V0 deprecation] Remove _VLLM_V1 suffixes from attention backend names (#25489)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
Signed-off-by: default avatarMatthew Bonanni <mbonanni001@gmail.com>
parent 71b25b0d
...@@ -186,6 +186,14 @@ def _cached_get_attn_backend( ...@@ -186,6 +186,14 @@ def _cached_get_attn_backend(
# Check the environment variable and override if specified # Check the environment variable and override if specified
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
if backend_by_env_var is not None: if backend_by_env_var is not None:
if backend_by_env_var.endswith("_VLLM_V1"):
logger.warning(
"The suffix '_VLLM_V1' in the environment variable "
"%s is no longer necessary as V0 backends have been "
"deprecated. Please remove this suffix from your "
"environment variable setting.", STR_BACKEND_ENV_VAR)
backend_by_env_var = backend_by_env_var.removesuffix(
"_VLLM_V1")
selected_backend = backend_name_to_enum(backend_by_env_var) selected_backend = backend_name_to_enum(backend_by_env_var)
if selected_backend is None: if selected_backend is None:
raise ValueError( raise ValueError(
......
...@@ -577,8 +577,8 @@ class NixlConnectorWorker: ...@@ -577,8 +577,8 @@ class NixlConnectorWorker:
use_mla=self.use_mla) use_mla=self.use_mla)
self.backend_name = backend.get_name() self.backend_name = backend.get_name()
attn_backend = backend_name_to_enum(self.backend_name) attn_backend = backend_name_to_enum(self.backend_name)
self._use_flashinfer = attn_backend == _Backend.FLASHINFER_VLLM_V1 self._use_flashinfer = attn_backend == _Backend.FLASHINFER
self._use_pallas_v1 = attn_backend == _Backend.PALLAS_VLLM_V1 self._use_pallas = attn_backend == _Backend.PALLAS
self.kv_cache_layout = get_kv_cache_layout() self.kv_cache_layout = get_kv_cache_layout()
logger.debug("Detected attention backend %s", self.backend_name) logger.debug("Detected attention backend %s", self.backend_name)
logger.debug("Detected kv cache layout %s", self.kv_cache_layout) logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
...@@ -749,7 +749,7 @@ class NixlConnectorWorker: ...@@ -749,7 +749,7 @@ class NixlConnectorWorker:
# (roughly 8KB vs 5KB). # (roughly 8KB vs 5KB).
# Conversely for FlashInfer, K and V are registered in the same region # Conversely for FlashInfer, K and V are registered in the same region
# to better exploit the memory layout (ie num_blocks is the first dim). # to better exploit the memory layout (ie num_blocks is the first dim).
split_k_and_v = not (self.use_mla or self._use_pallas_v1 split_k_and_v = not (self.use_mla or self._use_pallas
or self._use_flashinfer) or self._use_flashinfer)
tensor_size_bytes = None tensor_size_bytes = None
for layer_name, cache_or_caches in xfer_buffers.items(): for layer_name, cache_or_caches in xfer_buffers.items():
...@@ -938,7 +938,7 @@ class NixlConnectorWorker: ...@@ -938,7 +938,7 @@ class NixlConnectorWorker:
tp_ratio = divide(self._tp_size[self.engine_id], tp_ratio = divide(self._tp_size[self.engine_id],
self._tp_size[engine_id]) self._tp_size[engine_id])
assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP" assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP"
assert not self._use_pallas_v1 or tp_ratio == 1, \ assert not self._use_pallas or tp_ratio == 1, \
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet." "TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
# Handle tp_size>num_kv_heads: replicate KV cache. # Handle tp_size>num_kv_heads: replicate KV cache.
......
...@@ -1479,25 +1479,21 @@ class EngineArgs: ...@@ -1479,25 +1479,21 @@ class EngineArgs:
"such as ngram, medusa, eagle, or deepseek_mtp.") "such as ngram, medusa, eagle, or deepseek_mtp.")
V1_BACKENDS = [ V1_BACKENDS = [
"FLASH_ATTN_VLLM_V1",
"FLASH_ATTN", "FLASH_ATTN",
"PALLAS", "PALLAS",
"PALLAS_VLLM_V1", "TRITON_ATTN",
"TRITON_ATTN_VLLM_V1",
"TRITON_MLA", "TRITON_MLA",
"CUTLASS_MLA", "CUTLASS_MLA",
"FLASHMLA", "FLASHMLA",
"FLASHMLA_VLLM_V1",
"FLASH_ATTN_MLA", "FLASH_ATTN_MLA",
"FLASHINFER", "FLASHINFER",
"FLASHINFER_VLLM_V1",
"FLASHINFER_MLA", "FLASHINFER_MLA",
"ROCM_AITER_MLA", "ROCM_AITER_MLA",
"TORCH_SDPA_VLLM_V1", "TORCH_SDPA",
"FLEX_ATTENTION", "FLEX_ATTENTION",
"TREE_ATTN", "TREE_ATTN",
"XFORMERS_VLLM_V1", "XFORMERS",
"ROCM_ATTN_VLLM_V1", "ROCM_ATTN",
] ]
if (envs.is_set("VLLM_ATTENTION_BACKEND") if (envs.is_set("VLLM_ATTENTION_BACKEND")
and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS): and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):
......
...@@ -42,7 +42,7 @@ def kernel_warmup(worker: "Worker"): ...@@ -42,7 +42,7 @@ def kernel_warmup(worker: "Worker"):
# and is not a pooling model # and is not a pooling model
def _is_flashinfer_backend(backend): def _is_flashinfer_backend(backend):
try: try:
return backend.get_name() == "FLASHINFER_VLLM_V1" return backend.get_name() == "FLASHINFER"
except NotImplementedError: except NotImplementedError:
return False return False
......
...@@ -241,9 +241,8 @@ class CudaPlatformBase(Platform): ...@@ -241,9 +241,8 @@ class CudaPlatformBase(Platform):
use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or ( use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or (
selected_backend is None and cls.is_device_capability(100) selected_backend is None and cls.is_device_capability(100)
and block_size in [32, 64]) and block_size in [32, 64])
use_flashmla = selected_backend in [ use_flashmla = selected_backend == _Backend.FLASHMLA or (
_Backend.FLASHMLA, _Backend.FLASHMLA_VLLM_V1 selected_backend is None and is_flashmla_supported()[0])
] or (selected_backend is None and is_flashmla_supported()[0])
use_flashattn = selected_backend == _Backend.FLASH_ATTN_MLA or ( use_flashattn = selected_backend == _Backend.FLASH_ATTN_MLA or (
selected_backend is None and flash_attn_supports_mla()) selected_backend is None and flash_attn_supports_mla())
use_triton = selected_backend == _Backend.TRITON_MLA or ( use_triton = selected_backend == _Backend.TRITON_MLA or (
...@@ -282,7 +281,7 @@ class CudaPlatformBase(Platform): ...@@ -282,7 +281,7 @@ class CudaPlatformBase(Platform):
if use_v1: if use_v1:
FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501 FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501
FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501 FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501 TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501 TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501
XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501 XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501
...@@ -300,16 +299,16 @@ class CudaPlatformBase(Platform): ...@@ -300,16 +299,16 @@ class CudaPlatformBase(Platform):
elif selected_backend == _Backend.FLEX_ATTENTION: elif selected_backend == _Backend.FLEX_ATTENTION:
logger.info_once("Using FlexAttention backend on V1 engine.") logger.info_once("Using FlexAttention backend on V1 engine.")
return FLEX_ATTENTION_V1 return FLEX_ATTENTION_V1
elif selected_backend == _Backend.TRITON_ATTN_VLLM_V1: elif selected_backend == _Backend.TRITON_ATTN:
logger.info_once("Using Triton backend on V1 engine.") logger.info_once("Using Triton backend on V1 engine.")
return TRITON_ATTN_VLLM_V1 return TRITON_ATTN
elif selected_backend == _Backend.FLASH_ATTN: elif selected_backend == _Backend.FLASH_ATTN:
logger.info_once("Using Flash Attention backend on V1 engine.") logger.info_once("Using Flash Attention backend on V1 engine.")
return FLASH_ATTN_V1 return FLASH_ATTN_V1
elif selected_backend == _Backend.TREE_ATTN: elif selected_backend == _Backend.TREE_ATTN:
logger.info_once("Using Tree Attention backend on V1 engine.") logger.info_once("Using Tree Attention backend on V1 engine.")
return TREE_ATTN_V1 return TREE_ATTN_V1
elif selected_backend == _Backend.XFORMERS_VLLM_V1: elif selected_backend == _Backend.XFORMERS:
logger.info_once("Using XFormers backend on V1 engine.") logger.info_once("Using XFormers backend on V1 engine.")
return XFORMERS_V1 return XFORMERS_V1
...@@ -341,7 +340,7 @@ class CudaPlatformBase(Platform): ...@@ -341,7 +340,7 @@ class CudaPlatformBase(Platform):
if (has_sink or if (has_sink or
use_fp8_kv_cache) and not cls.is_device_capability(90): use_fp8_kv_cache) and not cls.is_device_capability(90):
logger.info_once("Using Triton backend on V1 engine.") logger.info_once("Using Triton backend on V1 engine.")
return TRITON_ATTN_VLLM_V1 return TRITON_ATTN
elif is_default_backend_supported := is_attn_backend_supported( elif is_default_backend_supported := is_attn_backend_supported(
FLASH_ATTN_V1, head_size, dtype, FLASH_ATTN_V1, head_size, dtype,
allow_import_error=False): allow_import_error=False):
...@@ -457,12 +456,12 @@ class CudaPlatformBase(Platform): ...@@ -457,12 +456,12 @@ class CudaPlatformBase(Platform):
else: else:
# Default to FlashAttention # Default to FlashAttention
if attention_backend is None: if attention_backend is None:
attention_backend = "FLASH_ATTN_VLLM_V1" attention_backend = "FLASH_ATTN"
# All Blackwell backends support fp8 # All Blackwell backends support fp8
if cls.is_device_capability(100): if cls.is_device_capability(100):
supported = True supported = True
elif attention_backend == "FLASH_ATTN_VLLM_V1": elif attention_backend == "FLASH_ATTN":
if fp8_attention: if fp8_attention:
from vllm.attention.utils.fa_utils import ( from vllm.attention.utils.fa_utils import (
flash_attn_supports_fp8) flash_attn_supports_fp8)
...@@ -471,7 +470,7 @@ class CudaPlatformBase(Platform): ...@@ -471,7 +470,7 @@ class CudaPlatformBase(Platform):
supported = True supported = True
elif attention_backend == "FLASHINFER": elif attention_backend == "FLASHINFER":
supported = True supported = True
elif attention_backend == "TRITON_ATTN_VLLM_V1": elif attention_backend == "TRITON_ATTN":
supported = cls.supports_fp8() supported = cls.supports_fp8()
return supported return supported
......
...@@ -40,34 +40,26 @@ def in_wsl() -> bool: ...@@ -40,34 +40,26 @@ def in_wsl() -> bool:
class _Backend(enum.Enum): class _Backend(enum.Enum):
FLASH_ATTN = enum.auto() FLASH_ATTN = enum.auto()
FLASH_ATTN_VLLM_V1 = enum.auto() TRITON_ATTN = enum.auto()
TRITON_ATTN_VLLM_V1 = enum.auto()
XFORMERS = enum.auto() XFORMERS = enum.auto()
ROCM_FLASH = enum.auto() ROCM_FLASH = enum.auto()
ROCM_AITER_MLA = enum.auto() # Supported by V1 ROCM_AITER_MLA = enum.auto() # Supported by V1
ROCM_AITER_MLA_VLLM_V1 = enum.auto()
ROCM_AITER_FA = enum.auto() # used for ViT attn backend ROCM_AITER_FA = enum.auto() # used for ViT attn backend
TORCH_SDPA = enum.auto() TORCH_SDPA = enum.auto()
TORCH_SDPA_VLLM_V1 = enum.auto()
FLASHINFER = enum.auto() FLASHINFER = enum.auto()
FLASHINFER_VLLM_V1 = enum.auto()
FLASHINFER_MLA = enum.auto() FLASHINFER_MLA = enum.auto()
TRITON_MLA = enum.auto() # Supported by V1 TRITON_MLA = enum.auto() # Supported by V1
TRITON_MLA_VLLM_V1 = enum.auto()
CUTLASS_MLA = enum.auto() CUTLASS_MLA = enum.auto()
FLASHMLA = enum.auto() # Supported by V1 FLASHMLA = enum.auto() # Supported by V1
FLASHMLA_VLLM_V1 = enum.auto()
FLASH_ATTN_MLA = enum.auto() # Supported by V1 FLASH_ATTN_MLA = enum.auto() # Supported by V1
PALLAS = enum.auto() PALLAS = enum.auto()
PALLAS_VLLM_V1 = enum.auto()
IPEX = enum.auto() IPEX = enum.auto()
DUAL_CHUNK_FLASH_ATTN = enum.auto() DUAL_CHUNK_FLASH_ATTN = enum.auto()
DIFFERENTIAL_FLASH_ATTN = enum.auto() DIFFERENTIAL_FLASH_ATTN = enum.auto()
NO_ATTENTION = enum.auto() NO_ATTENTION = enum.auto()
FLEX_ATTENTION = enum.auto() FLEX_ATTENTION = enum.auto()
TREE_ATTN = enum.auto() TREE_ATTN = enum.auto()
XFORMERS_VLLM_V1 = enum.auto() ROCM_ATTN = enum.auto()
ROCM_ATTN_VLLM_V1 = enum.auto()
class PlatformEnum(enum.Enum): class PlatformEnum(enum.Enum):
......
...@@ -218,8 +218,7 @@ class RocmPlatform(Platform): ...@@ -218,8 +218,7 @@ class RocmPlatform(Platform):
raise ValueError( raise ValueError(
f" The selected backend, {selected_backend.name}," f" The selected backend, {selected_backend.name},"
f"does not support block size {block_size}.") f"does not support block size {block_size}.")
if selected_backend in (_Backend.ROCM_AITER_MLA, if selected_backend == _Backend.ROCM_AITER_MLA:
_Backend.ROCM_AITER_MLA_VLLM_V1):
if block_size == 1: if block_size == 1:
logger.info("Using AITER MLA backend on V1 engine.") logger.info("Using AITER MLA backend on V1 engine.")
return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501 return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
...@@ -240,7 +239,7 @@ class RocmPlatform(Platform): ...@@ -240,7 +239,7 @@ class RocmPlatform(Platform):
elif (envs.VLLM_ROCM_USE_AITER and elif (envs.VLLM_ROCM_USE_AITER and
envs.VLLM_USE_AITER_UNIFIED_ATTENTION) or \ envs.VLLM_USE_AITER_UNIFIED_ATTENTION) or \
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION or \ envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION or \
selected_backend == _Backend.ROCM_ATTN_VLLM_V1: selected_backend == _Backend.ROCM_ATTN:
# rocm specific backend, with aiter and/or # rocm specific backend, with aiter and/or
# triton prefix-prefill # triton prefix-prefill
logger.info("Using Rocm/Aiter Attention backend on V1 engine.") logger.info("Using Rocm/Aiter Attention backend on V1 engine.")
......
...@@ -50,8 +50,7 @@ class TpuPlatform(Platform): ...@@ -50,8 +50,7 @@ class TpuPlatform(Platform):
dtype: torch.dtype, kv_cache_dtype: Optional[str], dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool, use_mla: bool, block_size: int, use_v1: bool, use_mla: bool,
has_sink) -> str: has_sink) -> str:
if (selected_backend != _Backend.PALLAS if selected_backend != _Backend.PALLAS:
and selected_backend != _Backend.PALLAS_VLLM_V1):
logger.info("Cannot use %s backend on TPU.", selected_backend) logger.info("Cannot use %s backend on TPU.", selected_backend)
if not use_v1: if not use_v1:
......
...@@ -40,14 +40,14 @@ class XPUPlatform(Platform): ...@@ -40,14 +40,14 @@ class XPUPlatform(Platform):
use_v1 = envs.VLLM_USE_V1 use_v1 = envs.VLLM_USE_V1
if not use_v1: if not use_v1:
raise ValueError("XPU backend only supports V1.") raise ValueError("XPU backend only supports V1.")
TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501 TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
if selected_backend == _Backend.TRITON_ATTN_VLLM_V1: if selected_backend == _Backend.TRITON_ATTN:
logger.info_once("Using Triton backend on V1 engine.") logger.info_once("Using Triton backend on V1 engine.")
return TRITON_ATTN_VLLM_V1 return TRITON_ATTN
elif selected_backend == _Backend.FLASH_ATTN: elif selected_backend == _Backend.FLASH_ATTN:
logger.info_once("Using Flash Attention backend on V1 engine.") logger.info_once("Using Flash Attention backend on V1 engine.")
return FLASH_ATTN_V1 return FLASH_ATTN
elif selected_backend: elif selected_backend:
raise ValueError( raise ValueError(
f"Invalid attention backend for {cls.device_name}, " f"Invalid attention backend for {cls.device_name}, "
...@@ -64,7 +64,7 @@ class XPUPlatform(Platform): ...@@ -64,7 +64,7 @@ class XPUPlatform(Platform):
XPU only support fp8 kv cache with triton backend. XPU only support fp8 kv cache with triton backend.
""" """
if envs.is_set("VLLM_ATTENTION_BACKEND") and \ if envs.is_set("VLLM_ATTENTION_BACKEND") and \
envs.VLLM_ATTENTION_BACKEND == "TRITON_ATTN_VLLM_V1": envs.VLLM_ATTENTION_BACKEND == "TRITON_ATTN":
return kv_cache_dtype in ["fp8_e4m3", "fp8_e5m2", "fp8"] return kv_cache_dtype in ["fp8_e4m3", "fp8_e5m2", "fp8"]
return False return False
......
...@@ -54,7 +54,7 @@ class TorchSDPABackend(AttentionBackend): ...@@ -54,7 +54,7 @@ class TorchSDPABackend(AttentionBackend):
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "TORCH_SDPA_VLLM_V1" return "TORCH_SDPA"
@staticmethod @staticmethod
def get_impl_cls() -> type["TorchSDPABackendImpl"]: def get_impl_cls() -> type["TorchSDPABackendImpl"]:
......
...@@ -60,7 +60,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -60,7 +60,7 @@ class FlashAttentionBackend(AttentionBackend):
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "FLASH_ATTN_VLLM_V1" return "FLASH_ATTN"
@staticmethod @staticmethod
def get_impl_cls() -> type["FlashAttentionImpl"]: def get_impl_cls() -> type["FlashAttentionImpl"]:
......
...@@ -167,7 +167,7 @@ class FlashInferBackend(AttentionBackend): ...@@ -167,7 +167,7 @@ class FlashInferBackend(AttentionBackend):
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "FLASHINFER_VLLM_V1" return "FLASHINFER"
@staticmethod @staticmethod
def get_impl_cls() -> type[FlashInferImpl]: def get_impl_cls() -> type[FlashInferImpl]:
......
...@@ -270,7 +270,7 @@ class MLACommonBackend(AttentionBackend): ...@@ -270,7 +270,7 @@ class MLACommonBackend(AttentionBackend):
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "TRITON_MLA_VLLM_V1" return "TRITON_MLA"
@staticmethod @staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]: def get_metadata_cls() -> type["AttentionMetadata"]:
......
...@@ -27,7 +27,7 @@ class FlashMLABackend(MLACommonBackend): ...@@ -27,7 +27,7 @@ class FlashMLABackend(MLACommonBackend):
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "FLASHMLA_VLLM_V1" return "FLASHMLA"
@staticmethod @staticmethod
def get_metadata_cls() -> type["FlashMLAMetadata"]: def get_metadata_cls() -> type["FlashMLAMetadata"]:
......
...@@ -33,7 +33,7 @@ class AiterMLABackend(MLACommonBackend): ...@@ -33,7 +33,7 @@ class AiterMLABackend(MLACommonBackend):
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "ROCM_AITER_MLA_VLLM_V1" return "ROCM_AITER_MLA"
@staticmethod @staticmethod
def get_impl_cls() -> type["AiterMLAImpl"]: def get_impl_cls() -> type["AiterMLAImpl"]:
......
...@@ -24,7 +24,7 @@ class TritonMLABackend(MLACommonBackend): ...@@ -24,7 +24,7 @@ class TritonMLABackend(MLACommonBackend):
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "TRITON_MLA_VLLM_V1" return "TRITON_MLA"
@staticmethod @staticmethod
def get_impl_cls() -> type["TritonMLAImpl"]: def get_impl_cls() -> type["TritonMLAImpl"]:
......
...@@ -86,7 +86,7 @@ class PallasAttentionBackend(AttentionBackend): ...@@ -86,7 +86,7 @@ class PallasAttentionBackend(AttentionBackend):
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "PALLAS_VLLM_V1" return "PALLAS"
@staticmethod @staticmethod
def get_impl_cls() -> type["PallasAttentionBackendImpl"]: def get_impl_cls() -> type["PallasAttentionBackendImpl"]:
......
...@@ -340,7 +340,7 @@ class AiterFlashAttentionBackend(AttentionBackend): ...@@ -340,7 +340,7 @@ class AiterFlashAttentionBackend(AttentionBackend):
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "FLASH_ATTN_VLLM_V1" return "FLASH_ATTN"
@staticmethod @staticmethod
def get_impl_cls() -> type["AiterFlashAttentionImpl"]: def get_impl_cls() -> type["AiterFlashAttentionImpl"]:
......
...@@ -159,7 +159,7 @@ class RocmAttentionBackend(AttentionBackend): ...@@ -159,7 +159,7 @@ class RocmAttentionBackend(AttentionBackend):
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "ROCM_ATTN_VLLM_V1" return "ROCM_ATTN"
@staticmethod @staticmethod
def get_impl_cls() -> type["RocmAttentionImpl"]: def get_impl_cls() -> type["RocmAttentionImpl"]:
......
...@@ -52,7 +52,7 @@ class TreeAttentionBackend(AttentionBackend): ...@@ -52,7 +52,7 @@ class TreeAttentionBackend(AttentionBackend):
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "TREE_ATTN_VLLM_V1" return "TREE_ATTN"
@staticmethod @staticmethod
def get_impl_cls() -> type["TreeAttentionImpl"]: def get_impl_cls() -> type["TreeAttentionImpl"]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment