Commit 518a5f4d authored by hly's avatar hly
Browse files

import aicc-master-dev

parent c2a1b310
// Copyright (c) 2025, Xin Zhou.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "../flash_fwd_launch_template.h"
template<>
void run_fp8_mha_fwd_prefix_prefill_<BFloat16, 256, 256>(Flash_fwd_params &params, hipStream_t stream) {
#ifdef BUILD_FA_FWD
run_fp8_flash_fwd_prefix_prefill<BFloat16, 256, 256>(params, stream);
#endif
}
// Copyright (c) 2025, Xin Zhou.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "../flash_fwd_launch_template.h"
template<>
void run_fp8_mha_fwd_prefix_prefill_<Float16, 256, 256>(Flash_fwd_params &params, hipStream_t stream) {
#ifdef BUILD_FA_FWD
run_fp8_flash_fwd_prefix_prefill<Float16, 256, 256>(params, stream);
#endif
}
// Copyright (c) 2025, Wenjian Zhang.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "../flash_fwd_launch_template.h"
template<>
void run_fp8_mha_fwd_prefix_prefill_<BFloat16, 192, 128>(Flash_fwd_params &params, hipStream_t stream) {
#ifdef BUILD_FA_FWD
run_fp8_flash_fwd_prefix_prefill<BFloat16, 192, 128>(params, stream);
#endif
}
// Copyright (c) 2025, Wenjian Zhang.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "../flash_fwd_launch_template.h"
template<>
void run_fp8_mha_fwd_prefix_prefill_<Float16, 192, 128>(Flash_fwd_params &params, hipStream_t stream) {
#ifdef BUILD_FA_FWD
run_fp8_flash_fwd_prefix_prefill<Float16, 192, 128>(params, stream);
#endif
}
#ifdef BUILD_FA_PERMUTE
#include <hip/hip_runtime.h>
#include "../../include/intrinsic.h"
#include "../flash_fwd_permute_hdim128.h"
template<>
......@@ -140,28 +139,22 @@ __global__ void flash_fwd_varlen_permute_bhsd2bshd<128, 4, 32>(
int32_t block_offset = seqlen_limit * kHeadDim;
int32_t thread_offset = lane_id_col * 8;
// 一次读取 4x128 的 Half 到 LDS
#if defined(__gfx936__) || defined(__gfx938__)
{
auto *lds_ptr = (__attribute__((address_space(3))) int *)(
reinterpret_cast<size_t>(lds) + static_cast<size_t>(lane_id * 4) * sizeof(float));
__builtin_hcu_raw_buffer_load_lds(
read_buffer,
lds_ptr,
16,
(block_offset + thread_offset) << 1, /* v_offset */
0, /* s_offset */
0, /* immediate offset, instruction offset */
0 /* auxilariy data| bit 0: glc, bit 1: slc, bit 2: dlc, bit 3: cache swizzle */
);
}
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
__builtin_hcu_raw_buffer_load_lds(
read_buffer,
lds + lane_id * 4,
16,
(block_offset + thread_offset) << 1, /* v_offset */
0, /* s_offset */
0, /* immediate offset, instruction offset */
0 /* auxilariy data| bit 0: glc, bit 1: slc, bit 2: dlc, bit 3: cache swizzle */
);
#else
*(vec4_fp32*)(lds + lane_id * 4) = *(vec4_fp32*)(read_ptr + ((block_offset + thread_offset) >> 1));
#endif
// 从 LDS 转置后, 64 个线程写 4 行, 每次写 128 个 Half, 对应 fetch * 4 + [0,3] 的 seqlen
vec2_fp32 data0, data1;
inlineasm_fa_ds_read2_b32(lds, lane_id, data0, 0, 64);
inlineasm_fa_ds_read2_b32(lds, lane_id + 128, data1, 0, 64);
asm volatile("s_waitcnt lgkmcnt(0)\n");
vec2_fp32 data0 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float*)lds + lane_id, 0, 64, false);
vec2_fp32 data1 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float*)lds + lane_id + 128, 0, 64, false);
write_ptr[(min(actual_seqlen - 1, fetch * 4 + 0) * num_heads * kHeadDim + (lane_id << 1)) >> 1] = data0[0];
write_ptr[(min(actual_seqlen - 1, fetch * 4 + 1) * num_heads * kHeadDim + (lane_id << 1)) >> 1] = data0[1];
write_ptr[(min(actual_seqlen - 1, fetch * 4 + 2) * num_heads * kHeadDim + (lane_id << 1)) >> 1] = data1[0];
......@@ -370,4 +363,4 @@ __global__ void flash_fwd_varlen_permute_bhsd2bshd<256, 1, 32>(
#endif
#endif
\ No newline at end of file
#ifdef BUILD_FA_PERMUTE
#include <hip/hip_runtime.h>
#include "../../include/intrinsic.h"
#include "../flash_fwd_permute_hdim128.h"
......@@ -118,29 +117,23 @@ __global__ void flash_fwd_varlen_permute_bshd2bhsd<128, 4, 0>(
// 接下来, 这个 block 要读取 4x128 的内容, 15 个线程读取一行 128 个 half(这里写死了 head_dim = 128), 每个线程读取 8 个 half
int32_t thread_offset = lane_id_row * 128 + lane_id_col * 8;
// block 地址 + thread 地址, << 1 是获取偏移的字节数, 写到 lds 是为了转置一下
#if defined(__gfx936__) || defined(__gfx938__)
{
auto *lds_ptr = (__attribute__((address_space(3))) int *)(
reinterpret_cast<size_t>(lds) + static_cast<size_t>(lane_id * 4) * sizeof(float));
__builtin_hcu_raw_buffer_load_lds(
read_buffer,
lds_ptr,
16,
(block_offset + thread_offset) << 1, /* v_offset */
0, /* s_offset */
0, /* immediate offset, instruction offset */
0 /* auxilariy data| bit 0: glc, bit 1: slc, bit 2: dlc, bit 3: cache swizzle */
);
}
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
__builtin_hcu_raw_buffer_load_lds(
read_buffer,
lds + lane_id * 4,
16,
(block_offset + thread_offset) << 1, /* v_offset */
0, /* s_offset */
0, /* immediate offset, instruction offset */
0 /* auxilariy data| bit 0: glc, bit 1: slc, bit 2: dlc, bit 3: cache swizzle */
);
#else
*(vec4_fp32*)(lds + lane_id * 4) = *(vec4_fp32*)(read_ptr + ((block_offset + thread_offset) >> 1));
#endif
// 写到 lds 不需要同步, 因为只有一个 wave
// 循环 seqlen_q 次, 每次间隔 4 x 128 个 half, 需要写 4 次
vec2_fp32 data0, data1;
inlineasm_fa_ds_read2_b32(lds, lane_id, data0, 0, 64);
inlineasm_fa_ds_read2_b32(lds, lane_id + 128, data1, 0, 64);
asm volatile("s_waitcnt lgkmcnt(0)\n");
vec2_fp32 data0 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float*)lds + lane_id, 0, 64, false);
vec2_fp32 data1 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float*)lds + lane_id + 128, 0, 64, false);
write_ptr[(fetch * head_dim + (lane_id << 1) + 0 * cur_seqlen_q * head_dim) >> 1] = data0[0];
write_ptr[(fetch * head_dim + (lane_id << 1) + 1 * cur_seqlen_q * head_dim) >> 1] = data0[1];
write_ptr[(fetch * head_dim + (lane_id << 1) + 2 * cur_seqlen_q * head_dim) >> 1] = data1[0];
......@@ -267,29 +260,23 @@ __global__ void flash_fwd_varlen_permute_bshd2bhsd<128, 4, 32>(
// 接下来, 这个 block 要读取 4x128 的内容, 15 个线程读取一行 128 个 half(这里写死了 head_dim = 128), 每个线程读取 8 个 half
int32_t thread_offset = lane_id_row * 128 + lane_id_col * 8;
// block 地址 + thread 地址, << 1 是获取偏移的字节数, 写到 lds 是为了转置一下
#if defined(__gfx936__) || defined(__gfx938__)
{
auto *lds_ptr = (__attribute__((address_space(3))) int *)(
reinterpret_cast<size_t>(lds) + static_cast<size_t>(lane_id * 4) * sizeof(float));
__builtin_hcu_raw_buffer_load_lds(
read_buffer,
lds_ptr,
16,
(block_offset + thread_offset) << 1, /* v_offset */
0, /* s_offset */
0, /* immediate offset, instruction offset */
0 /* auxilariy data| bit 0: glc, bit 1: slc, bit 2: dlc, bit 3: cache swizzle */
);
}
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
__builtin_hcu_raw_buffer_load_lds(
read_buffer,
lds + lane_id * 4,
16,
(block_offset + thread_offset) << 1, /* v_offset */
0, /* s_offset */
0, /* immediate offset, instruction offset */
0 /* auxilariy data| bit 0: glc, bit 1: slc, bit 2: dlc, bit 3: cache swizzle */
);
#else
*(vec4_fp32*)(lds + lane_id * 4) = *(vec4_fp32*)(read_ptr + ((block_offset + thread_offset) >> 1));
#endif
// 写到 lds 不需要同步, 因为只有一个 wave
// 循环 seqlen_q 次, 每次间隔 4 x 128 个 half, 需要写 4 次
vec2_fp32 data0, data1;
inlineasm_fa_ds_read2_b32(lds, lane_id, data0, 0, 64);
inlineasm_fa_ds_read2_b32(lds, lane_id + 128, data1, 0, 64);
asm volatile("s_waitcnt lgkmcnt(0)\n");
vec2_fp32 data0 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float*)lds + lane_id, 0, 64, false);
vec2_fp32 data1 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float*)lds + lane_id + 128, 0, 64, false);
write_ptr[(seqlen_limit * head_dim + (lane_id << 1) + 0 * cur_seqlen_q * head_dim) >> 1] = data0[0];
write_ptr[(seqlen_limit * head_dim + (lane_id << 1) + 1 * cur_seqlen_q * head_dim) >> 1] = data0[1];
write_ptr[(seqlen_limit * head_dim + (lane_id << 1) + 2 * cur_seqlen_q * head_dim) >> 1] = data1[0];
......@@ -352,7 +339,7 @@ __global__ void __launch_bounds__(64, 1) flash_fwd_varlen_permute_bshd2bhsd<128,
// 接下来, 这个 block 要读取 4x128 的内容, 15 个线程读取一行 128 个 half(这里写死了 head_dim = 128), 每个线程读取 8 个 half
int32_t thread_offset = lane_id_row * 128 + lane_id_col * 8;
// block 地址 + thread 地址, << 1 是获取偏移的字节数, 写到 lds 是为了转置一下
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
int m0_offset = reinterpret_cast<size_t>(lds) + (fetch * 256 << 2);
int offset_v = (block_offset + thread_offset) << 1;
asm volatile(
......@@ -377,15 +364,14 @@ __global__ void __launch_bounds__(64, 1) flash_fwd_varlen_permute_bshd2bhsd<128,
// 把所有的 buffer_load 指令下发之后, 再从 lds 开始读取
#pragma unroll
for (int32_t fetch = 0; fetch < SEQLEN_PER_BLOCK; ++fetch) {
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(%0)\n" :: "B"(SEQLEN_PER_BLOCK - fetch - 1));
__builtin_amdgcn_sched_barrier(0);
#endif
// 循环 seqlen_q 次, 每次间隔 4 x 128 个 half, 需要写 4 次
inlineasm_fa_ds_read2_b32(lds, fetch * 256 + lane_id, registers_buffer[fetch * 2], 0, 64);
inlineasm_fa_ds_read2_b32(lds, fetch * 256 + lane_id + 128, registers_buffer[fetch * 2 + 1], 0, 64);
asm volatile("s_waitcnt lgkmcnt(0)\n");
registers_buffer[fetch * 2] = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float*)lds + fetch * 256 + lane_id, 0, 64, false);
registers_buffer[fetch * 2 + 1] = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float*)lds + fetch * 256 + lane_id + 128, 0, 64, false);
}
__builtin_amdgcn_sched_barrier(0);
......@@ -394,7 +380,7 @@ __global__ void __launch_bounds__(64, 1) flash_fwd_varlen_permute_bshd2bhsd<128,
for (int32_t fetch = 0; fetch < SEQLEN_PER_BLOCK; ++fetch) {
// 限制边界
int32_t seqlen_limit = min(actual_seqlen - 1, fetch);
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
// 计算固定的偏移, 字节数目
int32_t v_addr = (seqlen_limit * head_dim << 1) + (lane_id << 2);
// 循环 seqlen_q 次, 每次间隔 4 x 128 个 half, 需要写 4 次
......@@ -688,4 +674,4 @@ template<>
__global__ void __launch_bounds__(64, 1) flash_fwd_varlen_permute_bshd2bhsd<256, 4, 32>(
void* output, void* query, void* split_sizes, int64_t head_stride, int32_t num_heads, int real_headdim) {}
#endif // end of BUILD_FA_PERMUTE
#endif // end of BUILD_FA_PERMUTE
\ No newline at end of file
......@@ -8,6 +8,8 @@ if torch.cuda.is_available():
flash_attn_qkvpacked_func,
flash_attn_varlen_func,
hg_flash_attn_varlen_func,
flash_mla_with_kvcache,
get_mla_metadata,
vllm_flash_attn_varlen_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func,
......
......@@ -2,7 +2,6 @@
from typing import Optional, Union
from typing import List, Tuple
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -19,13 +18,19 @@ from flash_attn.utils.sparse_utils import hyperparameter_check, get_block_map_me
DEFAULT_FA_VERSION = 2
try:
torch._C._dispatch_find_schema_or_throw("flash_attn2_c_op::varlen_fwd", "")
_has_flash_attn2_c_varlen_fwd = True
except RuntimeError:
_has_flash_attn2_c_varlen_fwd = False
def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
def round_multiple(x, m):
return (x + m - 1) // m * m
def _get_block_size_n(device, head_dim, is_dropout, is_causal):
# This should match the block sizes in the CUDA kernel
......@@ -59,7 +64,7 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal):
elif head_dim <= 512:
return 64
if torch.__version__ >= "2.4.0":
if torch.__version__ >= "2.4.0" and _has_flash_attn2_c_varlen_fwd:
_torch_custom_op_wrapper = torch.library.custom_op
_torch_register_fake_wrapper = torch.library.register_fake
else:
......@@ -199,7 +204,11 @@ def varlen_fwd_fake(
return out, softmax_lse, p, rng_state
_wrapped_flash_attn_varlen_forward = torch.ops.flash_attn2_c_op.varlen_fwd
_wrapped_flash_attn_varlen_forward = (
torch.ops.flash_attn2_c_op.varlen_fwd
if _has_flash_attn2_c_varlen_fwd
else _flash_attn_varlen_forward
)
def _flash_attn_backward(
......@@ -596,7 +605,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
softcap,
alibi_slopes,
deterministic,
return_softmax,
return_softmax,
bhsd = False
):
if softmax_scale is None:
......@@ -611,7 +620,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
return_softmax=return_softmax and dropout_p > 0,
bhsd = bhsd
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
......@@ -1922,7 +1931,7 @@ def vllm_flash_attn_varlen_func(
# Version selector
fa_version: int = DEFAULT_FA_VERSION,
s_aux=None,
):
):
"""
仅用于vllm prefix cache
dropout_p should be set to 0.0 during evaluation
......@@ -1994,7 +2003,7 @@ def vllm_flash_attn_varlen_func(
else:
assert len(window_size) == 2
real_window_size = (window_size[0], window_size[1])
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
dummy_cu_seqlens_k = torch.empty_like(cu_seqlens_q)
......@@ -2005,7 +2014,7 @@ def vllm_flash_attn_varlen_func(
bs = cu_seqlens_q.shape[0] - 1
total_q = q.shape[0]
# max_seqlen_q*bs==total_q and max_seqlen_q<=4 means mtp
# if mtp, k head must be 1.
# if mtp, k head must be 1.
# todo : support k head >1
is_mtp = (max_seqlen_q*bs==total_q and max_seqlen_q>1 and max_seqlen_q<5)
if (max_seqlen_q==1 or is_mtp ) and real_window_size[0]==-1:
......@@ -2015,9 +2024,9 @@ def vllm_flash_attn_varlen_func(
else :
out = torch.empty_like(q)
flash_attn_cuda.paged_attention(out,q.reshape(bs,max_seqlen_q,q.shape[1],q.shape[-1]),k,v,softmax_scale,block_table,
seqused_k,alibi_slopes,kv_cache_dtype,q_descale,k_descale,v_descale,max_seqlen_k,s_aux)
seqused_k,alibi_slopes,kv_cache_dtype,q_descale,k_descale,v_descale,max_seqlen_k,s_aux)
return out
is_938 = "gfx938" in torch.cuda.get_device_properties("cuda").gcnArchName
is_938 = ("gfx938" in torch.cuda.get_device_properties("cuda").gcnArchName or "gfx92a" in torch.cuda.get_device_properties("cuda").gcnArchName)
if (not is_938) and k.dtype == torch.float8_e5m2 and v.dtype == torch.float8_e5m2:
assert q.dtype != torch.float8_e5m2 , "UnSupport q.dtype:fp8"
q_descale = None
......@@ -2048,7 +2057,7 @@ def vllm_flash_attn_varlen_func(
None,
s_aux,
)
else:
else:
if(k.dtype == torch.float8_e4m3fn or k.dtype == torch.float8_e5m2) and q.dtype != k.dtype:
if q_descale is not None:
q=q/q_descale
......@@ -2059,7 +2068,7 @@ def vllm_flash_attn_varlen_func(
v,
out,
cu_seqlens_q,
# cu_seqlens_k not used since we use seqused_k, but flash_api.cpp
# cu_seqlens_k not used since we use seqused_k, but flash_api.cpp
# still wants it so we pass all zeros
dummy_cu_seqlens_k if cu_seqlens_k is None else cu_seqlens_k,
seqused_k,
......@@ -2092,7 +2101,7 @@ def vllm_flash_attn_varlen_func(
v,
out,
cu_seqlens_q,
# cu_seqlens_k not used since we use seqused_k, but flash_api.cpp
# cu_seqlens_k not used since we use seqused_k, but flash_api.cpp
# still wants it so we pass all zeros
dummy_cu_seqlens_k if cu_seqlens_k is None else cu_seqlens_k,
seqused_k,
......@@ -2334,6 +2343,7 @@ def flash_attn_with_kvcache(
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
s_aux = maybe_contiguous(s_aux)
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
if cache_seqlens is not None and isinstance(cache_seqlens, int):
......@@ -2646,7 +2656,7 @@ def sparse_attn_varlen_func(
block_count and block_offset for slash sparsity patterns, and
column_count and column_index for vertical sparsity patterns.
For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490.
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
......@@ -2682,7 +2692,7 @@ def sparse_attn_varlen_func(
"""
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x
maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, softmax_lse = flash_attn_cuda.varlen_fwd_sparse(
q,
......@@ -2723,7 +2733,7 @@ def varlen_fwd_unified(
softmax_scale=None,
causal=False,
softcap=0.0,
window_size=(-1, -1),
window_size=(-1, -1),
alibi_slopes=None,
use_alibi_sqrt=False,
qq_bias=None,
......@@ -2732,15 +2742,125 @@ def varlen_fwd_unified(
*,
out=None,
return_softmax_lse=False,
q_descale=None,
k_descale=None,
v_descale=None,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
window_size_left, window_size_right = window_size
if k.dtype.is_floating_point:
k_dtype_bits = torch.finfo(k.dtype).bits
else:
k_dtype_bits = torch.iinfo(k.dtype).bits
is_mtp = (max_seqlen_q * seqused_k.size(0) == q.shape[0] and 1 < max_seqlen_q < 16)
if max_seqlen_q >= 16:
fp8_dtypes = [torch.float8_e4m3fn]
if hasattr(torch, "float8_e4m3fnuz"):
fp8_dtypes.append(torch.float8_e4m3fnuz)
if k_dtype_bits == 16 and q.shape[-1] == 256 and v.shape[-1] == 256:
out, softmax_lse = flash_attn_cuda.varlen_fwd_unified(
q,
k,
v,
out,
cu_seqlens_q,
max_seqlen_q,
seqused_k,
max_seqlen_k,
block_table,
softmax_scale,
softcap,
None, # q_descale
None, # k_descale
None, # v_descale
None, # output_scale
causal,
window_size_left,
window_size_right,
alibi_slopes,
use_alibi_sqrt,
qq_bias,
s_aux,
mm_prefix_range,
)
return (out, softmax_lse) if return_softmax_lse else out
bshd_prefill = _require_hg_varlen_symbol("hg_prefix_prefill_varlen_fwd")
fa_output, *rest_extend = bshd_prefill(
q,
k,
v,
out, # out_
cu_seqlens_q,
None, # cu_seqlens_k
seqused_k,
alibi_slopes,
block_table,
max_seqlen_q,
max_seqlen_k,
0.0, # dropout
softmax_scale,
False, # zero_tensors
causal,
window_size[0],
window_size[1],
softcap,
return_softmax_lse,
1,
None if (k_dtype_bits == 16) else q_descale,
None if (k_dtype_bits == 16) else k_descale,
None if (k_dtype_bits == 16) else v_descale,
s_aux,
True,
)
return (fa_output, rest_extend[0]) if return_softmax_lse else fa_output
bs = seqused_k.size(0)
total_q = q.shape[0]
if max_seqlen_q == 1 or is_mtp:
if k_dtype_bits == 16 and s_aux is not None:
raise RuntimeError(
"b16 prefix decode with attention sink is not supported by unified attention yet"
)
assert not use_alibi_sqrt and qq_bias is None and mm_prefix_range is None, \
f"Arguments not supported in hg_bshd_decode"
bshd_pa_decode = _require_hg_varlen_symbol("hg_prefix_decode_varlen_fwd")
result = bshd_pa_decode(
q,
k,
v,
out,
cu_seqlens_q,
None,
seqused_k,
alibi_slopes,
block_table,
max_seqlen_q,
max_seqlen_k,
0.0,
softmax_scale,
False,
causal,
window_size[0],
window_size[1],
softcap,
return_softmax_lse,
1,
None if (k_dtype_bits == 16) else q_descale,
None if (k_dtype_bits == 16) else k_descale,
None if (k_dtype_bits == 16) else v_descale,
s_aux,
True,
)
fa_output = result[0]
return (fa_output, result[1]) if return_softmax_lse else fa_output
out, softmax_lse = flash_attn_cuda.varlen_fwd_unified(
q,
k,
......@@ -3830,8 +3950,8 @@ def get_block_map_fast(q, k, topk_ratio, BLKQ=128, BLKK=64):
sparse_map = torch.zeros_like(pooled_score, dtype=torch.int8)
sparse_map.scatter_(-1, lut, 1)
return sparse_map, lut, topk
class SparseLinearAttention(nn.Module):
def __init__(self, head_dim, topk, feature_map='softmax', use_bf16=True, use_fp8=False, tie_feature_map_qk=True):
R'''
......@@ -3877,7 +3997,7 @@ class SparseLinearAttention(nn.Module):
with torch.no_grad():
nn.init.zeros_(self.proj_l.weight)
nn.init.zeros_(self.proj_l.bias)
def forward(self, q, k, v, return_sparsity=False):
R'''
Args:
......@@ -3886,18 +4006,18 @@ class SparseLinearAttention(nn.Module):
v: values of shape (B, L, H, D).
return_sparsity: whether to return the actual sparsity.
'''
B, seqlen_q, H, headdim = q.shape
if headdim == 64:
block_m = 64 if seqlen_q <= 2048 else 128
elif headdim == 128:
block_m = 64 if seqlen_q <= 2048 else 128
block_k = 64
block_k = 64
if headdim == 64:
sparse_map, lut, real_topk = get_block_map(q.transpose(1, 2).contiguous(), k.transpose(1, 2).contiguous(), topk_ratio=self.topk, BLKQ=block_m, BLKK=block_k)
else:
sparse_map, lut, real_topk = get_block_map_fast(q, k, topk_ratio=self.topk, BLKQ=block_m, BLKK=block_k)
q = q.to(self.dtype)
k = k.to(self.dtype)
v = v.to(self.dtype)
......@@ -3981,7 +4101,7 @@ def sparse_attn_with_sla(
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
"""
maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
dtype = torch.bfloat16 if use_bf16 else torch.float16
......@@ -3994,12 +4114,12 @@ def sparse_attn_with_sla(
block_m = 64 if seqlen_q <= 2048 else 128
elif headdim == 128:
block_m = 64 if seqlen_q <= 2048 else 128
block_k = 64
block_k = 64
if headdim == 64:
sparse_map, lut, real_topk = get_block_map(q.transpose(1, 2).contiguous(), k.transpose(1, 2).contiguous(), topk_ratio=topk, BLKQ=block_m, BLKK=block_k)
else:
sparse_map, lut, real_topk = get_block_map_fast(q, k, topk_ratio=topk, BLKQ=block_m, BLKK=block_k)
q = q.to(dtype)
k = k.to(dtype)
v = v.to(dtype)
......@@ -4045,15 +4165,6 @@ def _require_hg_varlen_symbol(name: str):
return symbol
def _apply_hg_kvcache_safe_env() -> None:
# DTK/gfx938 PA launch is sensitive to these knobs. Keep the old HG-safe
# defaults unless the caller explicitly asks for the raw kernel selection.
if os.environ.get("HG_KVCACHE_RAW_KERNEL") == "1":
return
os.environ.setdefault("PA_NO_MLS", "1")
os.environ.setdefault("PA_USE_TILE32X32", "1")
def _validate_hg_paged_kv_contract(k_cache, v_cache) -> None:
if k_cache.dim() != 4 or v_cache.dim() != 4:
raise ValueError("HG paged KV cache expects k and v to both be 4D tensors")
......@@ -4066,53 +4177,56 @@ def _validate_hg_paged_kv_contract(k_cache, v_cache) -> None:
"v=[num_blocks, page_block_size, num_heads_k, d_v]"
)
def _normalize_hg_paged_q_scales(q_scale, batch_size, num_heads_q, num_heads_k):
if q_scale is None:
raise ValueError("q_descale must be provided for HG int8 paged-kvcache path")
q_scale = maybe_contiguous(q_scale)
if q_scale.dim() == 1:
if q_scale.numel() == batch_size * num_heads_q:
q_scale = q_scale.view(batch_size, num_heads_q)
elif q_scale.numel() == batch_size * num_heads_k:
q_scale = q_scale.view(batch_size, num_heads_k)
if q_scale.dim() != 2 or q_scale.shape[0] != batch_size:
raise ValueError(
"q_descale must have shape [batch_size, num_heads_q] "
"or [batch_size, num_heads_k] for HG int8 paged-kvcache path"
)
if q_scale.shape[1] == num_heads_q:
return q_scale.contiguous()
if q_scale.shape[1] == num_heads_k and num_heads_q % num_heads_k == 0:
return q_scale.repeat_interleave(num_heads_q // num_heads_k, dim=1).contiguous()
raise ValueError(
"q_descale must have shape [batch_size, num_heads_q] "
"or [batch_size, num_heads_k] for HG int8 paged-kvcache path"
)
def get_mla_metadata(
cache_seqlens: torch.Tensor,
num_heads_per_head_k: int,
num_heads_k: int,
is_fp8_kvcache: bool = False,
):
return None, None
def _expand_hg_paged_kv_scales(scale, block_table, page_block_size, num_heads_k, name):
if scale is None:
raise ValueError(f"{name} must be provided for HG int8 paged-kvcache path")
scale = maybe_contiguous(scale)
batch_size = block_table.shape[0]
if scale.dim() == 1 and scale.numel() == batch_size * num_heads_k:
scale = scale.view(batch_size, num_heads_k)
if scale.dim() != 2 or scale.shape != (batch_size, num_heads_k):
raise ValueError(
f"{name} must have shape [batch_size, num_heads_k] for HG int8 paged-kvcache path"
def flash_mla_with_kvcache(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: Optional[torch.Tensor],
num_splits: Optional[torch.Tensor],
softmax_scale: Optional[float] = None,
causal: bool = False,
use_cuda_graph: bool = True,
out: Optional[torch.Tensor] = None,
):
if k_cache.dtype not in (torch.float16, torch.bfloat16):
raise NotImplementedError(
"HG MLA dispatch in the main repository supports fp16/bf16 only; "
"fp8/int8 MLA is not supported."
)
expanded = torch.empty(
(int(block_table.max().item()) + 1, page_block_size, num_heads_k),
device=scale.device,
dtype=scale.dtype,
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
hg_mla = _require_hg_varlen_symbol("hg_fwd_kvcache_mla")
max_seqlen_k = 1 if use_cuda_graph else cache_seqlens.max().item()
result = hg_mla(
q,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
out,
max_seqlen_k,
)
for batch_idx in range(batch_size):
block_ids = block_table[batch_idx].to(dtype=torch.long)
expanded[block_ids] = scale[batch_idx].view(1, 1, num_heads_k).expand(
block_ids.numel(), page_block_size, num_heads_k
)
return expanded.contiguous()
if len(result) < 2:
raise RuntimeError("hg_fwd_kvcache_mla did not return softmax_lse")
return result[0], result[1]
def hg_flash_attn_varlen_func(
q,
......@@ -4247,8 +4361,6 @@ def hg_flash_attn_varlen_func(
unsupported.append("num_splits")
if fa_version != 2:
unsupported.append("fa_version")
if s_aux is not None:
unsupported.append("s_aux")
if custom_mask is not None:
unsupported.append("custom_mask")
if unsupported:
......@@ -4266,6 +4378,7 @@ def hg_flash_attn_varlen_func(
raise ValueError("cu_seqlens_q must be provided")
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
s_aux = maybe_contiguous(s_aux)
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
......@@ -4333,10 +4446,6 @@ def hg_flash_attn_varlen_func(
raise ValueError("cu_seqlens_k and seqused_k cannot be provided at the same time")
if block_table is None:
raise ValueError("block_table must be provided when seqused_k is used")
if return_attn_probs:
raise NotImplementedError(
"return_attn_probs is not supported for HG prefix/paged compatibility paths"
)
if dropout_p != 0.0:
raise NotImplementedError("dropout_p must be 0.0 for HG prefix/paged compatibility paths")
......@@ -4346,6 +4455,33 @@ def hg_flash_attn_varlen_func(
k_dtype_bits = torch.iinfo(k.dtype).bits
if max_seqlen_q > 16 or (k_dtype_bits == 8 and max_seqlen_q > 1):
if k_dtype_bits == 16 and q.shape[-1] == 256 and v.shape[-1] == 256:
out, softmax_lse = flash_attn_cuda.varlen_fwd_unified(
q,
k,
v,
out,
cu_seqlens_q,
max_seqlen_q,
seqused_k,
max_seqlen_k,
block_table,
softmax_scale,
softcap,
None, # q_descale
None, # k_descale
None, # v_descale
None, # output_scale
causal,
window_size[0],
window_size[1],
alibi_slopes,
False, # use_alibi_sqrt
None, # qq_bias
s_aux,
None, # mm_prefix_range
)
return (out, softmax_lse) if wants_aux else out
prefix_prefill = _require_hg_varlen_symbol("hg_prefix_prefill_varlen_fwd")
result = prefix_prefill(
q,
......@@ -4366,15 +4502,16 @@ def hg_flash_attn_varlen_func(
window_size[0],
window_size[1],
softcap,
return_softmax_lse,
return_softmax_lse or return_attn_probs,
1,
None if k_dtype_bits == 16 else q_descale,
None if k_dtype_bits == 16 else k_descale,
None if k_dtype_bits == 16 else v_descale,
s_aux,
is_bf16_output,
)
fa_output = result[0]
return (fa_output, result[1]) if return_softmax_lse else fa_output
return (fa_output, result[1]) if (return_softmax_lse or return_attn_probs) else fa_output
if k_dtype_bits == 16:
prefix_decode = _require_hg_varlen_symbol("hg_prefix_decode_varlen_fwd")
......@@ -4397,13 +4534,18 @@ def hg_flash_attn_varlen_func(
window_size[0],
window_size[1],
softcap,
return_softmax_lse,
return_softmax_lse or return_attn_probs,
1,
None,
None,
None,
s_aux,
is_bf16_output,
)
fa_output = result[0]
return (fa_output, result[1]) if return_softmax_lse else fa_output
return (fa_output, result[1]) if (return_softmax_lse or return_attn_probs) else fa_output
if return_softmax_lse:
if return_softmax_lse or return_attn_probs:
raise NotImplementedError(
"return_softmax_lse is not supported for the HG paged-kvcache compatibility path"
)
......@@ -4411,28 +4553,6 @@ def hg_flash_attn_varlen_func(
_validate_hg_paged_kv_contract(k, v)
if k.shape[1] != 128:
raise NotImplementedError("HG paged-kvcache path currently requires page_block_size == 128")
_apply_hg_kvcache_safe_env()
q_descale = _normalize_hg_paged_q_scales(
q_descale,
batch_size=block_table.shape[0],
num_heads_q=q.shape[1],
num_heads_k=k.shape[2],
)
k_descale = _expand_hg_paged_kv_scales(
k_descale,
block_table=block_table,
page_block_size=k.shape[1],
num_heads_k=k.shape[2],
name="k_descale",
)
v_descale = _expand_hg_paged_kv_scales(
v_descale,
block_table=block_table,
page_block_size=v.shape[1],
num_heads_k=v.shape[2],
name="v_descale",
)
hg_kvcache = _require_hg_varlen_symbol("hg_fwd_kvcache_bshd")
result = hg_kvcache(
q.unsqueeze(1),
......@@ -4442,7 +4562,7 @@ def hg_flash_attn_varlen_func(
None,
None,
seqused_k,
max_seqlen_k if max_seqlen_k > 0 else int(seqused_k.max().item()),
1,
None,
None,
None,
......@@ -4456,12 +4576,12 @@ def hg_flash_attn_varlen_func(
window_size[1],
softcap,
False,
num_splits,
-1,
None,
None,
q_descale,
k_descale,
v_descale,
None if k_dtype_bits == 16 else q_descale,
None if k_dtype_bits == 16 else k_descale,
None if k_dtype_bits == 16 else v_descale,
is_bf16_output,
)
return result[0].squeeze(1)
......@@ -52,6 +52,20 @@ SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE
FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE"
def cutlass_include_dirs():
candidates = [
Path(this_dir) / "csrc" / "cutlass" / "include",
]
cutlass_home = os.getenv("CUTLASS_HOME")
if cutlass_home:
candidates.append(Path(cutlass_home) / "include")
candidates.extend([
Path("/workspace/cutlass_3.2.1/include"),
Path("/public/home/huangly/数据采集/cutlass_3.2.1/include"),
])
return [str(path) for path in candidates if path.exists()]
def get_platform():
"""
Returns the platform name as used in wheel filenames.
......@@ -110,8 +124,18 @@ _HG_EXPLICIT_SOURCES_BY_MODE = {
"src/target/flash_fwd_hdim128_fp16.cpp",
"src/target/flash_fwd_hdim128_padding_mask_bf16.cpp",
"src/target/flash_fwd_hdim128_padding_mask_fp16.cpp",
"src/target/flash_fp8_fwd_hdim128_bf16.cpp",
"src/target/flash_fp8_fwd_hdim128_fp16.cpp",
"src/target/flash_fp8_fwd_hdim128_prefix_prefill_bf16.cpp",
"src/target/flash_fp8_fwd_hdim128_prefix_prefill_fp16.cpp",
"src/target/flash_fwd_hdim128_prefix_prefill_bf16.cpp",
"src/target/flash_fwd_hdim128_prefix_prefill_fp16.cpp",
"src/target/flash_fp8_fwd_hdim128_prefix_prefill_bf16.cpp",
"src/target/flash_fp8_fwd_hdim128_prefix_prefill_fp16.cpp",
"src/target/flash_fp8_fwd_hdimqk192_hdimv128_prefix_prefill_bf16.cpp",
"src/target/flash_fp8_fwd_hdimqk192_hdimv128_prefix_prefill_fp16.cpp",
"src/target/flash_fp8_fwd_hdim256_prefix_prefill_bf16.cpp",
"src/target/flash_fp8_fwd_hdim256_prefix_prefill_fp16.cpp",
"src/target/flash_fwd_hdim160_bf16.cpp",
"src/target/flash_fwd_hdim160_fp16.cpp",
"src/target/flash_fwd_hdim192_bf16.cpp",
......@@ -262,13 +286,64 @@ def _ninja_shell_join(args) -> str:
return " ".join(_ninja_escape(shlex.quote(str(x))) for x in args)
def _resolve_hg_compiler() -> str:
candidates = [
os.environ.get("FLASH_ATTN_HG_COMPILER"),
"/opt/dtk/bin/aicc",
"aicc",
]
for compiler in candidates:
if compiler and shutil.which(compiler):
return compiler
requested = [c for c in candidates if c]
raise RuntimeError(
"error: no usable HG aicc compiler found from: "
+ ", ".join(repr(c) for c in requested)
+ ". Set FLASH_ATTN_HG_COMPILER to the DTK aicc path."
)
def _normalize_hg_gfx_archs(gfx_version: str):
archs = []
for item in str(gfx_version).replace(",", ";").split(";"):
item = item.strip()
if not item:
continue
archs.append(item if item.startswith("gfx") else f"gfx{item}")
return archs
def _hg_target_define_value(archs):
return ",".join(arch[3:] if arch.startswith("gfx") else arch for arch in archs)
def _hg_arch_device_flags(archs):
flags = []
for arch in archs:
if arch in ("gfx936", "gfx938"):
flags.extend([f"-Xarch_{arch}", "-mllvm=-support-768-vgprs=true"])
elif arch in ("gfx928", "gfx92a", "gfx946"):
flags.extend([f"-Xarch_{arch}", "-mllvm=-support-512-vgprs=true"])
flags.extend([f"-Xarch_{arch}", "-mllvm=-co-issue-vgpr-size=256"])
return flags
def _is_rocm_5_7() -> bool:
version_file = "/opt/rocm/.info/version"
try:
with open(version_file, "r", encoding="utf-8") as f:
return f.read(100).startswith("5.7.0")
except OSError:
return False
def compute_hg_build_descriptor(
src_dir,
build_dir,
mode="all",
extra_options_raw="-DGFX_VERSION=938 -Wl,-Bsymbolic",
extra_options_raw="-DGFX_VERSION=938;936 -Wl,-Bsymbolic",
):
"""Collect HG sources and flags for Ninja (no compile). Default: mode=all, gfx938."""
"""Collect HG sources and flags for Ninja (no compile). Default: mode=all, gfx938/gfx936."""
import sysconfig as _sysconfig
src_dir = os.path.abspath(str(src_dir))
......@@ -279,7 +354,7 @@ def compute_hg_build_descriptor(
BUILD_FA_FWD = BUILD_FA_BWD = BUILD_FA_KVCACHE = False
BUILD_FA_PERMUTE = BUILD_FLASHMLA = False
BUILD_C_INTERFACE = False
BUILD_ASM = False
BUILD_ASM = True
FA_DEBUG = True
FA_DEBUG_SUM_MAX = False
HEADDIM_128_ONLY = False
......@@ -355,14 +430,11 @@ def compute_hg_build_descriptor(
EXTRA_HIP_FLAGS.append(_tok)
if GFX_VERSION is None:
GFX_VERSION = "938"
GFX_VERSION = "938;936"
ROCM_PATH = os.environ.get("ROCM_PATH", os.environ.get("ROCM_HOME", "/opt/rocm"))
HG_COMPILER = _resolve_hg_compiler()
if not shutil.which("hipcc"):
raise RuntimeError(
"error: hipcc not found in PATH. Please activate the DTK environment first."
)
if not os.path.isdir(os.path.join(ROCM_PATH, "include")):
raise RuntimeError(
f"error: {ROCM_PATH}/include not found. "
......@@ -400,7 +472,8 @@ def compute_hg_build_descriptor(
"-lc10",
]
_gfx_comma = GFX_VERSION.replace(";", ",")
HG_ARCHS = _normalize_hg_gfx_archs(GFX_VERSION)
_gfx_comma = _hg_target_define_value(HG_ARCHS)
DEFINES = [
f"-DTARGET={_gfx_comma}",
"-D__HIP_PLATFORM_AMD__=1",
......@@ -444,8 +517,10 @@ def compute_hg_build_descriptor(
DEFINES.append("-DPA_PAGE_BLOCK_SIZE")
if MLA_PAGE_BLOCK_SIZE:
DEFINES.append("-DMLA_PAGE_BLOCK_SIZE")
if _is_rocm_5_7():
DEFINES.append("-DROCM_5_7")
OFFLOAD_FLAGS = [f"--offload-arch=gfx{_g}" for _g in GFX_VERSION.split(";") if _g]
OFFLOAD_FLAGS = [f"--offload-arch={_g}" for _g in HG_ARCHS]
INCLUDE_FLAGS = [
f"-I{ROCM_PATH}/include",
......@@ -457,23 +532,36 @@ def compute_hg_build_descriptor(
INCLUDE_FLAGS += TORCH_INCLUDE_FLAGS
COMMON_FLAGS = [
"-fPIC",
"-O3",
"-std=c++17",
"-fPIC",
"-ffast-math",
"-fno-finite-math-only",
"-fno-gpu-rdc",
"-mno-fma",
]
DTK_DEVICE_FLAGS = [
"-DHIP_ENABLE_WARP_SYNC_BUILTINS",
"-mllvm",
"-slp-phi-tree-bb-max-size=10000",
"-disable-machine-sink",
"-mllvm",
"-enable-num-vgprs-512=true",
"-Rpass-analysis=kernel-resource-usage",
"-ftemplate-backtrace-limit=0",
"-disable-code-sink",
"-mcode-object-version=5",
]
if not _is_rocm_5_7():
DTK_DEVICE_FLAGS += [
"-mllvm",
"-amdgpu-enable-rewrite-partial-reg-uses=false",
"-mllvm",
"-allow-gvn-convergent-call=true",
"-mllvm",
"-disallow-uniform-vmed3-combine=true",
"-mllvm",
"-hcu-pre-emit-load-store-opt=false",
"-mllvm",
"-amdgpu-early-inline-all=true",
"-mllvm",
"-amdgpu-function-calls=false",
]
DTK_DEVICE_FLAGS += _hg_arch_device_flags(HG_ARCHS)
if os.environ.get("FLASH_ATTN_HG_SAVE_TEMPS", "") == "1":
DTK_DEVICE_FLAGS.append("--save-temps")
......@@ -555,6 +643,7 @@ def compute_hg_build_descriptor(
"obj_dir": obj_dir,
"sources": _all_sources,
"objects": objects,
"compiler": HG_COMPILER,
"compile_flags": compile_flags,
"link_flags": link_flags,
"out_so": out_so,
......@@ -566,16 +655,17 @@ def run_hg_ninja_build(descriptor: dict) -> None:
"""Write build_hg.ninja and run ninja (parallel via MAX_JOBS)."""
build_dir = descriptor["build_dir"]
ninja_file = descriptor["ninja_path"]
compiler = _ninja_shell_join([descriptor["compiler"]])
out_so_ninja = _ninja_escape_path(descriptor["out_so"])
lines = [
"ninja_required_version = 1.3",
"",
"rule hipcc_compile",
" command = hipcc -c $in -o $out $FLAGS",
"rule hg_compile",
f" command = {compiler} -c $in -o $out $FLAGS",
" description = HG compile $in",
"",
"rule hipcc_link",
" command = hipcc -shared -o $out @$out.rsp $LINK_FLAGS",
"rule hg_link",
f" command = {compiler} -shared -o $out @$out.rsp $LINK_FLAGS",
" rspfile = $out.rsp",
" rspfile_content = $in",
" description = HG link $out",
......@@ -585,9 +675,9 @@ def run_hg_ninja_build(descriptor: dict) -> None:
"",
]
for src, obj in zip(descriptor["sources"], descriptor["objects"]):
lines.append(f"build {_ninja_escape_path(obj)}: hipcc_compile {_ninja_escape_path(src)}")
lines.append(f"build {_ninja_escape_path(obj)}: hg_compile {_ninja_escape_path(src)}")
obj_list = " ".join(_ninja_escape_path(obj) for obj in descriptor["objects"])
lines.append(f"build {out_so_ninja}: hipcc_link {obj_list}")
lines.append(f"build {out_so_ninja}: hg_link {obj_list}")
lines.append("")
os.makedirs(build_dir, exist_ok=True)
......@@ -616,6 +706,7 @@ HG_BUILD_DIR = os.path.join(this_dir, "build", "flash_attn_hg")
HG_SO_BUILD = os.path.join(HG_BUILD_DIR, "libflash_attention.so")
HG_SO_PKG = os.path.join(this_dir, "flash_attn", "lib", "libflash_attention.so")
HG_LIB_DIR = os.path.dirname(HG_SO_PKG)
os.environ['PYTORCH_NVCC'] = 'aicc'
# We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp
# files included in the source distribution, in case the user compiles from source.
......@@ -663,7 +754,21 @@ if not SKIP_CUDA_BUILD:
# HAS_HG_DISPATCH / -lflash_attention are applied there if the .so exists.
hg_compile_defs = []
hg_link_args = []
aicc_flags = [
"-mcode-object-version=5",
"-mllvm=-support-768-vgprs=true",
"-mllvm=-disable-machine-sink",
"-mllvm=-disable-code-sink",
"-mllvm=-amdgpu-enable-rewrite-partial-reg-uses=false",
"-mllvm=-allow-gvn-convergent-call=true",
"-mllvm=-disallow-uniform-vmed3-combine=true",
"-mllvm=-hcu-pre-emit-load-store-opt=false",
"-mllvm=-amdgpu-early-inline-all=true",
"-mllvm=-amdgpu-function-calls=false",
"-fno-finite-math-only",
"--gpu-max-threads-per-block=256",
"-mllvm=-unroll-threshold=10000"
]
# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
# torch._C._GLIBCXX_USE_CXX11_ABI
# https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
......@@ -896,7 +1001,7 @@ if not SKIP_CUDA_BUILD:
"-std=c++17",
"-DDCU_ASM",
# "-mllvm -not-combine-fma=true",
"-mllvm -slp-phi-tree-bb-max-size=10000",
# "-mllvm -slp-phi-tree-bb-max-size=10000",
# "-mllvm -allow-cse-cross-bb-convergent-call=true",
# "-mllvm -full-vectorize-slp=true",
f"-DFLASH_ATTENTION_BF16_TYPE={bf16_type}",
......@@ -936,6 +1041,7 @@ if not SKIP_CUDA_BUILD:
]
+ generator_flag
+ hg_compile_defs
+ aicc_flags
# + cc_flag
),
},
......@@ -944,6 +1050,7 @@ if not SKIP_CUDA_BUILD:
Path(this_dir) / "csrc" / "flash_attn",
Path(this_dir) / "csrc" / "flash_attn" / "src",
Path(this_dir) / "csrc" / "cutlass" / "include",
],
)
)
......@@ -1051,13 +1158,16 @@ class NinjaBuildExtension(BuildExtension):
if os.path.isdir(HG_SRC_DIR):
os.makedirs(HG_BUILD_DIR, exist_ok=True)
_maybe_clean_hg_build_dir(HG_BUILD_DIR)
print("=== Building HG libflash_attention.so (mode=all, gfx938, ninja) ===")
try:
desc = compute_hg_build_descriptor(
HG_SRC_DIR,
HG_BUILD_DIR,
mode="all",
extra_options_raw="-DGFX_VERSION=938 -Wl,-Bsymbolic",
extra_options_raw="-DGFX_VERSION=938;936 -Wl,-Bsymbolic",
)
print(
"=== Building HG libflash_attention.so "
f"(mode=all, gfx938/gfx936, ninja, compiler={desc['compiler']}) ==="
)
run_hg_ninja_build(desc)
if os.path.isfile(HG_SO_BUILD):
......@@ -1066,11 +1176,11 @@ class NinjaBuildExtension(BuildExtension):
use_hg = True
print(f"=== Copied HG .so -> {HG_SO_PKG} ===")
else:
print("WARNING: HG build completed but output .so is missing; continuing without HG dispatch")
raise RuntimeError("Error: HG build completed but output .so is missing")
except Exception as e:
print(f"WARNING: HG build failed ({e}), continuing without HG dispatch")
raise RuntimeError(f"Error: HG build failed ({e})")
else:
print(f"WARNING: HG source directory not found ({HG_SRC_DIR}), continuing without HG dispatch")
raise RuntimeError(f"Error: HG source directory not found ({HG_SRC_DIR})")
else:
# FLASH_BUILD_HG=0 should deterministically disable dispatch even if stale artifacts exist.
if os.path.isfile(HG_SO_PKG):
......
import os
import sys
import math
import torch
import pickle
import time
import numpy
import argparse
import random
from datetime import datetime
use_cuda_toolkits = os.path.exists("/usr/local/cuda/bin/nvcc")
use_rocm_toolkits = os.path.exists("/opt/rocm/llvm/bin/clang")
use_dtk_toolkits = os.path.exists("/opt/dtk/bin/aicc")
if (use_cuda_toolkits):
from vllm.vllm_flash_attn import flash_attn_varlen_func
elif (use_rocm_toolkits or use_dtk_toolkits):
try:
from flash_attention_interface import flash_attn_varlen_func, flash_attn_2_cuda, flash_attn_with_kvcache
except ModuleNotFoundError:
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
import flash_attn_2_cuda as flash_attn_cuda
def _require_hg_varlen_symbol(name: str):
symbol = getattr(flash_attn_cuda, name, None)
if symbol is None:
raise RuntimeError(
f"{name} is unavailable in this build. Rebuild flash_attn with HAS_HG_DISPATCH enabled."
)
return symbol
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, do_assert=True, cos_threshold=1e-5) -> None:
assert x.shape == y.shape, "for {}, x and y must have the same shape".format(name)
x, y = x.double(), y.double()
RMSE = ((x - y) * (x - y)).mean().sqrt().item()
cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12)
amax_diff = (x - y).abs().max().item()
rel_diff_mean = (x / y).abs().mean().item()
rel_diff_max = (x / y).abs().max().item()
print("name:{} cos_diff={:.12f}, RMSE=\x1b[35m{:.12f}\x1b[0m, amax_diff=\x1b[35m{:.12f}\x1b[0m, REL=\x1b[35m{:.12f}\x1b[0m, rel_max=\x1b[35m{:.12f}\x1b[0m".format(
name, cos_diff, RMSE, amax_diff, rel_diff_mean, rel_diff_max))
if (do_assert): assert cos_diff < cos_threshold
def scaled_dot_product_attention(__query, __key, __value, h_q, h_kv, is_causal=False, USE_CPU=False, return_max_sum=False, original_seqlen_kv=0, split_slice=0, is_bshd=False, window_size=(-1, -1)):
__query = __query.transpose(0, 1).contiguous()
__key = __key.transpose(0, 1).contiguous()
__value = __value.transpose(0, 1).contiguous()
# 判断是否使用 CPU 计算 golden, 避免 blas 的影响
original_device = __query.device
original_dtype = __query.dtype
if (USE_CPU):
__query = __query.cpu()
__key = __key.cpu()
__value = __value.cpu()
# print("scaled_dot_product_attention: ", query.shape, key.shape, value.shape)
__query = __query.float()
__key = __key.float()
__value = __value.float()
# 如果按照官方的方法返回
if (not return_max_sum):
__key = __key.repeat_interleave(h_q // h_kv, dim=0)
__value = __value.repeat_interleave(h_q // h_kv, dim=0)
attn_weight = __query @ __key.transpose(-2, -1) / math.sqrt(__query.size(-1))
# MTP > 1, causal/local mask applied
if (window_size != (-1, -1)):
s_q = __query.shape[-2]
s_k = __key.shape[-2]
left, right = window_size
if left < 0:
left = s_k
if right < 0:
right = s_k
row_idx = torch.arange(s_q, dtype=torch.int32, device=attn_weight.device)[:, None]
col_idx = torch.arange(s_k, dtype=torch.int32, device=attn_weight.device)[None, :]
col_idx_limit_left = row_idx + s_k - s_q - left
col_idx_limit_right = row_idx + s_k - s_q + right
temp_mask = (col_idx >= col_idx_limit_left) & (col_idx <= col_idx_limit_right)
attn_weight = attn_weight.masked_fill(temp_mask.logical_not(), float("-inf"))
elif (is_causal):
s_q = __query.shape[-2]
s_k = __key.shape[-2]
attn_bias = torch.zeros(s_q, s_k, dtype=__query.dtype, device=attn_weight.device)
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool, device=attn_weight.device).tril(diagonal=s_k - s_q)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(__query.dtype)
attn_weight += attn_bias
# some codes for debug
scores_max = attn_weight.to(torch.float32).max(-1)[0]
scores_sum = torch.exp(attn_weight.to(torch.float32) - scores_max.unsqueeze(-1)).sum(dim=-1)
# original codes
lse = attn_weight.logsumexp(dim=-1)
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
output = attn_weight @ __value
output = output.transpose(0, 1).contiguous()
return output.to(original_device).to(original_dtype), lse.to(original_device), scores_max.to(original_device), scores_sum.to(original_device)
def set_random_seed(seed=0):
random.seed(seed) # 设置 Python 的随机种子
numpy.random.seed(seed) # 设置 NumPy 的随机种子
torch.manual_seed(seed) # 设置 PyTorch 的随机种子
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed) # 设置所有 GPU 的随机种子
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['OMP_NUM_THREADS'] = '1' # 设置 OpenMP 的线程数
torch.set_num_threads(1) # 设置 PyTorch 的线程数
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--load', default=False, action='store_true', help='load path')
parser.add_argument('--trace', default=False, action='store_true', help='whether dump perf traces')
parser.add_argument('--bf16', default=False, action='store_true', help='whether use bfloat16 as main dtype')
parser.add_argument('--fp8', default=False, action='store_true', help='whether use fp8_e4m3 inputs for HG decode')
parser.add_argument('--pressure', default=False, action='store_true', help='whether do pressure test')
parser.add_argument('--cpu', default=False, action='store_true', help='whether compute golden via cpu')
parser.add_argument('--pad', default=False, action='store_true', help='whether make query uncontiguous to simulate vllm behaviors')
parser.add_argument('--iterations', type=int, default=100, help='pressure test times')
parser.add_argument('--block_size', type=int, default=128, help='page block_size')
parser.add_argument('--batch-size', type=int, default=1, help='batch size for generated inputs')
parser.add_argument('--seq-q', type=int, default=4, help='query length per batch for generated inputs')
parser.add_argument('--seq-k', type=int, default=2048, help='kv length per batch for generated inputs')
parser.add_argument('--num-heads', type=int, default=24, help='number of query heads for generated inputs')
parser.add_argument('--num-heads-kv', type=int, default=2, help='number of kv heads for generated inputs')
parser.add_argument('--head-dim-qk', type=int, default=128, help='query/key head dimension')
parser.add_argument('--head-dim-v', type=int, default=128, help='value head dimension')
parser.add_argument('--no-causal', dest='causal', default=True, action='store_false', help='disable causal mask for generated inputs')
parser.add_argument('--window-left', type=int, default=-1, help='left sliding window size')
parser.add_argument('--window-right', type=int, default=-1, help='right sliding window size')
parser.add_argument('--seed', default=False, action='store_true', help='whether do pressure test')
args = parser.parse_args()
if (args.seed):
set_random_seed(212)
# 从文件加载输入
if (args.load):
nvidia_packet = torch.load("./demo.pt")
query, key, value, cu_seqlens_q, max_seqlen_q, cache_seqlens, max_seqlen_k, softmax_scale, causal, window_size, alibi_slopes, page_table, softcap, fa_version, q_descale, k_descale, v_descale = nvidia_packet["inputs"]
vllm_golden = nvidia_packet["outputs"]
# 解析出必要的参数
batch_size = page_table.shape[0]
assert batch_size == cu_seqlens_q.shape[0] - 1, "check batch size"
page_block_size = key.shape[1]
num_heads_kv = key.shape[2]
num_heads = query.shape[1]
head_dim_qk = query.shape[2]
head_dim_v = key.shape[3]
infer_dtype = query.dtype
else:
# 随机生成 seqkv
batch_size = args.batch_size
# 得到 Q 的长度
seqlen_q = [args.seq_q for i in range(batch_size)]
seqlen_q_sum = sum(seqlen_q)
max_seqlen_q = max(seqlen_q)
cu_seqlens_q = numpy.array([0] + numpy.cumsum(seqlen_q).tolist()).astype("int32")
cu_seqlens_q = torch.from_numpy(cu_seqlens_q)
# 得到 KV 的长度
cache_seqlens = [args.seq_k for i in range(batch_size)]
# 指定分页块的大小, nvidia 64, ours 128
page_block_size = 16 if (use_cuda_toolkits) else args.block_size
# 根据分页块大小计算实际需要的页表的大小
max_seqlen_k = max(cache_seqlens)
seqlen_kv_real_required_page = [math.ceil(it / page_block_size) for it in cache_seqlens]
seqlen_kv_real_required_page_sum = sum(seqlen_kv_real_required_page)
# 默认按照最大 seqlenkv 的来分配
seqlen_kv_max_required_page = math.ceil(max_seqlen_k / page_block_size)
seqlen_kv_max_required_page_total = batch_size * seqlen_kv_max_required_page
# 打乱页表
shuffle = True
if (shuffle):
block_random = torch.randperm(seqlen_kv_max_required_page_total, dtype=torch.int32, device="cuda")
else:
block_random = torch.arange(seqlen_kv_max_required_page_total , dtype=torch.int32)
page_table = []
seq_block_incre = 0
for i in range(batch_size):
blocks_pad = [0] * seqlen_kv_max_required_page
if (shuffle):
blocks_pad[:seqlen_kv_real_required_page[i]] = block_random[seq_block_incre: seq_block_incre + seqlen_kv_real_required_page[i]].cpu().tolist()
seq_block_incre += seqlen_kv_real_required_page[i]
else:
blocks_pad = block_random[seq_block_incre: seq_block_incre + seqlen_kv_max_required_page].cpu().tolist()
seq_block_incre += seqlen_kv_max_required_page
page_table.append(torch.IntTensor(blocks_pad))
page_table = torch.stack(page_table).contiguous().to("cuda")
# 创建基本参数
head_dim_qk = args.head_dim_qk
head_dim_v = args.head_dim_v
num_heads = args.num_heads
num_heads_kv = args.num_heads_kv
infer_dtype = torch.float16 # deepseek 默认使用 bfloat16 推理
if (args.bf16): infer_dtype = torch.bfloat16 # 除非命令行指定用 fp16, 不受 args.dtype 影响
softmax_scale = 1.0 / math.sqrt(head_dim_qk)
causal = args.causal
window_size = (args.window_left, args.window_right)
alibi_slopes = None
softcap = 0.0
fa_version = 2
q_descale = torch.ones((batch_size, num_heads), dtype=torch.float32, device="cuda")
k_descale = torch.ones((batch_size, num_heads_kv), dtype=torch.float32, device="cuda")
v_descale = torch.ones((batch_size, num_heads_kv), dtype=torch.float32, device="cuda")
# 创建输入张量
if (args.pad):
query_origin_tensor = torch.randn((seqlen_q_sum, num_heads + 16, head_dim_qk), dtype=infer_dtype, device="cuda")
q = query_origin_tensor[:, :num_heads]
else:
q = torch.randn((seqlen_q_sum, num_heads, head_dim_qk), dtype=infer_dtype, device="cuda")
k_cache = torch.randn((seqlen_kv_max_required_page_total, page_block_size, num_heads_kv, head_dim_qk), device="cuda", dtype=infer_dtype)
v_cache = torch.randn((seqlen_kv_max_required_page_total, page_block_size, num_heads_kv, head_dim_v), device="cuda", dtype=infer_dtype)
vllm_golden = None
cu_seqlens_q = cu_seqlens_q.to(q.device)
cache_seqlens = torch.from_numpy(numpy.array(cache_seqlens).astype("int32")).to(q.device)
q_ref = q
k_cache_ref = k_cache
v_cache_ref = v_cache
if args.fp8:
if not hasattr(torch, "float8_e4m3fn"):
raise RuntimeError("This PyTorch build does not support torch.float8_e4m3fn")
q = q.to(torch.float8_e4m3fn)
k_cache = k_cache.to(torch.float8_e4m3fn)
v_cache = v_cache.to(torch.float8_e4m3fn)
q_ref = q.to(infer_dtype)
k_cache_ref = k_cache.to(infer_dtype)
v_cache_ref = v_cache.to(infer_dtype)
# 展示一下输入数据
print("--------------------------------------------------------------------------------------------")
print("q: ", q.shape, q.dtype, q.is_contiguous(), q.stride())
print("k_cache: ", k_cache.shape, k_cache.dtype, k_cache.is_contiguous(), k_cache.stride())
print("v_cache: ", v_cache.shape, v_cache.dtype, v_cache.is_contiguous(), v_cache.stride())
print("cu_seqlens_q: ", cu_seqlens_q.shape, cu_seqlens_q.dtype, cu_seqlens_q.is_contiguous())
print("cu_seqlens_q: ", cu_seqlens_q)
print("max_seqlen_q: ", max_seqlen_q)
print("cache_seqlens: ", cache_seqlens)
print("max_seqlen_k: ", max_seqlen_k)
print("softmax_scale: ", softmax_scale)
print("causal: ", causal)
print("window_size: ", window_size)
print("alibi_slopes: ", alibi_slopes)
print("page_table: ", page_table.shape, page_table.dtype, page_table.is_contiguous(), page_table.stride())
print("page_table: ", page_table)
print("softcap: ", softcap)
print("fa_version: ", fa_version)
print("q_descale: ", q_descale.shape, q_descale.dtype, q_descale.tolist())
print("k_descale: ", k_descale.shape, k_descale.dtype, k_descale.tolist())
print("v_descale: ", v_descale.shape, v_descale.dtype, v_descale.tolist())
print("--------------------------------------------------------------------------------------------")
# 先从 kvcache 中还原出 key 和 value
key_original = []
value_original = []
for b in range(batch_size):
# 获取页表索引
index = page_table[b]
# 获取实际的索引
max_page_blocks = math.ceil(cache_seqlens[b] / page_block_size)
actual_index = index[:max_page_blocks]
# 根据该页表索引获取当前 seqlenkv 的内容
key_content = k_cache_ref[actual_index]
# reshape 回去
key_content = key_content.view(-1, num_heads_kv, head_dim_qk)[:cache_seqlens[b]].contiguous()
# 同理
value_content = v_cache_ref[actual_index].view(-1, num_heads_kv, head_dim_v)[:cache_seqlens[b]].contiguous()
key_original.append(key_content)
value_original.append(value_content)
# 同理还原出 query 的内容
query_original = []
cum_q = 0
for b in range(batch_size):
query_len = cu_seqlens_q[b + 1] - cu_seqlens_q[b]
query_content = q_ref[cum_q: cum_q + query_len]
query_original.append(query_content.contiguous())
cum_q += query_len
# 重新实现 self-attention
golden = []
golden_lse = []
golden_max = []
for b in range(batch_size):
tmp_output, lse, scores_max, scores_sum = scaled_dot_product_attention(query_original[b], key_original[b], value_original[b], num_heads, num_heads_kv, is_causal=causal, USE_CPU=args.cpu, window_size=window_size)
golden.append(tmp_output)
golden_lse.append(lse)
golden_max.append(scores_max)
golden = torch.cat(golden, dim=0)
golden_lse = torch.cat(golden_lse, dim=-1)
golden_max = torch.cat(golden_max, dim=-1)
print("golden: ", golden.shape)
print("golden_lse: ", golden_lse.shape)
print("--------------------------------------------------------------------------------------------")
if (True):
# fa_output, fa_lse = flash_attn_2_cuda.prefix_decode_varlen_fwd(
bshd_pa_decode = _require_hg_varlen_symbol("hg_prefix_decode_varlen_fwd")
fa_output, fa_lse = bshd_pa_decode(
q,
k_cache,
v_cache,
None, # out_
cu_seqlens_q,
None, # cu_seqlens_k
cache_seqlens,
alibi_slopes,
page_table,
max_seqlen_q,
max_seqlen_k,
0.0, # dropout
softmax_scale,
False, # zero_tensors
causal,
window_size[0],
window_size[1],
softcap,
True, # return_softmax_lse,
1,
q_descale if args.fp8 else None,
k_descale if args.fp8 else None,
v_descale if args.fp8 else None,
None, # s_aux
infer_dtype == torch.bfloat16,
)
else:
fa_output, fa_lse, *rest = flash_attn_with_kvcache(
q=q,
k_cache=k_cache,
v_cache=v_cache,
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k_new=cache_seqlens,
max_seqlen_q=max_seqlen_q,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
softcap=softcap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
)
torch.cuda.synchronize()
if (vllm_golden is not None):
# 检查保存流程是否有错误
cal_diff(fa_output, vllm_golden, "check")
print("fa_output: ", fa_output.shape)
if (fa_lse is not None): print("fa_lse: ", fa_lse.shape)
# 检验精度如何
fp8_threshold = 5e-3
cal_diff(golden, fa_output, "accuracy", True, fp8_threshold if args.fp8 else 1e-5)
if (fa_lse is not None): cal_diff(golden_lse, fa_lse, "softmax_lse", True, fp8_threshold if args.fp8 else 1e-5)
print("--------------------------------------------------------------------------------------------")
# benchmark 性能数据
import triton
def benchmark_prefix_prefill():
_ = bshd_pa_decode(
q,
k_cache,
v_cache,
None,
cu_seqlens_q,
None,
cache_seqlens,
alibi_slopes,
page_table,
max_seqlen_q,
max_seqlen_k,
0.0,
softmax_scale,
False,
causal,
window_size[0],
window_size[1],
softcap,
True,
1,
q_descale if args.fp8 else None,
k_descale if args.fp8 else None,
v_descale if args.fp8 else None,
None,
infer_dtype == torch.bfloat16,
)
# 适时关闭, 用于 debug
if ((os.getenv("FA_DEBUG") is None) and (os.getenv("HIP_LOG_LEVEL") is None) and not args.trace):
import triton
t = triton.testing.do_bench_cudagraph(benchmark_prefix_prefill)
FLOPS = float(0)
BYTES = float(0)
for b in range(batch_size):
batch_seqlen_q = cu_seqlens_q[b + 1] - cu_seqlens_q[b]
batch_seqlen_k = cache_seqlens[b]
effective_seqlen_k = batch_seqlen_k
if window_size != (-1, -1):
window_left, window_right = window_size
left = batch_seqlen_k if window_left < 0 else window_left
right = batch_seqlen_k if window_right < 0 else window_right
effective_seqlen_k = min(batch_seqlen_k, left + batch_seqlen_q + right)
undo_flops = batch_seqlen_q * batch_seqlen_q / 2 if (causal and window_size == (-1, -1)) else 0
attn_elems = batch_seqlen_q * effective_seqlen_k - undo_flops
qk_flops = num_heads * attn_elems * head_dim_qk * 2
pv_flops = num_heads * attn_elems * head_dim_v * 2
FLOPS += qk_flops + pv_flops
q_load = batch_seqlen_q * num_heads * head_dim_qk
k_load = effective_seqlen_k * num_heads_kv * head_dim_qk # k load not only once
v_load = effective_seqlen_k * num_heads_kv * head_dim_v
BYTES += q_load * q.element_size() + k_load * k_cache.element_size() + v_load * v_cache.element_size() # ignore storation ?
print(f"Performance: {t:.3f} ms, \x1b[35m{FLOPS / 10 ** 9 / t:.2f}\x1b[0m TFLOPS, \x1b[35m{BYTES / 10 ** 6 / t:.0f}\x1b[0m GB/s")
# 压力测试
if (args.pressure):
pressure_count = max(100, args.iterations)
for p in range(pressure_count):
pressure_fa_output = torch.zeros_like(fa_output)
pressure_fa_output, _ = bshd_pa_decode(
q.clone(),
k_cache.clone(),
v_cache.clone(),
None,
cu_seqlens_q,
None,
cache_seqlens,
alibi_slopes,
page_table,
max_seqlen_q,
max_seqlen_k,
0.0,
softmax_scale,
False,
causal,
window_size[0],
window_size[1],
softcap,
True,
1,
q_descale if args.fp8 else None,
k_descale if args.fp8 else None,
v_descale if args.fp8 else None,
infer_dtype == torch.bfloat16,
)
torch.cuda.synchronize()
is_equal = torch.equal(pressure_fa_output, fa_output)
if (not is_equal): cal_diff(pressure_fa_output, fa_output, "pressure")
assert is_equal, "\x1b[31mUnstable\x1b[0m!"
del pressure_fa_output
sys.stdout.write("\rPressure Test: {}/{}".format(p + 1, pressure_count))
print(" \x1b[32mPASS\x1b[0m")
print("-----------------------------------------------------------------------------------")
import torch
import os
from vllm.logger import init_logger
from vllm.platforms import current_platform
......@@ -6,6 +7,39 @@ from vllm.triton_utils import tl, triton
import math
import time
from typing import Optional
UNIFIED_BLOCK_SIZE = int(os.getenv("UNIFIED_BLOCK_SIZE", "128"))
def estimate_unified_attention_bytes(
batch_size,
seqlen_q,
seqlen_k,
nheads,
nheads_k,
d,
block_size,
q_bytes,
k_bytes,
v_bytes,
out_bytes=0,
d_v=None,
window_size=(-1, -1),
):
d_v = d if d_v is None else d_v
effective_seqlen_k = seqlen_k
if window_size != (-1, -1):
window_left, window_right = window_size
left = seqlen_k if window_left < 0 else window_left
right = seqlen_k if window_right < 0 else window_right
effective_seqlen_k = min(seqlen_k, left + seqlen_q + right)
num_blocks = math.ceil(effective_seqlen_k / block_size)
q_bytes_total = batch_size * seqlen_q * nheads * d * q_bytes
kv_bytes_total = effective_seqlen_k * batch_size * nheads_k * (d * k_bytes + d_v * v_bytes)
out_bytes_total = batch_size * seqlen_q * nheads * d_v * out_bytes
metadata_bytes = (batch_size + 1) * 4 + batch_size * 4 + batch_size * num_blocks * 4
return q_bytes_total + kv_bytes_total + out_bytes_total + metadata_bytes
import pytest
import torch
import torch.nn.functional as F
......@@ -18,19 +52,8 @@ import pdb
from einops import rearrange, repeat
from flash_attn import (
flash_attn_func,
flash_attn_kvpacked_func,
flash_attn_qkvpacked_func,
flash_attn_varlen_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func,
flash_attn_with_kvcache,
varlen_fwd_unified,
)
from flash_attn import flash_attn_func
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import _get_block_size_n
from flash_attn.layers.rotary import apply_rotary_emb
MAX_HEADDIM_SM8x = 192
......@@ -109,6 +132,7 @@ def kernel_unified_attention_2d(
TILE_SIZE: tl.constexpr, # int must be power of 2
HEAD_SIZE: tl.constexpr, # int
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
VALUE_HEAD_SIZE: tl.constexpr, # int
USE_ALIBI_SLOPES: tl.constexpr, # bool
USE_ALIBI_SQRT: tl.constexpr, # bool
USE_QQ_BIAS: tl.constexpr, # bool
......@@ -167,6 +191,7 @@ def kernel_unified_attention_2d(
)
dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1)
value_dim_mask = tl.where(offs_d < VALUE_HEAD_SIZE, 1, 0).to(tl.int1)
query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1)
query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1)
......@@ -296,7 +321,7 @@ def kernel_unified_attention_2d(
# V : (TILE_SIZE, HEAD_SIZE)
V_load = tl.load(
value_cache_ptr + v_offset,
mask=dim_mask[None, :] & tile_mask[:, None],
mask=value_dim_mask[None, :] & tile_mask[:, None],
other=0.0,
)
......@@ -425,7 +450,7 @@ def kernel_unified_attention_2d(
tl.store(
output_ptr + output_offset,
acc,
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
mask=value_dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
)
......@@ -1039,6 +1064,7 @@ def unified_attention(
TILE_SIZE=TILE_SIZE_PREFILL,
HEAD_SIZE=head_size,
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
VALUE_HEAD_SIZE=v.shape[3],
USE_ALIBI_SLOPES=use_alibi_slopes,
USE_ALIBI_SQRT=use_alibi_sqrt,
USE_QQ_BIAS=use_qq_bias,
......@@ -1235,7 +1261,14 @@ def make_paged_kv(k_list, v_list, block_size, device, dtype):
return k_cache, v_cache, block_table
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool = False, is_e5m2: bool = False) -> None:
def cal_diff(
x: torch.Tensor,
y: torch.Tensor,
name: str,
use_fp8: bool = False,
is_e5m2: bool = False,
cos_threshold: Optional[float] = None,
) -> None:
torch_dtype = x.dtype
x, y = x.double(), y.double()
RMSE = ((x - y) * (x - y)).mean().sqrt().item()
......@@ -1245,23 +1278,60 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool = False,
if is_e5m2:
assert cos_diff < 1e-2
elif use_fp8:
assert cos_diff < 1e-3
assert cos_diff < 5e-3
elif cos_threshold is not None:
assert cos_diff < cos_threshold
else:
assert cos_diff < (1e-4 if torch_dtype == torch.bfloat16 else 1e-5)
class _NoUnifiedFallback:
def __init__(self, module):
self._module = module
def __getattr__(self, name):
if name == "varlen_fwd_unified":
raise AssertionError("unexpected fallback to flash_attn_cuda.varlen_fwd_unified")
return getattr(self._module, name)
def varlen_fwd_unified_expect_hg(expected_symbol, *args, **kwargs):
fn_globals = varlen_fwd_unified.__globals__
original_module = fn_globals["flash_attn_cuda"]
original_require = fn_globals["_require_hg_varlen_symbol"]
called_symbols = []
def require_hg_symbol(name):
called_symbols.append(name)
return original_require(name)
fn_globals["flash_attn_cuda"] = _NoUnifiedFallback(original_module)
fn_globals["_require_hg_varlen_symbol"] = require_hg_symbol
try:
result = varlen_fwd_unified(*args, **kwargs)
finally:
fn_globals["flash_attn_cuda"] = original_module
fn_globals["_require_hg_varlen_symbol"] = original_require
assert expected_symbol in called_symbols, (
f"expected {expected_symbol}, got HG calls {called_symbols}"
)
return result
# ---------------------------------------------------------------------------
# Accuracy tests
# ---------------------------------------------------------------------------
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("use_fp8", [False, True])
@pytest.mark.parametrize("mha_type", ["gqa"])
@pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("softcap", [0.0])
@pytest.mark.parametrize("window_size", [(-1, -1)])
@pytest.mark.parametrize("window_size", [(-1, -1), (511, 0)])
@pytest.mark.parametrize("use_alibi_sqrt", [True, False])
@pytest.mark.parametrize("use_qq_bias", [True, False]) # seqlen_q > seqlen_k 时 skip
@pytest.mark.parametrize("use_sinks", [True, False])
@pytest.mark.parametrize("use_mm_prefix", [True, False])
@pytest.mark.parametrize("d", [128, 256])
@pytest.mark.parametrize("d,d_v", [(128, 128), (192, 128), (256, 256)])
@pytest.mark.parametrize(
"batch_size,seqlen_q,seqlen_k,block_size",
[
......@@ -1274,43 +1344,63 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool = False,
# --- 场景 2: Decode 场景 (增量推理) ---
# 验证 seqlen_q=1 时,如何正确从 KV Cache 的最后位置读取信息
# 此时 qq_bias 实际上只退化为向量加法,最容易测出指针偏移错误
(8, 1, 2048, 128), # 高 Batch 的标准 Decode
(1, 1, 4096, 128), # 超长上下文 Decode,验证大索引寻址
# --- 场景 3: Chunked Prefill / Speculative Decoding (分段/投机采样) ---
# Q 小于 K,但大于 1。这是最难写的逻辑,验证 Is_causal 的动态截断
(2, 128, 1024, 128), # Q 是一小段,K 是长历史
(4, 256, 512, 128), # 验证 Q 和 K 比例较近时的处理
# --- 场景 4: 边界非对称尺寸 (非 2 的幂次) ---
# 专门用来抓那些“假设数据一定是 BlockSize 整数倍”的 Bug
(64, 1, 2048, UNIFIED_BLOCK_SIZE),
(16, 1, 2048, UNIFIED_BLOCK_SIZE),
(1, 1, 4096, UNIFIED_BLOCK_SIZE),
(64, 4, 2048, UNIFIED_BLOCK_SIZE),
(32, 4, 2048, UNIFIED_BLOCK_SIZE),
(16, 4, 2048, UNIFIED_BLOCK_SIZE),
(8, 4, 2048, UNIFIED_BLOCK_SIZE), # 高 Batch 的标准 Decode
(4, 4, 2048, UNIFIED_BLOCK_SIZE),
(2, 4, 2048, UNIFIED_BLOCK_SIZE),
(1, 4, 4096, UNIFIED_BLOCK_SIZE), # 超长上下文 Decode,验证大索引寻址
# --- 场景 3: Prefix Prefill ---
(1, 16, 128, UNIFIED_BLOCK_SIZE), # fp8 prefill lower boundary
(1, 32, 512, UNIFIED_BLOCK_SIZE),
(2, 32, 513, UNIFIED_BLOCK_SIZE), # non block-aligned KV length
(2, 64, 2048, UNIFIED_BLOCK_SIZE),
(3, 96, 1537, UNIFIED_BLOCK_SIZE), # non power-of-two batch/length
(2, 128, 1024, UNIFIED_BLOCK_SIZE),
# # --- 场景 4: 边界非对称尺寸 (非 2 的幂次) ---
# # 专门用来抓那些“假设数据一定是 BlockSize 整数倍”的 Bug
(1, 127, 127, 128), # 刚好差 1 个填满 Block
(2, 33, 1025, 128), # 非常细碎的 Block 和不规则长度
],
)
def test_unified_attn_2d(
batch_size, seqlen_q, seqlen_k, block_size,
d, causal, window_size, softcap,
d, d_v, causal, window_size, softcap,
mha_type, dtype,
use_alibi_sqrt, use_qq_bias, use_sinks, use_mm_prefix,
use_alibi_sqrt, use_qq_bias, use_sinks, use_mm_prefix, use_fp8,
):
device = torch.device("cuda")
torch.manual_seed(42)
nheads = 8
nheads_k = 1 if mha_type == "gqa" else nheads
nheads = 16
nheads_k = 2 if mha_type == "gqa" else nheads
softmax_scale = d ** (-0.5)
MAX_MM_RANGES = 2
# skip invalid combos
if use_alibi_sqrt and not causal:
pytest.skip("alibi_sqrt only tested with causal=True")
if use_alibi_sqrt:
pytest.skip("HG unified attention does not support alibi_sqrt yet")
if use_qq_bias and seqlen_q > seqlen_k:
pytest.skip("qq_bias requires seqlen_q <= seqlen_k")
if use_qq_bias:
pytest.skip("HG unified attention does not support qq_bias yet")
if use_mm_prefix and seqlen_q > seqlen_k:
pytest.skip("mm_prefix not supported when seqlen_q > seqlen_k")
if use_mm_prefix:
pytest.skip("HG unified attention does not support mm_prefix yet")
if (not use_fp8) and use_sinks and (seqlen_q == 1 or 1 < seqlen_q < 16):
pytest.skip("b16 prefix decode sinks are not supported yet")
if (not use_fp8) and use_sinks and seqlen_q >= 16 and d == 256 and d_v == 256:
pytest.skip("b16 prefix prefill 256/256 sinks are not supported yet")
# if use_mm_prefix and not causal:
# pytest.skip("mm_prefix_range is only meaningful with causal=True")
......@@ -1318,15 +1408,30 @@ def test_unified_attn_2d(
for _ in range(batch_size):
q_list.append(torch.randn(seqlen_q, nheads, d, device=device, dtype=dtype))
k_list.append(torch.randn(seqlen_k, nheads_k, d, device=device, dtype=dtype))
v_list.append(torch.randn(seqlen_k, nheads_k, d, device=device, dtype=dtype))
v_list.append(torch.randn(seqlen_k, nheads_k, d_v, device=device, dtype=dtype))
if use_fp8:
fp8_dtype = current_platform.fp8_dtype()
q_kernel_list = [q.to(fp8_dtype) for q in q_list]
k_kernel_list = [k.to(fp8_dtype) for k in k_list]
v_kernel_list = [v.to(fp8_dtype) for v in v_list]
q_ref_list = [q.to(dtype) for q in q_kernel_list]
k_ref_list = [k.to(dtype) for k in k_kernel_list]
v_ref_list = [v.to(dtype) for v in v_kernel_list]
else:
q_kernel_list, k_kernel_list, v_kernel_list = q_list, k_list, v_list
q_ref_list, k_ref_list, v_ref_list = q_list, k_list, v_list
q_varlen = torch.cat(q_list, dim=0)
q_varlen = torch.cat(q_kernel_list, dim=0)
cu_seqlens_q = torch.zeros(batch_size + 1, device=device, dtype=torch.int32)
cu_seqlens_q[1:] = torch.cumsum(
torch.tensor([seqlen_q] * batch_size, dtype=torch.int32), dim=0
)
seqused_k = torch.tensor([seqlen_k] * batch_size, device=device, dtype=torch.int32)
k_cache, v_cache, block_table = make_paged_kv(k_list, v_list, block_size, device, dtype)
k_cache, v_cache, block_table = make_paged_kv(k_kernel_list, v_kernel_list, block_size, device, q_varlen.dtype)
q_descale = torch.ones((batch_size, nheads), device=device, dtype=torch.float32) if use_fp8 else None
k_descale = torch.ones((batch_size, nheads_k), device=device, dtype=torch.float32) if use_fp8 else None
v_descale = torch.ones((batch_size, nheads_k), device=device, dtype=torch.float32) if use_fp8 else None
# Build optional tensors
alibi_slopes = None
......@@ -1339,7 +1444,8 @@ def test_unified_attn_2d(
sinks = None
if use_sinks:
sinks = torch.randn(nheads, device=device, dtype=dtype)
sink_dtype = torch.bfloat16 if use_fp8 else torch.float32
sinks = (torch.randn(nheads, device=device, dtype=sink_dtype) * 0.25) + 2.0
mm_prefix_range = None
if use_mm_prefix:
......@@ -1354,7 +1460,7 @@ def test_unified_attn_2d(
for i in range(batch_size):
ref_outs.append(
ref_attn(
q_list[i], k_list[i], v_list[i],
q_ref_list[i], k_ref_list[i], v_ref_list[i],
causal=causal,
window_size=window_size,
softmax_scale=softmax_scale,
......@@ -1369,7 +1475,19 @@ def test_unified_attn_2d(
ref_out = torch.cat(ref_outs, dim=0)
# ---- CUDA kernel ----
cuda_out, cuda_lse = varlen_fwd_unified(
expected_hg_symbol = None
if seqlen_q >= 16:
if not (not use_fp8 and d == 256 and d_v == 256):
expected_hg_symbol = "hg_prefix_prefill_varlen_fwd"
elif use_fp8 or seqlen_q == 1 or 1 < seqlen_q < 16:
expected_hg_symbol = "hg_prefix_decode_varlen_fwd"
varlen_runner = (
varlen_fwd_unified
if expected_hg_symbol is None
else lambda *args, **kwargs: varlen_fwd_unified_expect_hg(expected_hg_symbol, *args, **kwargs)
)
cuda_out, cuda_lse = varlen_runner(
q_varlen, k_cache, v_cache,
cu_seqlens_q, seqused_k, block_table,
max_seqlen_q=seqlen_q,
......@@ -1384,6 +1502,9 @@ def test_unified_attn_2d(
s_aux=sinks,
mm_prefix_range=mm_prefix_range,
return_softmax_lse=True,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
)
# # ---- Triton kernel ----
......@@ -1416,15 +1537,164 @@ def test_unified_attn_2d(
# triton_max_diff = (triton_out - ref_out).abs().max().item()
print(
f"\n[{dtype} | causal={causal} | {mha_type} | bs={batch_size} "
f"sq={seqlen_q} sk={seqlen_k} blk={block_size} | "
f"\n[{dtype} | fp8={use_fp8} | causal={causal} | {mha_type} | bs={batch_size} "
f"sq={seqlen_q} sk={seqlen_k} d={d}/{d_v} blk={block_size} | "
f"alibi_sqrt={use_alibi_sqrt} qq_bias={use_qq_bias} "
f"sinks={use_sinks} mm_prefix={use_mm_prefix}]"
f"\n CUDA max_diff={cuda_max_diff:.4e}"
# f"\n Triton max_diff={triton_max_diff:.4e}"
)
cal_diff(cuda_out, ref_out, "out")
cutlass_b16_256_prefill = (not use_fp8 and seqlen_q >= 16 and d == 256 and d_v == 256)
cal_diff(
cuda_out,
ref_out,
"out",
use_fp8=use_fp8,
cos_threshold=1e-3 if cutlass_b16_256_prefill else None,
)
@pytest.mark.parametrize(
"batch_size,seqlen_q,seqlen_k,causal,window_size",
[
pytest.param(1, 256, 4096, True, (-1, -1), id="causal-large-sq"),
pytest.param(2, 257, 4103, True, (-1, -1), id="causal-unaligned-sq-sk"),
pytest.param(4, 512, 8193, True, (-1, -1), id="causal-large-bs-sq-unaligned-sk"),
# The torch reference materializes dense [heads, seq_q, seq_kv] scores,
# so keep 4K/8K correctness at bs=1 while still hitting long prefill.
pytest.param(1, 4096, 4096, True, (-1, -1), id="causal-sq4096"),
pytest.param(1, 8192, 8192, True, (-1, -1), id="causal-sq8192"),
pytest.param(1, 4097, 8193, True, (-1, -1), id="causal-unaligned-sq4097-sk8193"),
pytest.param(1, 17, 129, False, (511, 0), id="swa-lower-boundary-unaligned"),
pytest.param(3, 65, 1025, False, (511, 0), id="swa-mid-unaligned"),
pytest.param(2, 513, 1537, False, (511, 0), id="swa-large-unaligned-sq-sk"),
pytest.param(1, 4096, 8193, False, (511, 0), id="swa-sq4096-sk8193"),
pytest.param(1, 8192, 8193, False, (511, 0), id="swa-sq8192-sk8193"),
],
)
def test_unified_attn_fp8_192x128_prefill_corner_cases(
batch_size, seqlen_q, seqlen_k, causal, window_size
):
test_unified_attn_2d(
batch_size=batch_size,
seqlen_q=seqlen_q,
seqlen_k=seqlen_k,
block_size=UNIFIED_BLOCK_SIZE,
d=192,
d_v=128,
causal=causal,
window_size=window_size,
softcap=0.0,
mha_type="gqa",
dtype=torch.bfloat16,
use_alibi_sqrt=False,
use_qq_bias=False,
use_sinks=False,
use_mm_prefix=False,
use_fp8=True,
)
@pytest.mark.parametrize(
"batch_size,seqlen_q,seqlen_k,causal,window_size",
[
pytest.param(1, 17, 129, True, (-1, -1), id="causal-lower-boundary-unaligned"),
pytest.param(2, 65, 1025, True, (-1, -1), id="causal-mid-unaligned"),
pytest.param(1, 256, 4097, True, (-1, -1), id="causal-large-unaligned-sk"),
pytest.param(1, 17, 129, False, (511, 0), id="swa-lower-boundary-unaligned"),
pytest.param(2, 65, 1025, False, (511, 0), id="swa-mid-unaligned"),
pytest.param(1, 256, 4097, False, (511, 0), id="swa-large-unaligned-sk"),
],
)
def test_unified_attn_fp8_256x256_prefill_corner_cases(
batch_size, seqlen_q, seqlen_k, causal, window_size
):
test_unified_attn_2d(
batch_size=batch_size,
seqlen_q=seqlen_q,
seqlen_k=seqlen_k,
block_size=UNIFIED_BLOCK_SIZE,
d=256,
d_v=256,
causal=causal,
window_size=window_size,
softcap=0.0,
mha_type="gqa",
dtype=torch.bfloat16,
use_alibi_sqrt=False,
use_qq_bias=False,
use_sinks=False,
use_mm_prefix=False,
use_fp8=True,
)
@pytest.mark.parametrize(
"batch_size,seqlen_q,seqlen_k,causal,window_size",
[
pytest.param(1, 1, 4096, True, (-1, -1), id="decode-sq1-large-sk"),
pytest.param(2, 4, 2048, True, (-1, -1), id="decode-mtp4-large-sk"),
pytest.param(1, 17, 129, True, (-1, -1), id="prefill-causal-lower-boundary"),
pytest.param(2, 65, 1025, False, (511, 0), id="prefill-swa-unaligned"),
],
)
def test_unified_attn_fp8_192x128_sinks_corner_cases(
batch_size, seqlen_q, seqlen_k, causal, window_size
):
test_unified_attn_2d(
batch_size=batch_size,
seqlen_q=seqlen_q,
seqlen_k=seqlen_k,
block_size=UNIFIED_BLOCK_SIZE,
d=192,
d_v=128,
causal=causal,
window_size=window_size,
softcap=0.0,
mha_type="gqa",
dtype=torch.bfloat16,
use_alibi_sqrt=False,
use_qq_bias=False,
use_sinks=True,
use_mm_prefix=False,
use_fp8=True,
)
@pytest.mark.parametrize("block_size", [64, 128])
@pytest.mark.parametrize("d,d_v", [(128, 128), (192, 128)])
@pytest.mark.parametrize(
"batch_size,seqlen_q,seqlen_k,causal,window_size",
[
pytest.param(1, 64, 257, True, (-1, -1), id="prefill-causal-unaligned"),
pytest.param(2, 65, 1025, False, (511, 0), id="prefill-swa-unaligned"),
pytest.param(16, 1, 2048, True, (-1, -1), id="decode-sq1"),
pytest.param(16, 4, 2048, False, (-1, -1), id="decode-mtp4-noncausal"),
pytest.param(16, 4, 2048, False, (511, 0), id="decode-mtp4-swa"),
],
)
def test_unified_attn_b16_page64_regression(
batch_size, seqlen_q, seqlen_k, causal, window_size, d, d_v, block_size
):
test_unified_attn_2d(
batch_size=batch_size,
seqlen_q=seqlen_q,
seqlen_k=seqlen_k,
block_size=block_size,
d=d,
d_v=d_v,
causal=causal,
window_size=window_size,
softcap=0.0,
mha_type="gqa",
dtype=torch.bfloat16,
use_alibi_sqrt=False,
use_qq_bias=False,
use_sinks=False,
use_mm_prefix=False,
use_fp8=False,
)
# ---------------------------------------------------------------------------
......@@ -1440,8 +1710,8 @@ def benchmark_unified_attention():
torch.manual_seed(42)
dtype = torch.float16
d = 256
block_size = 320
d = 128
block_size = UNIFIED_BLOCK_SIZE
warmup = 10
repeat = 50
......@@ -1452,30 +1722,35 @@ def benchmark_unified_attention():
MAX_MM_RANGES = 2
# GQA
nheads = 8
nheads_k = 1
nheads = 24
nheads_k = 2
# workload shapes
shapes = [
# (8, 2048, 2048),
# (4, 2048, 2048),
(4, 1, 2048),
(1, 4, 51200),
(2, 4, 51200),
(4, 4, 51200),
(8, 4, 51200),
(16, 4, 51200),
(32, 4, 51200),
# (4, 2048, 4096),
]
# feature configs (C A Q S P)
feature_configs = [
(1,0,0,0,0),
(1,1,0,0,0),
(1,0,1,0,0),
(1,0,0,1,0),
(1,0,0,0,1),
(1,1,1,0,0),
(1,1,0,1,0),
(1,1,0,0,1),
(1,1,1,1,0),
(1,1,1,0,1),
(1,1,1,1,1),
# (1,1,0,0,0),
# (1,0,1,0,0),
# (1,0,0,1,0),
# (1,0,0,0,1),
# (1,1,1,0,0),
# (1,1,0,1,0),
# (1,1,0,0,1),
# (1,1,1,1,0),
# (1,1,1,0,1),
# (1,1,1,1,1),
]
print("\nUnified Attention GQA Benchmark")
......@@ -1485,7 +1760,8 @@ def benchmark_unified_attention():
f"{'BS':>3} {'SQ':>6} {'SK':>6} | "
f"{'C':>1} {'A':>1} {'Q':>1} {'S':>1} {'P':>1} | "
f"{'CUDA(ms)':>10} {'Triton(ms)':>11} | "
f"{'CUDA TFLOPS':>11} {'Triton TFLOPS':>13} | {'Speedup':>8}"
f"{'CUDA TFLOPS':>11} {'Triton TFLOPS':>13} | "
f"{'CUDA(GB/s)':>10} {'Triton(GB/s)':>12} | {'Speedup':>8}"
)
print("-" * 120)
......@@ -1548,6 +1824,15 @@ def benchmark_unified_attention():
triton_out = torch.zeros_like(q_varlen)
total_bytes = estimate_unified_attention_bytes(
batch_size, seqlen_q, seqlen_k, nheads, nheads_k, d, block_size,
q_bytes=torch.finfo(dtype).bits // 8,
k_bytes=torch.finfo(dtype).bits // 8,
v_bytes=torch.finfo(dtype).bits // 8,
d_v=d,
window_size=window_size,
)
for C, A, Q, S, P in feature_configs:
causal = bool(C)
......@@ -1574,7 +1859,8 @@ def benchmark_unified_attention():
sinks = None
if use_sinks:
sinks = torch.randn(nheads, device=device, dtype=dtype)
sink_dtype = torch.bfloat16 if use_fp8 else torch.float32
sinks = torch.randn(nheads, device=device, dtype=sink_dtype)
mm_prefix_range = None
if use_mm_prefix:
......@@ -1655,6 +1941,9 @@ def benchmark_unified_attention():
# FLOPs
flops = 4.0 * batch_size * nheads * seqlen_q * seqlen_k * d
cuda_bandwidth = total_bytes / 1e9 / cuda_ms * 1000
triton_bandwidth = total_bytes / 1e9 / triton_ms * 1000
cuda_tflops = flops / cuda_ms / 1e9
triton_tflops = flops / triton_ms / 1e9
......@@ -1663,11 +1952,277 @@ def benchmark_unified_attention():
f"{C} {A} {Q} {S} {P} | "
f"{cuda_ms:10.3f} {triton_ms:11.3f} | "
f"{cuda_tflops:11.2f} {triton_tflops:13.2f} | "
f"{cuda_bandwidth:10.2f} {triton_bandwidth:12.2f} | "
f"{triton_ms/cuda_ms:8.2f}x"
)
print("=" * 120)
def benchmark_hg_b16_fp8_pa():
device = torch.device("cuda")
torch.manual_seed(42)
dtype = torch.bfloat16
fp8_dtype = current_platform.fp8_dtype()
d = 192
d_v = 128
block_size = UNIFIED_BLOCK_SIZE
warmup = 10
repeat = 50
nheads = 16
nheads_k = 2
softcap = 0.0
shapes = [
(bs, seqlen_q, 51200)
for seqlen_q in (1, 4)
for bs in (1, 2, 4, 8, 16, 32, 64)
]
shapes += [
(bs, seqlen_q, 4096)
for seqlen_q in (32, 128, 256, 512)
for bs in (1, 2, 4)
]
shapes += [
(1, 257, 4103),
(2, 513, 8193),
(1, 4096, 4096),
(1, 4097, 8193),
(1, 8192, 8192),
(1, 8192, 8193),
]
windows = [(-1, -1), (511, 0)]
def time_fn(fn):
for _ in range(warmup):
fn()
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(repeat):
fn()
torch.cuda.synchronize()
return (time.perf_counter() - start) / repeat * 1000
def unwrap_out(result):
if isinstance(result, (tuple, list)):
return result[0]
return result
def diff_stats(x, y):
x = x.float()
y = y.float()
denom = torch.clamp((x * x + y * y).sum(), min=1e-12)
cos_diff = 1 - 2 * (x * y).sum() / denom
max_diff = (x - y).abs().max()
return cos_diff.item(), max_diff.item()
max_ref_scores = int(os.getenv("BENCH_REF_MAX_SCORES", "20000000"))
print("\nHG Unified PA BF16/FP8 Benchmark")
print("=" * 170)
print(
f"{'BS':>3} {'SQ':>3} {'SK':>6} {'D':>7} {'WINDOW':>12} | "
f"{'HG_OK':>5} {'HG(ms)':>8} {'TRI(ms)':>8} {'FP8(ms)':>8} | "
f"{'FP8/HG':>8} {'FP8/TRI':>8} {'HG/TRI':>8} | "
f"{'REF_cos':>9} {'REF_max':>9} | {'FP8 GB/s':>9} {'NOTE':>18}"
)
print("-" * 170)
summary = {
"total": 0,
"hg_ok": 0,
"hg_fail": 0,
"fp8_hg_speedups": [],
"fp8_triton_speedups": [],
}
for window_size in windows:
for batch_size, seqlen_q, seqlen_k in shapes:
summary["total"] += 1
causal = window_size == (-1, -1)
softmax_scale = d ** (-0.5)
q_list = [torch.randn(seqlen_q, nheads, d, device=device, dtype=dtype) for _ in range(batch_size)]
k_list = [torch.randn(seqlen_k, nheads_k, d, device=device, dtype=dtype) for _ in range(batch_size)]
v_list = [torch.randn(seqlen_k, nheads_k, d_v, device=device, dtype=dtype) for _ in range(batch_size)]
q_b16 = torch.cat(q_list, dim=0)
k_b16, v_b16, block_table = make_paged_kv(k_list, v_list, block_size, device, dtype)
q_fp8 = q_b16.to(fp8_dtype)
k_fp8 = k_b16.to(fp8_dtype)
v_fp8 = v_b16.to(fp8_dtype)
cu_seqlens_q = torch.zeros(batch_size + 1, device=device, dtype=torch.int32)
cu_seqlens_q[1:] = torch.cumsum(
torch.tensor([seqlen_q] * batch_size, dtype=torch.int32), dim=0
)
seqused_k = torch.tensor([seqlen_k] * batch_size, device=device, dtype=torch.int32)
q_descale = torch.ones((batch_size, nheads), device=device, dtype=torch.float32)
k_descale = torch.ones((batch_size, nheads_k), device=device, dtype=torch.float32)
v_descale = torch.ones((batch_size, nheads_k), device=device, dtype=torch.float32)
expected_hg_symbol = (
"hg_prefix_prefill_varlen_fwd"
if seqlen_q >= 16
else "hg_prefix_decode_varlen_fwd"
)
def run_b16_hg_checked():
return unwrap_out(varlen_fwd_unified_expect_hg(
expected_hg_symbol,
q_b16, k_b16, v_b16, cu_seqlens_q, seqused_k, block_table,
max_seqlen_q=seqlen_q, max_seqlen_k=seqlen_k,
softmax_scale=softmax_scale, causal=causal,
window_size=window_size, softcap=softcap,
return_softmax_lse=False,
))
def run_b16_hg():
return unwrap_out(varlen_fwd_unified(
q_b16, k_b16, v_b16, cu_seqlens_q, seqused_k, block_table,
max_seqlen_q=seqlen_q, max_seqlen_k=seqlen_k,
softmax_scale=softmax_scale, causal=causal,
window_size=window_size, softcap=softcap,
return_softmax_lse=False,
))
def run_fp8_checked():
return unwrap_out(varlen_fwd_unified_expect_hg(
expected_hg_symbol,
q_fp8, k_fp8, v_fp8, cu_seqlens_q, seqused_k, block_table,
max_seqlen_q=seqlen_q, max_seqlen_k=seqlen_k,
softmax_scale=softmax_scale, causal=causal,
window_size=window_size, softcap=softcap,
return_softmax_lse=False,
q_descale=q_descale, k_descale=k_descale, v_descale=v_descale,
))
def run_fp8():
return unwrap_out(varlen_fwd_unified(
q_fp8, k_fp8, v_fp8, cu_seqlens_q, seqused_k, block_table,
max_seqlen_q=seqlen_q, max_seqlen_k=seqlen_k,
softmax_scale=softmax_scale, causal=causal,
window_size=window_size, softcap=softcap,
return_softmax_lse=False,
q_descale=q_descale, k_descale=k_descale, v_descale=v_descale,
))
triton_out = torch.empty(
(q_b16.shape[0], nheads, d_v), device=device, dtype=dtype
)
def run_triton_b16():
unified_attention(
q=q_b16,
k=k_b16,
v=v_b16,
out=triton_out,
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=seqlen_q,
seqused_k=seqused_k,
max_seqlen_k=seqlen_k,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
block_table=block_table,
softcap=softcap,
q_descale=None,
k_descale=None,
v_descale=None,
seq_threshold_3D=128,
)
return triton_out
note = ""
hg_ms = float("nan")
fp8_over_hg = float("nan")
hg_over_triton = float("nan")
hg_cos = float("nan")
hg_max = float("nan")
triton_ms = time_fn(run_triton_b16)
run_fp8_checked()
torch.cuda.synchronize()
fp8_ms = time_fn(run_fp8)
try:
b16_hg_out = run_b16_hg_checked()
torch.cuda.synchronize()
num_ref_scores = batch_size * seqlen_q * seqlen_k * nheads
hg_correct = True
if num_ref_scores <= max_ref_scores:
ref_out = torch.cat([
ref_attn(
q_list[i], k_list[i], v_list[i],
causal=causal,
window_size=window_size,
softmax_scale=softmax_scale,
softcap=softcap,
)
for i in range(batch_size)
], dim=0)
hg_cos, hg_max = diff_stats(b16_hg_out, ref_out)
hg_correct = hg_cos < 1e-3
if not hg_correct:
note = "HG_BF16_REF_DIFF"
else:
note = "REF_SKIP"
if not hg_correct:
summary["hg_fail"] += 1
else:
hg_ms = time_fn(run_b16_hg)
fp8_over_hg = hg_ms / fp8_ms
hg_over_triton = triton_ms / hg_ms
summary["hg_ok"] += 1
summary["fp8_hg_speedups"].append(fp8_over_hg)
except Exception as exc:
summary["hg_fail"] += 1
note = type(exc).__name__
fp8_over_triton = triton_ms / fp8_ms
summary["fp8_triton_speedups"].append(fp8_over_triton)
fp8_bytes = estimate_unified_attention_bytes(
batch_size, seqlen_q, seqlen_k, nheads, nheads_k, d, block_size,
q_bytes=1, k_bytes=1, v_bytes=1, out_bytes=2, d_v=d_v,
window_size=window_size,
)
print(
f"{batch_size:3d} {seqlen_q:3d} {seqlen_k:6d} {str(d) + '/' + str(d_v):>7} {str(window_size):>12} | "
f"{str(math.isfinite(hg_ms)):>5} {hg_ms:8.3f} {triton_ms:8.3f} {fp8_ms:8.3f} | "
f"{fp8_over_hg:8.2f} {fp8_over_triton:8.2f} {hg_over_triton:8.2f} | "
f"{hg_cos:9.2e} {hg_max:9.2e} | "
f"{fp8_bytes / 1e9 / fp8_ms * 1000:9.2f} {note:>18}"
)
print("-" * 170)
if summary["fp8_hg_speedups"]:
hg_speedups = torch.tensor(summary["fp8_hg_speedups"], dtype=torch.float32)
print(
"HG_BF16 baseline summary: "
f"ok={summary['hg_ok']}/{summary['total']} "
f"fail={summary['hg_fail']} "
f"fp8/hg mean={hg_speedups.mean().item():.3f} "
f"median={hg_speedups.median().item():.3f} "
f"min={hg_speedups.min().item():.3f} "
f"max={hg_speedups.max().item():.3f}"
)
triton_speedups = torch.tensor(summary["fp8_triton_speedups"], dtype=torch.float32)
print(
"Triton BF16 reference summary: "
f"fp8/triton mean={triton_speedups.mean().item():.3f} "
f"median={triton_speedups.median().item():.3f} "
f"min={triton_speedups.min().item():.3f} "
f"max={triton_speedups.max().item():.3f}"
)
print("=" * 170)
if __name__ == "__main__":
benchmark_unified_attention()
\ No newline at end of file
if os.getenv("RUN_UNIFIED_BENCHMARK") == "1":
benchmark_hg_b16_fp8_pa()
else:
test_unified_attn_2d(
batch_size=1, seqlen_q=1, seqlen_k=40960, block_size=UNIFIED_BLOCK_SIZE,
d=192, d_v=128, causal=False, window_size=(511, 0), softcap=0.0,
mha_type="gqa", dtype=torch.bfloat16,
use_alibi_sqrt=False, use_qq_bias=False, use_sinks=False, use_mm_prefix=False,
use_fp8=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment