Commit d03e4bf6 authored by laibao's avatar laibao
Browse files

feat(attn): ROCm块大小为64倍数(且不等于64)时走FA varlen_fwd_unified

parent 1ea9a3f0
......@@ -316,7 +316,6 @@ if TYPE_CHECKING:
VLLM_USE_CUDA_GRAPH_SIZES: bool = False
VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD: bool = False
VLLM_USE_LIGHTOP_FUSED_TOPP_TOPK: bool = False
VLLM_V1_USE_FA_UNIFIED_ATTN_2D: bool = False
VLLM_ENABLE_RAY_ASYNC_SCHEDULING: bool = False
......@@ -1978,10 +1977,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_ENABLE_RAY_ASYNC_SCHEDULING", "False").lower() in
("true", "1")),
#If set to 1/True, enable the flash attention unified path.
"VLLM_V1_USE_FA_UNIFIED_ATTN_2D":
lambda: (os.environ.get("VLLM_V1_USE_FA_UNIFIED_ATTN_2D", "False").lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -304,9 +304,23 @@ class RocmPlatform(Platform):
f"is not MLA type while requested for MLA backend."
)
if envs.VLLM_USE_FLASH_ATTN_PA and block_size == 64:
logger.info_once("Using Flash Attention backend on V1 engine. (only supports block size 64)")
use_unified_flash = (
block_size is not None
and block_size != 64
and block_size % 64 == 0
)
if envs.VLLM_USE_FLASH_ATTN_PA and (block_size == 64 or use_unified_flash):
if use_unified_flash and block_size != 64:
logger.info_once(
"Using Flash Attention backend with unified varlen kernel on "
"V1 engine. (block size %d, requires block size divisible by 64)",
block_size,
)
else:
logger.info_once(
"Using Flash Attention backend on V1 engine. "
"(only supports block size 64)"
)
return AttentionBackendEnum.FLASH_ATTN.get_path()
else:
os.environ['VLLM_USE_FLASH_ATTN_PA'] = '0'
......
......@@ -33,6 +33,13 @@ if is_flash_attn_varlen_func_available():
vllm_flash_attn_varlen_func,
reshape_and_cache_cuda,
)
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash,
)
try:
from flash_attn import varlen_fwd_unified
except Exception:
varlen_fwd_unified = None
else:
from vllm.v1.attention.backends.fa_utils import (
flash_attn_supports_sinks,
......@@ -113,6 +120,30 @@ class FlashAttentionBackend(AttentionBackend):
def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
return FlashAttentionMetadataBuilder
@staticmethod
def _use_rocm_unified_kv_layout(
block_size: int | None = None,
key_cache: torch.Tensor | None = None,
value_cache: torch.Tensor | None = None,
) -> bool:
if not current_platform.is_rocm():
return False
if block_size is None:
if key_cache is not None and value_cache is not None:
if key_cache.ndim != 4 or value_cache.ndim != 4:
return False
if key_cache.shape != value_cache.shape:
return False
block_size = key_cache.shape[1]
else:
try:
block_size = get_current_vllm_config().cache_config.block_size
except Exception:
return False
return block_size is not None and block_size != 64 and block_size % 64 == 0
if current_platform.is_rocm():
@staticmethod
def get_kv_cache_shape(
......@@ -124,6 +155,9 @@ class FlashAttentionBackend(AttentionBackend):
) -> tuple[tuple[int, ...], tuple[int, ...]]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
if FlashAttentionBackend._use_rocm_unified_kv_layout(block_size):
unified_shape = (num_blocks, block_size, num_kv_heads, head_size)
return (unified_shape, unified_shape)
return (
(num_blocks, num_kv_heads, block_size, head_size),
(num_blocks, num_kv_heads, head_size, block_size),
......@@ -136,6 +170,17 @@ class FlashAttentionBackend(AttentionBackend):
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout = get_kv_cache_layout()
if FlashAttentionBackend._use_rocm_unified_kv_layout():
if cache_layout != "NHD":
raise RuntimeError(
"ROCm unified KV layout currently supports NHD only."
)
if include_num_layers_dimension:
# (num_blocks, num_layers, block_size, num_kv_heads, head_size)
return (1, 0, 2, 3, 4), (1, 0, 2, 3, 4)
key_stride_order = (0, 1, 2, 3)
value_stride_order = (0, 1, 2, 3)
else:
if cache_layout == "NHD" and include_num_layers_dimension:
# (num_blocks, num_layers, block_size, num_kv_heads, head_size)
return (1, 0, 3, 2, 5), (1, 0, 4, 2, 3)
......@@ -774,6 +819,33 @@ class FlashAttentionImpl(AttentionImpl):
print(f"q.shape = {query[:num_actual_tokens].shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
print(f"cu_seqlens_q.shape = {cu_seqlens_q.shape}, max_seqlen_q = {max_seqlen_q}, seqused_k.shape = {seqused_k.shape}, max_seqlen_k = {max_seqlen_k}")
print(f"softmax_scale = {self.scale:.3f}, alibi_slopes = {self.alibi_slopes}, window_size = {self.sliding_window}, block_tables.shape = {block_table.shape}, softcap = {self.logits_soft_cap}, scheduler_metadata = {scheduler_metadata}")
use_unified_kv_layout = (
FlashAttentionBackend._use_rocm_unified_kv_layout(
key_cache=key_cache, value_cache=value_cache)
)
if use_unified_kv_layout:
varlen_fwd_unified(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_seqlens_q,
seqused_k=seqused_k,
block_table=block_table,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=attn_metadata.causal,
softcap=self.logits_soft_cap,
window_size=tuple(self.sliding_window),
alibi_slopes=self.alibi_slopes,
use_alibi_sqrt=False,
qq_bias=None,
s_aux=self.sinks,
mm_prefix_range=None,
return_softmax_lse=False,
out=output[:num_actual_tokens],
)
else:
vllm_flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
......@@ -889,6 +961,21 @@ class FlashAttentionImpl(AttentionImpl):
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
if current_platform.is_rocm():
if FlashAttentionBackend._use_rocm_unified_kv_layout(
key_cache=key_cache,
value_cache=value_cache,
):
triton_reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
else:
if envs.VLLM_USE_OPT_RESHAPE_AND_CACHE:
from lightop import reshape_and_cache_cuda
reshape_and_cache_cuda(
......
......@@ -12,11 +12,6 @@ import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm import envs
try:
from flash_attn import varlen_fwd_unified
except Exception:
varlen_fwd_unified = None
logger = init_logger(__name__)
float8_info = torch.finfo(current_platform.fp8_dtype())
......@@ -988,8 +983,6 @@ def unified_attention(
or num_seqs > seq_threshold_3D
):
# print(f"[2D Triton] k shape: {k.shape}, v shape: {v.shape}")
if not envs.VLLM_V1_USE_FA_UNIFIED_ATTN_2D:
# print("Running Triton kernel")
kernel_unified_attention_2d[
(
total_num_q_blocks,
......@@ -1045,33 +1038,6 @@ def unified_attention(
BLOCK_M=BLOCK_M,
USE_FP8=output_scale is not None,
)
else:
if varlen_fwd_unified is None:
raise RuntimeError(
"flash_attn.varlen_fwd_unified is not available in this flash-attn version"
)
# print("Running FA kernel")
varlen_fwd_unified(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
seqused_k=seqused_k,
block_table=block_table,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=causal,
softcap=softcap,
window_size=window_size,
alibi_slopes=alibi_slopes,
use_alibi_sqrt=use_alibi_sqrt,
qq_bias=qq_bias,
s_aux=sinks,
mm_prefix_range=mm_prefix_range,
return_softmax_lse=False,
out=out,
)
else:
# print(f"[3D Triton] k shape: {k.shape}, v shape: {v.shape}")
kernel_unified_attention_3d[
......
......@@ -5951,7 +5951,7 @@ class GPUModelRunner(
return kv_caches
def _update_hybrid_attention_mamba_layout(
self, kv_caches: dict[str, torch.Tensor]
self, kv_caches: dict[str, Any]
) -> None:
"""
Update the layout of attention layers from (2, num_blocks, ...) to
......@@ -5965,6 +5965,8 @@ class GPUModelRunner(
kv_cache_spec = group.kv_cache_spec
for layer_name in group.layer_names:
kv_cache = kv_caches[layer_name]
if not isinstance(kv_cache, torch.Tensor):
continue
if isinstance(kv_cache_spec, AttentionSpec) and kv_cache.shape[0] == 2:
assert kv_cache.shape[1] != 2, (
"Fail to determine whether the layout is "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment