Commit 64fc5a29 authored by zhuwenwen's avatar zhuwenwen
Browse files

update VLLM_FLASH_ATTN_BACKEND to VLLM_FLASH_ATTN_V1

parent e036115e
......@@ -143,7 +143,7 @@ if TYPE_CHECKING:
VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX: int = 16
VLLM_ENFORCE_EAGER_BS_THRESHOLD: Optional[int] = None
VLLM_HAS_CONTEXT_DEFAULT: bool = False
VLLM_FLASH_ATTN_BACKEND: bool = False
VLLM_FLASH_ATTN_V1: bool = False
VLLM_USE_NN: bool = False
VLLM_ENABLE_TBO: bool = False
VLLM_TBO_REQ_DELAY_MS: int = 0
......@@ -962,9 +962,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_HAS_CONTEXT_DEFAULT":
lambda: bool(int(os.getenv("VLLM_HAS_CONTEXT_DEFAULT", "0"))),
# If set, vLLM will use FlashAttention Backend for attention computation on rocm
"VLLM_FLASH_ATTN_BACKEND":
lambda: (os.environ.get("VLLM_FLASH_ATTN_BACKEND", "False").lower() in
# If set, vLLM will use FlashAttention Backend for v1 attention computation on rocm
"VLLM_FLASH_ATTN_V1":
lambda: (os.environ.get("VLLM_FLASH_ATTN_V1", "False").lower() in
("true", "1")),
# If set, vLLM will transpose weight to use nn layout
......
......@@ -239,114 +239,16 @@ class RocmPlatform(Platform):
# logger.info("Using AITER MLA backend")
# return "vllm.attention.backends.rocm_aiter_mla.AiterMLABackend" # noqa: E501
if envs.VLLM_FLASH_ATTN_BACKEND:
if use_v1:
if selected_backend == _Backend.FLASHINFER:
raise ValueError("FlashInfer backend on V1 engine is not supported")
# if selected_backend == _Backend.FLEX_ATTENTION:
# logger.info("Using FlexAttenion backend on V1 engine.")
# return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
if selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
logger.info_once("Using Triton backend on V1 engine.")
return ("vllm.v1.attention.backends."
"triton_attn.TritonAttentionBackend")
if cls.is_device_capability(100):
# Prefer FlashInfer for V1 on Blackwell GPUs if installed
try:
import flashinfer # noqa: F401
logger.info_once(
"Using FlashInfer backend on V1 engine by default for "
"Blackwell (SM 10.0) GPUs.")
return ("vllm.v1.attention.backends."
"flashinfer.FlashInferBackend")
except ImportError:
logger.info_once(
"FlashInfer failed to import for V1 engine on "
"Blackwell (SM 10.0) GPUs; it is recommended to "
"install FlashInfer for better performance.")
pass
if cls.has_device_capability(80):
logger.info_once("Using Flash Attention backend on V1 engine.")
return ("vllm.v1.attention.backends."
"flash_attn.FlashAttentionBackend")
if selected_backend == _Backend.FLASHINFER:
raise ValueError("FlashInfer backend is not supported")
elif selected_backend == _Backend.XFORMERS:
raise ValueError("XFormers backend is not supported")
# elif selected_backend == _Backend.DUAL_CHUNK_FLASH_ATTN:
# logger.info("Using DualChunkFlashAttention backend.")
# return ("vllm.attention.backends.dual_chunk_flash_attn."
# "DualChunkFlashAttentionBackend")
elif selected_backend == _Backend.FLASH_ATTN:
pass
elif selected_backend:
raise ValueError(
f"Invalid attention backend for {cls.device_name}, "
f"with use_v1: {use_v1} use_mla: {use_mla}")
target_backend = _Backend.FLASH_ATTN
if not cls.has_device_capability(80):
# Volta and Turing NVIDIA GPUs.
logger.info(
"Cannot use FlashAttention-2 backend for Volta and Turing "
"GPUs.")
raise ValueError("XFormers backend is not supported")
elif dtype not in (torch.float16, torch.bfloat16):
logger.info(
"Cannot use FlashAttention-2 backend for dtype other than "
"torch.float16 or torch.bfloat16.")
raise ValueError("XFormers backend is not supported")
# pass
elif block_size % 16 != 0:
logger.info(
"Cannot use FlashAttention-2 backend for block size not "
"divisible by 16.")
raise ValueError("XFormers backend is not supported")
# FlashAttn is valid for the model, checking if the package is
# installed.
if target_backend == _Backend.FLASH_ATTN:
try:
import flash_attn # noqa: F401
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend, flash_attn_supports_fp8)
supported_sizes = \
FlashAttentionBackend.get_supported_head_sizes()
if head_size not in supported_sizes:
logger.info(
"Cannot use FlashAttention-2 backend for head size %d.",
head_size)
raise ValueError("XFormers backend is not supported")
fp8_kv_cache = (kv_cache_dtype is not None
and kv_cache_dtype.startswith("fp8"))
if (fp8_kv_cache and not flash_attn_supports_fp8()):
logger.info(
"Cannot use FlashAttention backend for FP8 KV cache.")
logger.warning(
"Please use FlashInfer backend with FP8 KV Cache for "
"better performance by setting environment variable "
"VLLM_ATTENTION_BACKEND=FLASHINFER")
raise ValueError("XFormers backend is not supported")
except ImportError:
logger.info(
"Cannot use FlashAttention-2 backend because the "
"vllm.vllm_flash_attn package is not found. "
"Make sure that vllm_flash_attn was built and installed "
"(on by default).")
raise ValueError("XFormers backend is not supported")
if target_backend == _Backend.XFORMERS:
raise ValueError("XFormers backend is not supported")
logger.info("Using Flash Attention backend.")
return "vllm.attention.backends.flash_attn.FlashAttentionBackend"
else:
if selected_backend is None or selected_backend == _Backend.FLASH_ATTN:
selected_backend = _Backend.ROCM_FLASH
if envs.VLLM_USE_V1:
if envs.VLLM_FLASH_ATTN_V1 and block_size == 64:
if cls.has_device_capability(80):
logger.info_once("Using Flash Attention backend on V1 engine. (only supports block size 64)")
return ("vllm.v1.attention.backends."
"flash_attn.FlashAttentionBackend")
else:
logger.info("Using Triton Attention backend on V1 engine.")
return ("vllm.v1.attention.backends."
"triton_attn.TritonAttentionBackend")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment