Commit d03e4bf6 authored by laibao's avatar laibao
Browse files

feat(attn): ROCm块大小为64倍数(且不等于64)时走FA varlen_fwd_unified

parent 1ea9a3f0
......@@ -316,7 +316,6 @@ if TYPE_CHECKING:
VLLM_USE_CUDA_GRAPH_SIZES: bool = False
VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD: bool = False
VLLM_USE_LIGHTOP_FUSED_TOPP_TOPK: bool = False
VLLM_V1_USE_FA_UNIFIED_ATTN_2D: bool = False
VLLM_ENABLE_RAY_ASYNC_SCHEDULING: bool = False
......@@ -1978,10 +1977,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_ENABLE_RAY_ASYNC_SCHEDULING", "False").lower() in
("true", "1")),
#If set to 1/True, enable the flash attention unified path.
"VLLM_V1_USE_FA_UNIFIED_ATTN_2D":
lambda: (os.environ.get("VLLM_V1_USE_FA_UNIFIED_ATTN_2D", "False").lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -304,9 +304,23 @@ class RocmPlatform(Platform):
f"is not MLA type while requested for MLA backend."
)
if envs.VLLM_USE_FLASH_ATTN_PA and block_size == 64:
logger.info_once("Using Flash Attention backend on V1 engine. (only supports block size 64)")
use_unified_flash = (
block_size is not None
and block_size != 64
and block_size % 64 == 0
)
if envs.VLLM_USE_FLASH_ATTN_PA and (block_size == 64 or use_unified_flash):
if use_unified_flash and block_size != 64:
logger.info_once(
"Using Flash Attention backend with unified varlen kernel on "
"V1 engine. (block size %d, requires block size divisible by 64)",
block_size,
)
else:
logger.info_once(
"Using Flash Attention backend on V1 engine. "
"(only supports block size 64)"
)
return AttentionBackendEnum.FLASH_ATTN.get_path()
else:
os.environ['VLLM_USE_FLASH_ATTN_PA'] = '0'
......
......@@ -33,6 +33,13 @@ if is_flash_attn_varlen_func_available():
vllm_flash_attn_varlen_func,
reshape_and_cache_cuda,
)
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash,
)
try:
from flash_attn import varlen_fwd_unified
except Exception:
varlen_fwd_unified = None
else:
from vllm.v1.attention.backends.fa_utils import (
flash_attn_supports_sinks,
......@@ -112,6 +119,30 @@ class FlashAttentionBackend(AttentionBackend):
@staticmethod
def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
return FlashAttentionMetadataBuilder
@staticmethod
def _use_rocm_unified_kv_layout(
block_size: int | None = None,
key_cache: torch.Tensor | None = None,
value_cache: torch.Tensor | None = None,
) -> bool:
if not current_platform.is_rocm():
return False
if block_size is None:
if key_cache is not None and value_cache is not None:
if key_cache.ndim != 4 or value_cache.ndim != 4:
return False
if key_cache.shape != value_cache.shape:
return False
block_size = key_cache.shape[1]
else:
try:
block_size = get_current_vllm_config().cache_config.block_size
except Exception:
return False
return block_size is not None and block_size != 64 and block_size % 64 == 0
if current_platform.is_rocm():
@staticmethod
......@@ -124,6 +155,9 @@ class FlashAttentionBackend(AttentionBackend):
) -> tuple[tuple[int, ...], tuple[int, ...]]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
if FlashAttentionBackend._use_rocm_unified_kv_layout(block_size):
unified_shape = (num_blocks, block_size, num_kv_heads, head_size)
return (unified_shape, unified_shape)
return (
(num_blocks, num_kv_heads, block_size, head_size),
(num_blocks, num_kv_heads, head_size, block_size),
......@@ -136,20 +170,31 @@ class FlashAttentionBackend(AttentionBackend):
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout = get_kv_cache_layout()
if cache_layout == "NHD" and include_num_layers_dimension:
# (num_blocks, num_layers, block_size, num_kv_heads, head_size)
return (1, 0, 3, 2, 5), (1, 0, 4, 2, 3)
elif cache_layout == "NHD":
if FlashAttentionBackend._use_rocm_unified_kv_layout():
if cache_layout != "NHD":
raise RuntimeError(
"ROCm unified KV layout currently supports NHD only."
)
if include_num_layers_dimension:
# (num_blocks, num_layers, block_size, num_kv_heads, head_size)
return (1, 0, 2, 3, 4), (1, 0, 2, 3, 4)
key_stride_order = (0, 1, 2, 3)
value_stride_order = (0, 1, 2, 3)
elif cache_layout == "HND" and include_num_layers_dimension:
# (num_blocks, num_kv_heads, num_layers, block_size, head_size)
return (1, 2, 0, 3, 4), (1, 2, 0, 4, 3)
elif cache_layout == "HND":
key_stride_order = (0, 1, 2, 3)
value_stride_order = (0, 1, 3, 2)
else:
raise ValueError(f"Unknown cache layout format {cache_layout}.")
if cache_layout == "NHD" and include_num_layers_dimension:
# (num_blocks, num_layers, block_size, num_kv_heads, head_size)
return (1, 0, 3, 2, 5), (1, 0, 4, 2, 3)
elif cache_layout == "NHD":
key_stride_order = (0, 1, 2, 3)
value_stride_order = (0, 1, 2, 3)
elif cache_layout == "HND" and include_num_layers_dimension:
# (num_blocks, num_kv_heads, num_layers, block_size, head_size)
return (1, 2, 0, 3, 4), (1, 2, 0, 4, 3)
elif cache_layout == "HND":
key_stride_order = (0, 1, 2, 3)
value_stride_order = (0, 1, 3, 2)
else:
raise ValueError(f"Unknown cache layout format {cache_layout}.")
return key_stride_order, value_stride_order
else:
@staticmethod
......@@ -774,30 +819,57 @@ class FlashAttentionImpl(AttentionImpl):
print(f"q.shape = {query[:num_actual_tokens].shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
print(f"cu_seqlens_q.shape = {cu_seqlens_q.shape}, max_seqlen_q = {max_seqlen_q}, seqused_k.shape = {seqused_k.shape}, max_seqlen_k = {max_seqlen_k}")
print(f"softmax_scale = {self.scale:.3f}, alibi_slopes = {self.alibi_slopes}, window_size = {self.sliding_window}, block_tables.shape = {block_table.shape}, softcap = {self.logits_soft_cap}, scheduler_metadata = {scheduler_metadata}")
vllm_flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=attn_metadata.causal,
alibi_slopes=self.alibi_slopes,
window_size=sliding_window_size,
block_table=block_table,
softcap=self.logits_soft_cap,
scheduler_metadata=scheduler_metadata,
fa_version=self.vllm_flash_attn_version,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
# num_splits=attn_metadata.max_num_splits,
s_aux=self.sinks,
is_prefix_cache=True,
use_unified_kv_layout = (
FlashAttentionBackend._use_rocm_unified_kv_layout(
key_cache=key_cache, value_cache=value_cache)
)
if use_unified_kv_layout:
varlen_fwd_unified(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_seqlens_q,
seqused_k=seqused_k,
block_table=block_table,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=attn_metadata.causal,
softcap=self.logits_soft_cap,
window_size=tuple(self.sliding_window),
alibi_slopes=self.alibi_slopes,
use_alibi_sqrt=False,
qq_bias=None,
s_aux=self.sinks,
mm_prefix_range=None,
return_softmax_lse=False,
out=output[:num_actual_tokens],
)
else:
vllm_flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=attn_metadata.causal,
alibi_slopes=self.alibi_slopes,
window_size=sliding_window_size,
block_table=block_table,
softcap=self.logits_soft_cap,
scheduler_metadata=scheduler_metadata,
fa_version=self.vllm_flash_attn_version,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
# num_splits=attn_metadata.max_num_splits,
s_aux=self.sinks,
is_prefix_cache=True,
)
else:
flash_attn_varlen_func(
q=query[:num_actual_tokens],
......@@ -889,21 +961,11 @@ class FlashAttentionImpl(AttentionImpl):
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
if current_platform.is_rocm():
if envs.VLLM_USE_OPT_RESHAPE_AND_CACHE:
from lightop import reshape_and_cache_cuda
reshape_and_cache_cuda(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale
)
else:
from vllm.v1.attention.backends.fa_utils import reshape_and_cache_cuda
reshape_and_cache_cuda(
if FlashAttentionBackend._use_rocm_unified_kv_layout(
key_cache=key_cache,
value_cache=value_cache,
):
triton_reshape_and_cache_flash(
key,
value,
key_cache,
......@@ -913,6 +975,31 @@ class FlashAttentionImpl(AttentionImpl):
layer._k_scale,
layer._v_scale,
)
else:
if envs.VLLM_USE_OPT_RESHAPE_AND_CACHE:
from lightop import reshape_and_cache_cuda
reshape_and_cache_cuda(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale
)
else:
from vllm.v1.attention.backends.fa_utils import reshape_and_cache_cuda
reshape_and_cache_cuda(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
else:
reshape_and_cache_flash(
key,
......
......@@ -12,11 +12,6 @@ import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm import envs
try:
from flash_attn import varlen_fwd_unified
except Exception:
varlen_fwd_unified = None
logger = init_logger(__name__)
float8_info = torch.finfo(current_platform.fp8_dtype())
......@@ -988,90 +983,61 @@ def unified_attention(
or num_seqs > seq_threshold_3D
):
# print(f"[2D Triton] k shape: {k.shape}, v shape: {v.shape}")
if not envs.VLLM_V1_USE_FA_UNIFIED_ATTN_2D:
# print("Running Triton kernel")
kernel_unified_attention_2d[
(
total_num_q_blocks,
num_kv_heads,
)
](
output_ptr=out,
query_ptr=q,
key_cache_ptr=k,
value_cache_ptr=v,
sink_ptr=sinks,
block_tables_ptr=block_table,
seq_lens_ptr=seqused_k,
alibi_slopes_ptr=alibi_slopes,
qq_bias_ptr=qq_bias,
scale=softmax_scale,
k_scale=k_descale,
v_scale=v_descale,
out_scale=1 / output_scale if output_scale is not None else 1.0,
softcap=softcap,
num_query_heads=num_query_heads,
num_queries_per_kv=num_queries_per_kv,
block_table_stride=block_table.stride(0),
query_stride_0=q.stride(0),
query_stride_1=q.stride(1),
output_stride_0=out.stride(0),
output_stride_1=out.stride(1),
qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0,
BLOCK_SIZE=block_size,
TILE_SIZE=TILE_SIZE_PREFILL,
HEAD_SIZE=head_size,
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
USE_ALIBI_SLOPES=use_alibi_slopes,
USE_ALIBI_SQRT=use_alibi_sqrt,
USE_QQ_BIAS=use_qq_bias,
USE_SOFTCAP=(softcap > 0),
USE_SINKS=(sinks is not None),
USE_MM_PREFIX=use_mm_prefix,
MAX_MM_RANGES=max_mm_ranges,
mm_prefix_range_ptr=mm_prefix_range,
SLIDING_WINDOW=(1 + window_size[0]),
stride_k_cache_0=k.stride(0),
stride_k_cache_1=k.stride(1),
stride_k_cache_2=k.stride(2),
stride_k_cache_3=k.stride(3),
stride_v_cache_0=v.stride(0),
stride_v_cache_1=v.stride(1),
stride_v_cache_2=v.stride(2),
stride_v_cache_3=v.stride(3),
query_start_len_ptr=cu_seqlens_q,
BLOCK_Q=BLOCK_Q,
num_seqs=num_seqs,
BLOCK_M=BLOCK_M,
USE_FP8=output_scale is not None,
)
else:
if varlen_fwd_unified is None:
raise RuntimeError(
"flash_attn.varlen_fwd_unified is not available in this flash-attn version"
)
# print("Running FA kernel")
varlen_fwd_unified(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
seqused_k=seqused_k,
block_table=block_table,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=causal,
softcap=softcap,
window_size=window_size,
alibi_slopes=alibi_slopes,
use_alibi_sqrt=use_alibi_sqrt,
qq_bias=qq_bias,
s_aux=sinks,
mm_prefix_range=mm_prefix_range,
return_softmax_lse=False,
out=out,
kernel_unified_attention_2d[
(
total_num_q_blocks,
num_kv_heads,
)
](
output_ptr=out,
query_ptr=q,
key_cache_ptr=k,
value_cache_ptr=v,
sink_ptr=sinks,
block_tables_ptr=block_table,
seq_lens_ptr=seqused_k,
alibi_slopes_ptr=alibi_slopes,
qq_bias_ptr=qq_bias,
scale=softmax_scale,
k_scale=k_descale,
v_scale=v_descale,
out_scale=1 / output_scale if output_scale is not None else 1.0,
softcap=softcap,
num_query_heads=num_query_heads,
num_queries_per_kv=num_queries_per_kv,
block_table_stride=block_table.stride(0),
query_stride_0=q.stride(0),
query_stride_1=q.stride(1),
output_stride_0=out.stride(0),
output_stride_1=out.stride(1),
qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0,
BLOCK_SIZE=block_size,
TILE_SIZE=TILE_SIZE_PREFILL,
HEAD_SIZE=head_size,
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
USE_ALIBI_SLOPES=use_alibi_slopes,
USE_ALIBI_SQRT=use_alibi_sqrt,
USE_QQ_BIAS=use_qq_bias,
USE_SOFTCAP=(softcap > 0),
USE_SINKS=(sinks is not None),
USE_MM_PREFIX=use_mm_prefix,
MAX_MM_RANGES=max_mm_ranges,
mm_prefix_range_ptr=mm_prefix_range,
SLIDING_WINDOW=(1 + window_size[0]),
stride_k_cache_0=k.stride(0),
stride_k_cache_1=k.stride(1),
stride_k_cache_2=k.stride(2),
stride_k_cache_3=k.stride(3),
stride_v_cache_0=v.stride(0),
stride_v_cache_1=v.stride(1),
stride_v_cache_2=v.stride(2),
stride_v_cache_3=v.stride(3),
query_start_len_ptr=cu_seqlens_q,
BLOCK_Q=BLOCK_Q,
num_seqs=num_seqs,
BLOCK_M=BLOCK_M,
USE_FP8=output_scale is not None,
)
else:
# print(f"[3D Triton] k shape: {k.shape}, v shape: {v.shape}")
kernel_unified_attention_3d[
......
......@@ -5951,7 +5951,7 @@ class GPUModelRunner(
return kv_caches
def _update_hybrid_attention_mamba_layout(
self, kv_caches: dict[str, torch.Tensor]
self, kv_caches: dict[str, Any]
) -> None:
"""
Update the layout of attention layers from (2, num_blocks, ...) to
......@@ -5965,6 +5965,8 @@ class GPUModelRunner(
kv_cache_spec = group.kv_cache_spec
for layer_name in group.layer_names:
kv_cache = kv_caches[layer_name]
if not isinstance(kv_cache, torch.Tensor):
continue
if isinstance(kv_cache_spec, AttentionSpec) and kv_cache.shape[0] == 2:
assert kv_cache.shape[1] != 2, (
"Fail to determine whether the layout is "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment