Unverified Commit bc6e542d authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Remove V0 attention backends (#25351)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent af7dfb0d
...@@ -44,7 +44,7 @@ class model_aware_kv_ops_helper: ...@@ -44,7 +44,7 @@ class model_aware_kv_ops_helper:
# When VLLM_MLA_DISABLE=1, standard FA is used instead, leading # When VLLM_MLA_DISABLE=1, standard FA is used instead, leading
# to a kv_cache shape of [2, num_blks, blk_size, # to a kv_cache shape of [2, num_blks, blk_size,
# num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim]. # num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim].
# For more details, see vllm/attention/backends/mla/common.py. # For more details, see vllm/v1/attention/backends/mla/common.py.
if self.is_deepseek_mla and self.use_mla_opt: if self.is_deepseek_mla and self.use_mla_opt:
head_size = model_config.kv_lora_rank + \ head_size = model_config.kv_lora_rank + \
model_config.qk_rope_head_dim model_config.qk_rope_head_dim
......
...@@ -44,8 +44,8 @@ from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 ...@@ -44,8 +44,8 @@ from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
from vllm.transformers_utils.config import (get_model_path, is_interleaved, from vllm.transformers_utils.config import (get_model_path, is_interleaved,
maybe_override_with_speculators) maybe_override_with_speculators)
from vllm.transformers_utils.utils import check_gguf_file from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser, from vllm.utils import (FlexibleArgumentParser, GiB_bytes, get_ip,
GiB_bytes, get_ip, is_in_ray_actor) is_in_ray_actor)
from vllm.v1.sample.logits_processor import LogitsProcessor from vllm.v1.sample.logits_processor import LogitsProcessor
# yapf: enable # yapf: enable
...@@ -1163,17 +1163,6 @@ class EngineArgs: ...@@ -1163,17 +1163,6 @@ class EngineArgs:
self._set_default_args_v0(model_config) self._set_default_args_v0(model_config)
assert self.enable_chunked_prefill is not None assert self.enable_chunked_prefill is not None
if envs.VLLM_ATTENTION_BACKEND in [STR_DUAL_CHUNK_FLASH_ATTN_VAL]:
assert self.enforce_eager, (
"Cuda graph is not supported with DualChunkFlashAttention. "
"To run the model in eager mode, set 'enforce_eager=True' "
"or use '--enforce-eager' in the CLI.")
assert current_platform.is_cuda(), (
"DualChunkFlashAttention is only supported on CUDA platform.")
assert not use_v1, (
"DualChunkFlashAttention is not supported on V1 engine. "
"To run the model in V0 engine, try set 'VLLM_USE_V1=0'")
sliding_window: Optional[int] = None sliding_window: Optional[int] = None
if not is_interleaved(model_config.hf_text_config): if not is_interleaved(model_config.hf_text_config):
# Only set CacheConfig.sliding_window if the model is all sliding # Only set CacheConfig.sliding_window if the model is all sliding
......
...@@ -529,7 +529,6 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -529,7 +529,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
# - "TORCH_SDPA": use torch.nn.MultiheadAttention # - "TORCH_SDPA": use torch.nn.MultiheadAttention
# - "FLASH_ATTN": use FlashAttention # - "FLASH_ATTN": use FlashAttention
# - "XFORMERS": use XFormers # - "XFORMERS": use XFormers
# - "ROCM_FLASH": use ROCmFlashAttention
# - "FLASHINFER": use flashinfer # - "FLASHINFER": use flashinfer
# - "FLASHMLA": use FlashMLA # - "FLASHMLA": use FlashMLA
# - "FLASH_ATTN_MLA": use FlashAttention for MLA # - "FLASH_ATTN_MLA": use FlashAttention for MLA
......
...@@ -53,13 +53,18 @@ class Mamba2Metadata: ...@@ -53,13 +53,18 @@ class Mamba2Metadata:
def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]: def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]:
"""Returns the appropriate metadata classes for the current platform.""" """Returns the appropriate metadata classes for the current platform."""
if current_platform.is_rocm(): if current_platform.is_rocm():
from vllm.attention.backends.rocm_flash_attn import ( from vllm.v1.attention.backends.rocm_aiter_fa import (
ROCmFlashAttentionMetadata) AiterFlashAttentionMetadata)
return (ROCmFlashAttentionMetadata, PlaceholderAttentionMetadata) from vllm.v1.attention.backends.triton_attn import (
elif current_platform.is_cuda(): TritonAttentionMetadata)
from vllm.attention.backends.flash_attn import FlashAttentionMetadata return (AiterFlashAttentionMetadata, TritonAttentionMetadata,
from vllm.attention.backends.xformers import XFormersMetadata PlaceholderAttentionMetadata)
return (FlashAttentionMetadata, XFormersMetadata, if current_platform.is_cuda():
from vllm.v1.attention.backends.flash_attn import (
FlashAttentionMetadata)
from vllm.v1.attention.backends.xformers import (
XFormersAttentionMetadata)
return (FlashAttentionMetadata, XFormersAttentionMetadata,
PlaceholderAttentionMetadata) PlaceholderAttentionMetadata)
raise ValueError( raise ValueError(
f"Unsupported platform for Mamba2: {current_platform.device_type}") f"Unsupported platform for Mamba2: {current_platform.device_type}")
......
...@@ -478,7 +478,8 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -478,7 +478,8 @@ class DeepseekV2MLAAttention(nn.Module):
Main reference: DeepseekV2 paper, and FlashInfer Implementation Main reference: DeepseekV2 paper, and FlashInfer Implementation
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
For more info see MLACommonImpl in: vllm/attention/backends/mla/utils.py For more info see MLACommonImpl in:
vllm/v1/attention/backends/mla/utils.py
""" """
def __init__( def __init__(
......
...@@ -226,8 +226,10 @@ class CudaPlatformBase(Platform): ...@@ -226,8 +226,10 @@ class CudaPlatformBase(Platform):
kv_cache_dtype, block_size, use_v1, use_mla, kv_cache_dtype, block_size, use_v1, use_mla,
has_sink) -> str: has_sink) -> str:
if use_mla: if use_mla:
# TODO(lucas): refactor to be more concise if not use_v1:
# we should probably consider factoring out V1 here raise RuntimeError(
"MLA attention backends require the V1 engine. "
"Set VLLM_USE_V1=1 to enable them.")
from vllm.attention.ops.flashmla import is_flashmla_supported from vllm.attention.ops.flashmla import is_flashmla_supported
from vllm.attention.utils.fa_utils import flash_attn_supports_mla from vllm.attention.utils.fa_utils import flash_attn_supports_mla
...@@ -246,35 +248,17 @@ class CudaPlatformBase(Platform): ...@@ -246,35 +248,17 @@ class CudaPlatformBase(Platform):
use_triton = selected_backend == _Backend.TRITON_MLA or ( use_triton = selected_backend == _Backend.TRITON_MLA or (
selected_backend is None) selected_backend is None)
def _get_version(name, import_suffix) -> str:
if use_v1:
logger.info_once(f"Using {name} backend on V1 engine.")
return f"vllm.v1.attention.backends.mla.{import_suffix}"
else:
logger.info_once(f"Using {name} backend.")
return f"vllm.attention.backends.{import_suffix}"
if use_cutlassmla: if use_cutlassmla:
if use_v1: logger.info_once("Using Cutlass MLA backend on V1 engine.")
logger.info_once("Using Cutlass MLA backend on V1 engine.") return ("vllm.v1.attention.backends.mla."
return ("vllm.v1.attention.backends.mla." "cutlass_mla.CutlassMLABackend")
"cutlass_mla.CutlassMLABackend")
else:
logger.warning(
"Cutlass MLA backend is only supported on V1 engine")
if use_flashinfermla: if use_flashinfermla:
if use_v1: from vllm.v1.attention.backends.utils import (
from vllm.v1.attention.backends.utils import ( set_kv_cache_layout)
set_kv_cache_layout) set_kv_cache_layout("HND")
set_kv_cache_layout("HND") logger.info_once("Using FlashInfer MLA backend on V1 engine.")
logger.info_once( return ("vllm.v1.attention.backends.mla."
"Using FlashInfer MLA backend on V1 engine.") "flashinfer_mla.FlashInferMLABackend")
return ("vllm.v1.attention.backends.mla."
"flashinfer_mla.FlashInferMLABackend")
else:
logger.warning(
"FlashInfer MLA backend is only supported on V1 engine"
)
if use_flashmla: if use_flashmla:
if block_size != 64: if block_size != 64:
logger.warning( logger.warning(
...@@ -282,20 +266,18 @@ class CudaPlatformBase(Platform): ...@@ -282,20 +266,18 @@ class CudaPlatformBase(Platform):
" (currently only supports block size 64).", " (currently only supports block size 64).",
block_size) block_size)
else: else:
return _get_version("FlashMLA", "flashmla.FlashMLABackend") logger.info_once("Using FlashMLA backend on V1 engine.")
if use_flashattn:
if use_v1:
logger.info_once(
"Using FlashAttention MLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla." return ("vllm.v1.attention.backends.mla."
"flashattn_mla.FlashAttnMLABackend") "flashmla.FlashMLABackend")
else: if use_flashattn:
logger.warning( logger.info_once(
"FlashAttention MLA backend is only supported on V1 " "Using FlashAttention MLA backend on V1 engine.")
"engine.") return ("vllm.v1.attention.backends.mla."
"flashattn_mla.FlashAttnMLABackend")
if use_triton: if use_triton:
return _get_version("Triton MLA", logger.info_once("Using Triton MLA backend on V1 engine.")
"triton_mla.TritonMLABackend") return ("vllm.v1.attention.backends.mla."
"triton_mla.TritonMLABackend")
if use_v1: if use_v1:
FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501 FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501
FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501 FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
...@@ -382,78 +364,9 @@ class CudaPlatformBase(Platform): ...@@ -382,78 +364,9 @@ class CudaPlatformBase(Platform):
) )
return FLEX_ATTENTION_V1 return FLEX_ATTENTION_V1
# Backends for V0 engine raise RuntimeError(
if selected_backend == _Backend.XFORMERS: "V0 attention backends have been removed. Set VLLM_USE_V1=1 "
logger.info("Using XFormers backend.") "to select a supported backend.")
return "vllm.attention.backends.xformers.XFormersBackend"
elif selected_backend == _Backend.DUAL_CHUNK_FLASH_ATTN:
logger.info("Using DualChunkFlashAttention backend.")
return ("vllm.attention.backends.dual_chunk_flash_attn."
"DualChunkFlashAttentionBackend")
elif selected_backend == _Backend.DIFFERENTIAL_FLASH_ATTN:
logger.info("Using DifferentialFlashAttention backend.")
return ("vllm.attention.backends.differential_flash_attn."
"DifferentialFlashAttentionBackend")
elif selected_backend == _Backend.FLASH_ATTN:
pass
elif selected_backend:
raise ValueError(
f"Invalid attention backend for {cls.device_name}, "
f"with use_v1: {use_v1} use_mla: {use_mla}")
target_backend = _Backend.FLASH_ATTN
if not cls.has_device_capability(80):
# Volta and Turing NVIDIA GPUs.
logger.info(
"Cannot use FlashAttention-2 backend for Volta and Turing "
"GPUs.")
target_backend = _Backend.XFORMERS
elif dtype not in (torch.float16, torch.bfloat16):
logger.info(
"Cannot use FlashAttention-2 backend for dtype other than "
"torch.float16 or torch.bfloat16.")
target_backend = _Backend.XFORMERS
elif block_size % 16 != 0:
logger.info(
"Cannot use FlashAttention-2 backend for block size not "
"divisible by 16.")
target_backend = _Backend.XFORMERS
# FlashAttn is valid for the model, checking if the package is
# installed.
if target_backend == _Backend.FLASH_ATTN:
try:
import vllm.vllm_flash_attn # noqa: F401
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend, flash_attn_supports_fp8)
supported_sizes = \
FlashAttentionBackend.get_supported_head_sizes()
if head_size not in supported_sizes:
logger.info(
"Cannot use FlashAttention-2 backend for head size %d.",
head_size)
target_backend = _Backend.XFORMERS
fp8_kv_cache = (kv_cache_dtype is not None
and kv_cache_dtype.startswith("fp8"))
if (fp8_kv_cache and not flash_attn_supports_fp8()):
logger.info(
"Cannot use FlashAttention backend for FP8 KV cache.")
target_backend = _Backend.XFORMERS
except ImportError:
logger.info(
"Cannot use FlashAttention-2 backend because the "
"vllm.vllm_flash_attn package is not found. "
"Make sure that vllm_flash_attn was built and installed "
"(on by default).")
target_backend = _Backend.XFORMERS
if target_backend == _Backend.XFORMERS:
logger.info("Using XFormers backend.")
return "vllm.attention.backends.xformers.XFormersBackend"
logger.info("Using Flash Attention backend.")
return "vllm.attention.backends.flash_attn.FlashAttentionBackend"
@classmethod @classmethod
def get_punica_wrapper(cls) -> str: def get_punica_wrapper(cls) -> str:
......
...@@ -191,6 +191,11 @@ class RocmPlatform(Platform): ...@@ -191,6 +191,11 @@ class RocmPlatform(Platform):
kv_cache_dtype, block_size, use_v1, use_mla, kv_cache_dtype, block_size, use_v1, use_mla,
has_sink) -> str: has_sink) -> str:
if use_mla: if use_mla:
if not use_v1:
raise RuntimeError(
"MLA attention backends require the V1 engine. "
"Set VLLM_USE_V1=1 to enable them.")
from vllm.v1.attention.backends.mla.rocm_aiter_mla import ( from vllm.v1.attention.backends.mla.rocm_aiter_mla import (
is_aiter_mla_enabled) is_aiter_mla_enabled)
...@@ -201,39 +206,24 @@ class RocmPlatform(Platform): ...@@ -201,39 +206,24 @@ class RocmPlatform(Platform):
if selected_backend == _Backend.TRITON_MLA: if selected_backend == _Backend.TRITON_MLA:
if block_size != 1: if block_size != 1:
if use_v1: logger.info_once("Using Triton MLA backend on V1 engine.")
logger.info_once( return ("vllm.v1.attention.backends.mla."
"Using Triton MLA backend on V1 engine.") "triton_mla.TritonMLABackend")
return ("vllm.v1.attention.backends.mla." raise ValueError(
"triton_mla.TritonMLABackend") f" The selected backend, {selected_backend.name},"
else: f"does not support block size {block_size}.")
logger.info("Using Triton MLA backend.") if selected_backend in (_Backend.ROCM_AITER_MLA,
return "vllm.attention.backends.triton_mla.TritonMLABackend" # noqa: E501 _Backend.ROCM_AITER_MLA_VLLM_V1):
else:
raise ValueError(
f" The selected backend, {selected_backend.name},"
f"does not support block size {block_size}.")
elif selected_backend == _Backend.ROCM_AITER_MLA \
or selected_backend == _Backend.ROCM_AITER_MLA_VLLM_V1:
if block_size == 1: if block_size == 1:
if use_v1: logger.info("Using AITER MLA backend on V1 engine.")
logger.info("Using AITER MLA backend on V1 engine.") return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
else:
logger.info("Using AITER MLA backend")
return "vllm.attention.backends.rocm_aiter_mla.AiterMLABackend" # noqa: E501
else:
raise ValueError(
f" The selected backend, {selected_backend.name},"
f"does not support block size {block_size}."
"(currently only supports block size 1)")
else:
raise ValueError( raise ValueError(
f" The selected backend, {selected_backend.name}," f" The selected backend, {selected_backend.name},"
f"is not MLA type while requested for MLA backend.") f"does not support block size {block_size}."
"(currently only supports block size 1)")
if selected_backend is None or selected_backend == _Backend.FLASH_ATTN: raise ValueError(
selected_backend = _Backend.ROCM_FLASH f" The selected backend, {selected_backend.name},"
f"is not MLA type while requested for MLA backend.")
if envs.VLLM_USE_V1: if envs.VLLM_USE_V1:
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA \ if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA \
...@@ -245,14 +235,9 @@ class RocmPlatform(Platform): ...@@ -245,14 +235,9 @@ class RocmPlatform(Platform):
logger.info("Using Triton Attention backend on V1 engine.") logger.info("Using Triton Attention backend on V1 engine.")
return ("vllm.v1.attention.backends." return ("vllm.v1.attention.backends."
"triton_attn.TritonAttentionBackend") "triton_attn.TritonAttentionBackend")
if selected_backend == _Backend.ROCM_FLASH: raise RuntimeError(
if not cls.has_device_capability(90): "V0 attention backends have been removed. Set VLLM_USE_V1=1 "
# not Instinct series GPUs. "to select a supported backend.")
logger.info("flash_attn is not supported on NAVI GPUs.")
else:
logger.info("%s is not supported in AMD GPUs.", selected_backend)
logger.info("Using ROCmFlashAttention backend.")
return "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend" # noqa: E501
@classmethod @classmethod
def set_device(cls, device: torch.device) -> None: def set_device(cls, device: torch.device) -> None:
......
...@@ -157,10 +157,8 @@ STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND" ...@@ -157,10 +157,8 @@ STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
# register, corresponding to possible backends # register, corresponding to possible backends
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER" STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA" STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH"
STR_XFORMERS_ATTN_VAL: str = "XFORMERS" STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN" STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_DUAL_CHUNK_FLASH_ATTN_VAL: str = "DUAL_CHUNK_FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID" STR_INVALID_VAL: str = "INVALID"
MB_bytes = 1_000_000 MB_bytes = 1_000_000
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment