Commit eb35ba1b authored by fanwl's avatar fanwl
Browse files

Add FA Unified Attention 2D

- Add VLLM_V1_USE_FA_UNIFIED_ATTN_2D 环境变量
- 0: Triton attention, 1: FA unified attention
parent 3f414133
......@@ -310,6 +310,7 @@ if TYPE_CHECKING:
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: bool = False
VLLM_USE_CUDA_GRAPH_SIZES: bool = False
VLLM_USE_LIGHTOP_FUSED_TOPP_TOPK: bool = False
VLLM_V1_USE_FA_UNIFIED_ATTN_2D: bool = False
def get_default_cache_root():
......@@ -1940,6 +1941,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_LIGHTOP_FUSED_TOPP_TOPK":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_FUSED_TOPP_TOPK", "False").lower() in
("true", "1")),
#If set to 1/True, enable the flash attention unified path.
"VLLM_V1_USE_FA_UNIFIED_ATTN_2D":
lambda: (os.environ.get("VLLM_V1_USE_FA_UNIFIED_ATTN_2D", "False").lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -426,7 +426,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
dtype=kv_cache_dtype,
).page_size_bytes
else:
kernel_block_alignment_size = 16
kernel_block_alignment_size = 64
if (
current_platform.is_device_capability_family(100)
and model_config.get_head_size() == 256
......
......@@ -12,6 +12,10 @@ import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm import envs
from flash_attn import (
varlen_fwd_unified,
)
logger = init_logger(__name__)
float8_info = torch.finfo(current_platform.fp8_dtype())
......@@ -982,62 +986,89 @@ def unified_attention(
or max_seqlen_q > 1
or num_seqs > seq_threshold_3D
):
kernel_unified_attention_2d[
(
total_num_q_blocks,
num_kv_heads,
# print(f"[2D Triton] k shape: {k.shape}, v shape: {v.shape}")
if not envs.VLLM_V1_USE_FA_UNIFIED_ATTN_2D:
# print("Running Triton kernel")
kernel_unified_attention_2d[
(
total_num_q_blocks,
num_kv_heads,
)
](
output_ptr=out,
query_ptr=q,
key_cache_ptr=k,
value_cache_ptr=v,
sink_ptr=sinks,
block_tables_ptr=block_table,
seq_lens_ptr=seqused_k,
alibi_slopes_ptr=alibi_slopes,
qq_bias_ptr=qq_bias,
scale=softmax_scale,
k_scale=k_descale,
v_scale=v_descale,
out_scale=1 / output_scale if output_scale is not None else 1.0,
softcap=softcap,
num_query_heads=num_query_heads,
num_queries_per_kv=num_queries_per_kv,
block_table_stride=block_table.stride(0),
query_stride_0=q.stride(0),
query_stride_1=q.stride(1),
output_stride_0=out.stride(0),
output_stride_1=out.stride(1),
qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0,
BLOCK_SIZE=block_size,
TILE_SIZE=TILE_SIZE_PREFILL,
HEAD_SIZE=head_size,
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
USE_ALIBI_SLOPES=use_alibi_slopes,
USE_ALIBI_SQRT=use_alibi_sqrt,
USE_QQ_BIAS=use_qq_bias,
USE_SOFTCAP=(softcap > 0),
USE_SINKS=(sinks is not None),
USE_MM_PREFIX=use_mm_prefix,
MAX_MM_RANGES=max_mm_ranges,
mm_prefix_range_ptr=mm_prefix_range,
SLIDING_WINDOW=(1 + window_size[0]),
stride_k_cache_0=k.stride(0),
stride_k_cache_1=k.stride(1),
stride_k_cache_2=k.stride(2),
stride_k_cache_3=k.stride(3),
stride_v_cache_0=v.stride(0),
stride_v_cache_1=v.stride(1),
stride_v_cache_2=v.stride(2),
stride_v_cache_3=v.stride(3),
query_start_len_ptr=cu_seqlens_q,
BLOCK_Q=BLOCK_Q,
num_seqs=num_seqs,
BLOCK_M=BLOCK_M,
USE_FP8=output_scale is not None,
)
else:
# print("Running FA kernel")
varlen_fwd_unified(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
seqused_k=seqused_k,
block_table=block_table,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=causal,
softcap=softcap,
window_size=window_size,
alibi_slopes=alibi_slopes,
use_alibi_sqrt=use_alibi_sqrt,
qq_bias=qq_bias,
s_aux=sinks,
mm_prefix_range=mm_prefix_range,
return_softmax_lse=False,
out=out,
)
](
output_ptr=out,
query_ptr=q,
key_cache_ptr=k,
value_cache_ptr=v,
sink_ptr=sinks,
block_tables_ptr=block_table,
seq_lens_ptr=seqused_k,
alibi_slopes_ptr=alibi_slopes,
qq_bias_ptr=qq_bias,
scale=softmax_scale,
k_scale=k_descale,
v_scale=v_descale,
out_scale=1 / output_scale if output_scale is not None else 1.0,
softcap=softcap,
num_query_heads=num_query_heads,
num_queries_per_kv=num_queries_per_kv,
block_table_stride=block_table.stride(0),
query_stride_0=q.stride(0),
query_stride_1=q.stride(1),
output_stride_0=out.stride(0),
output_stride_1=out.stride(1),
qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0,
BLOCK_SIZE=block_size,
TILE_SIZE=TILE_SIZE_PREFILL,
HEAD_SIZE=head_size,
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
USE_ALIBI_SLOPES=use_alibi_slopes,
USE_ALIBI_SQRT=use_alibi_sqrt,
USE_QQ_BIAS=use_qq_bias,
USE_SOFTCAP=(softcap > 0),
USE_SINKS=(sinks is not None),
USE_MM_PREFIX=use_mm_prefix,
MAX_MM_RANGES=max_mm_ranges,
mm_prefix_range_ptr=mm_prefix_range,
SLIDING_WINDOW=(1 + window_size[0]),
stride_k_cache_0=k.stride(0),
stride_k_cache_1=k.stride(1),
stride_k_cache_2=k.stride(2),
stride_k_cache_3=k.stride(3),
stride_v_cache_0=v.stride(0),
stride_v_cache_1=v.stride(1),
stride_v_cache_2=v.stride(2),
stride_v_cache_3=v.stride(3),
query_start_len_ptr=cu_seqlens_q,
BLOCK_Q=BLOCK_Q,
num_seqs=num_seqs,
BLOCK_M=BLOCK_M,
USE_FP8=output_scale is not None,
)
else:
# print(f"[3D Triton] k shape: {k.shape}, v shape: {v.shape}")
kernel_unified_attention_3d[
(total_num_q_blocks, num_kv_heads, num_par_softmax_segments)
](
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment