Commit 82f60bef authored by zhuwenwen's avatar zhuwenwen
Browse files

set VLLM_FLASH_ATTN_BACKEND to use FlashAttention Backend for attention computation on rocm

parent 07b41ddf
......@@ -27,7 +27,12 @@ from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
from vllm.logger import init_logger
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
from vllm.platforms import current_platform
if not current_platform.is_rocm():
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
flash_attn_with_kvcache)
else:
from flash_attn import (flash_attn_varlen_func, vllm_flash_attn_varlen_func,
flash_attn_with_kvcache)
if TYPE_CHECKING:
......@@ -807,6 +812,7 @@ class FlashAttentionImpl(AttentionImpl):
(num_kv_tokens, num_kv_heads, head_size))
descale_shape = (q_seq_start_loc.shape[0] - 1, key.shape[1])
if not current_platform.is_rocm():
flash_attn_varlen_func(
q=query,
k=key,
......@@ -826,6 +832,21 @@ class FlashAttentionImpl(AttentionImpl):
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
else:
prefill_output = flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=q_seq_start_loc,
cu_seqlens_k=k_seq_start_loc,
max_seqlen_q=q_seq_len,
max_seqlen_k=k_seq_len,
softmax_scale=softmax_scale,
causal=_get_causal_option(attn_type),
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
)
else:
# prefix-enabled attention
assert attn_type == AttentionType.DECODER, (
......@@ -835,6 +856,7 @@ class FlashAttentionImpl(AttentionImpl):
max_seq_len = max(prefill_meta.seq_lens)
descale_shape = (prefill_meta.query_start_loc.shape[0] - 1,
key.shape[1])
if not current_platform.is_rocm():
flash_attn_varlen_func( # noqa
q=query,
k=key_cache,
......@@ -855,6 +877,27 @@ class FlashAttentionImpl(AttentionImpl):
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
else:
vllm_flash_attn_varlen_func( # noqa
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
seqused_k=prefill_meta.seq_lens_tensor,
max_seqlen_k=max_seq_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
block_table=prefill_meta.block_tables,
softcap=logits_soft_cap,
out=prefill_output,
fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
......@@ -870,6 +913,7 @@ class FlashAttentionImpl(AttentionImpl):
assert decode_meta.query_start_loc is not None
descale_shape = (decode_meta.query_start_loc.shape[0] - 1,
key.shape[1])
if not current_platform.is_rocm():
flash_attn_varlen_func(
q=decode_query,
k=key_cache,
......@@ -890,6 +934,22 @@ class FlashAttentionImpl(AttentionImpl):
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
else:
decode_output = flash_attn_varlen_func(
q=decode_query,
k=key_cache,
v=value_cache,
cu_seqlens_q=decode_meta.query_start_loc,
max_seqlen_q=decode_meta.max_decode_query_len,
seqused_k=decode_meta.seq_lens_tensor,
max_seqlen_k=decode_meta.max_decode_seq_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
block_table=decode_meta.block_tables,
)
else:
# Use flash_attn_with_kvcache for normal decoding.
(
......@@ -898,6 +958,7 @@ class FlashAttentionImpl(AttentionImpl):
block_tables_arg,
) = get_seq_len_block_table_args(decode_meta, False, attn_type)
descale_shape = (seq_lens_arg.shape[0], key_cache.shape[-2])
if not current_platform.is_rocm():
flash_attn_with_kvcache(
q=decode_query.unsqueeze(1),
k_cache=key_cache,
......@@ -915,6 +976,20 @@ class FlashAttentionImpl(AttentionImpl):
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
else:
decode_output = decode_output.unsqueeze(1)
decode_output = flash_attn_with_kvcache(
q=decode_query.unsqueeze(1),
k_cache=key_cache,
v_cache=value_cache,
block_table=block_tables_arg,
cache_seqlens=seq_lens_arg,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
)
return output
......
......@@ -124,7 +124,7 @@ if TYPE_CHECKING:
VLLM_PCIE_USE_CUSTOM_ALLREDUCE: bool = False
VLLM_ENFORCE_EAGER_BS_THRESHOLD: Optional[int] = None
VLLM_HAS_CONTEXT_DEFAULT: bool = False
VLLM_FLASH_ATTN_BACKEND: bool = False
VLLM_ENABLE_TBO: bool = False
VLLM_ZERO_OVERHEAD: bool = False
......@@ -799,14 +799,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ENFORCE_EAGER_BS_THRESHOLD":
lambda: int(os.environ.get("VLLM_ENFORCE_EAGER_BS_THRESHOLD", "-1")),
# Enable two batch overlap.
"VLLM_ENABLE_TBO":
lambda: bool(int(os.getenv("VLLM_ENABLE_TBO", "0"))),
# Enable zero overhead scheduler.
"VLLM_ZERO_OVERHEAD":
lambda: bool(int(os.getenv("VLLM_ZERO_OVERHEAD", "0"))),
# 'has_comtext' is a variable in common.py, which is calculated
# by metadata by default. However, it may introduce synchronization
# and affect performance, so it is directly assigned as False.
......@@ -814,6 +806,19 @@ environment_variables: dict[str, Callable[[], Any]] = {
# to restore the default usage.
"VLLM_HAS_CONTEXT_DEFAULT":
lambda: bool(int(os.getenv("VLLM_HAS_CONTEXT_DEFAULT", "0"))),
# If set, vLLM will use FlashAttention Backend for attention computation on rocm
"VLLM_FLASH_ATTN_BACKEND":
lambda: (os.environ.get("VLLM_FLASH_ATTN_BACKEND", "False").lower() in
("true", "1")),
# Enable two batch overlap.
"VLLM_ENABLE_TBO":
lambda: bool(int(os.getenv("VLLM_ENABLE_TBO", "0"))),
# Enable zero overhead scheduler.
"VLLM_ZERO_OVERHEAD":
lambda: bool(int(os.getenv("VLLM_ZERO_OVERHEAD", "0"))),
}
# end-env-vars-definition
......
......@@ -205,6 +205,87 @@ class RocmPlatform(Platform):
# f" The selected backend, {selected_backend.name},"
# f"is not MLA type while requested for MLA backend.")
if envs.VLLM_FLASH_ATTN_BACKEND:
if use_v1:
if selected_backend == _Backend.FLASHINFER:
raise ValueError("FlashInfer backend on V1 engine is not supported")
if selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
logger.info_once("Using Triton backend on V1 engine.")
return ("vllm.v1.attention.backends."
"triton_attn.TritonAttentionBackend")
if cls.has_device_capability(80):
logger.info_once("Using Flash Attention backend on V1 engine.")
return ("vllm.v1.attention.backends."
"flash_attn.FlashAttentionBackend")
if selected_backend == _Backend.FLASHINFER:
raise ValueError("FlashInfer backend is not supported")
elif selected_backend == _Backend.XFORMERS:
raise ValueError("XFormers backend is not supported")
elif selected_backend == _Backend.FLASH_ATTN:
pass
elif selected_backend:
raise ValueError(
f"Invalid attention backend for {cls.device_name}, "
f"with use_v1: {use_v1} use_mla: {use_mla}")
target_backend = _Backend.FLASH_ATTN
if not cls.has_device_capability(80):
# Volta and Turing NVIDIA GPUs.
logger.info(
"Cannot use FlashAttention-2 backend for Volta and Turing "
"GPUs.")
raise ValueError("XFormers backend is not supported")
elif dtype not in (torch.float16, torch.bfloat16):
logger.info(
"Cannot use FlashAttention-2 backend for dtype other than "
"torch.float16 or torch.bfloat16.")
raise ValueError("XFormers backend is not supported")
elif block_size % 16 != 0:
logger.info(
"Cannot use FlashAttention-2 backend for block size not "
"divisible by 16.")
raise ValueError("XFormers backend is not supported")
# FlashAttn is valid for the model, checking if the package is
# installed.
if target_backend == _Backend.FLASH_ATTN:
try:
import flash_attn # noqa: F401
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend, flash_attn_supports_fp8)
supported_sizes = \
FlashAttentionBackend.get_supported_head_sizes()
if head_size not in supported_sizes:
logger.info(
"Cannot use FlashAttention-2 backend for head size %d.",
head_size)
raise ValueError("XFormers backend is not supported")
fp8_kv_cache = (kv_cache_dtype is not None
and kv_cache_dtype.startswith("fp8"))
if (fp8_kv_cache and not flash_attn_supports_fp8()):
logger.info(
"Cannot use FlashAttention backend for FP8 KV cache.")
logger.warning(
"Please use FlashInfer backend with FP8 KV Cache for "
"better performance by setting environment variable "
"VLLM_ATTENTION_BACKEND=FLASHINFER")
raise ValueError("XFormers backend is not supported")
except ImportError:
logger.info(
"Cannot use FlashAttention-2 backend because the "
"flash_attn package is not found. "
"Make sure that flash_attn was built and installed "
"(on by default).")
raise ValueError("XFormers backend is not supported")
if target_backend == _Backend.XFORMERS:
raise ValueError("XFormers backend is not supported")
logger.info("Using Flash Attention backend.")
return "vllm.attention.backends.flash_attn.FlashAttentionBackend"
else:
selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend)
if envs.VLLM_USE_V1:
......
......@@ -24,11 +24,11 @@ if TYPE_CHECKING:
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
if current_platform.is_cuda():
if not current_platform.is_rocm():
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
get_scheduler_metadata)
else:
from flash_attn import flash_attn_varlen_func
from flash_attn import flash_attn_varlen_func, vllm_flash_attn_varlen_func
logger = init_logger(__name__)
......@@ -605,6 +605,7 @@ class FlashAttentionImpl(AttentionImpl):
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
if not current_platform.is_rocm():
flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
......@@ -626,11 +627,30 @@ class FlashAttentionImpl(AttentionImpl):
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
else:
vllm_flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
softcap=self.logits_soft_cap,
# scheduler_metadata=scheduler_metadata,
)
return output
assert not use_local_attn, (
"Cascade attention does not support local attention.")
# Cascade attention (rare case).
if not current_platform.is_rocm():
cascade_attention(
output[:num_actual_tokens],
query[:num_actual_tokens],
......@@ -656,6 +676,8 @@ class FlashAttentionImpl(AttentionImpl):
v_descale=layer._v_scale,
)
return output
else:
raise ValueError("cascade attention is not supported on rocm")
def use_cascade_attention(
......@@ -763,6 +785,7 @@ def cascade_attention(
descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2])
# Process shared prefix.
if not current_platform.is_rocm():
prefix_output, prefix_lse = flash_attn_varlen_func(
q=query,
k=key_cache,
......@@ -790,6 +813,7 @@ def cascade_attention(
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
# Process suffix per query.
if not current_platform.is_rocm():
suffix_output, suffix_lse = flash_attn_varlen_func(
q=query,
k=key_cache,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment