Commit 5248d7d2 authored by hly's avatar hly
Browse files

Import latest aicc hipcc fp8 pa snapshot.

Source: feature/aicc-hipcc-unified-attn-fp8-pa @ fc89765
parent c2a1b310
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
#ifdef BUILD_FA_PERMUTE #ifdef BUILD_FA_PERMUTE
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include "../../include/intrinsic.h"
#include "../flash_fwd_permute_hdim128.h" #include "../flash_fwd_permute_hdim128.h"
template<> template<>
...@@ -141,27 +140,21 @@ __global__ void flash_fwd_varlen_permute_bhsd2bshd<128, 4, 32>( ...@@ -141,27 +140,21 @@ __global__ void flash_fwd_varlen_permute_bhsd2bshd<128, 4, 32>(
int32_t thread_offset = lane_id_col * 8; int32_t thread_offset = lane_id_col * 8;
// 一次读取 4x128 的 Half 到 LDS // 一次读取 4x128 的 Half 到 LDS
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__)
{ __builtin_hcu_raw_buffer_load_lds(
auto *lds_ptr = (__attribute__((address_space(3))) int *)( read_buffer,
reinterpret_cast<size_t>(lds) + static_cast<size_t>(lane_id * 4) * sizeof(float)); lds + lane_id * 4,
__builtin_hcu_raw_buffer_load_lds( 16,
read_buffer, (block_offset + thread_offset) << 1, /* v_offset */
lds_ptr, 0, /* s_offset */
16, 0, /* immediate offset, instruction offset */
(block_offset + thread_offset) << 1, /* v_offset */ 0 /* auxilariy data| bit 0: glc, bit 1: slc, bit 2: dlc, bit 3: cache swizzle */
0, /* s_offset */ );
0, /* immediate offset, instruction offset */
0 /* auxilariy data| bit 0: glc, bit 1: slc, bit 2: dlc, bit 3: cache swizzle */
);
}
#else #else
*(vec4_fp32*)(lds + lane_id * 4) = *(vec4_fp32*)(read_ptr + ((block_offset + thread_offset) >> 1)); *(vec4_fp32*)(lds + lane_id * 4) = *(vec4_fp32*)(read_ptr + ((block_offset + thread_offset) >> 1));
#endif #endif
// 从 LDS 转置后, 64 个线程写 4 行, 每次写 128 个 Half, 对应 fetch * 4 + [0,3] 的 seqlen // 从 LDS 转置后, 64 个线程写 4 行, 每次写 128 个 Half, 对应 fetch * 4 + [0,3] 的 seqlen
vec2_fp32 data0, data1; vec2_fp32 data0 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float*)lds + lane_id, 0, 64, false);
inlineasm_fa_ds_read2_b32(lds, lane_id, data0, 0, 64); vec2_fp32 data1 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float*)lds + lane_id + 128, 0, 64, false);
inlineasm_fa_ds_read2_b32(lds, lane_id + 128, data1, 0, 64);
asm volatile("s_waitcnt lgkmcnt(0)\n");
write_ptr[(min(actual_seqlen - 1, fetch * 4 + 0) * num_heads * kHeadDim + (lane_id << 1)) >> 1] = data0[0]; write_ptr[(min(actual_seqlen - 1, fetch * 4 + 0) * num_heads * kHeadDim + (lane_id << 1)) >> 1] = data0[0];
write_ptr[(min(actual_seqlen - 1, fetch * 4 + 1) * num_heads * kHeadDim + (lane_id << 1)) >> 1] = data0[1]; write_ptr[(min(actual_seqlen - 1, fetch * 4 + 1) * num_heads * kHeadDim + (lane_id << 1)) >> 1] = data0[1];
write_ptr[(min(actual_seqlen - 1, fetch * 4 + 2) * num_heads * kHeadDim + (lane_id << 1)) >> 1] = data1[0]; write_ptr[(min(actual_seqlen - 1, fetch * 4 + 2) * num_heads * kHeadDim + (lane_id << 1)) >> 1] = data1[0];
...@@ -370,4 +363,4 @@ __global__ void flash_fwd_varlen_permute_bhsd2bshd<256, 1, 32>( ...@@ -370,4 +363,4 @@ __global__ void flash_fwd_varlen_permute_bhsd2bshd<256, 1, 32>(
#endif #endif
\ No newline at end of file
#ifdef BUILD_FA_PERMUTE #ifdef BUILD_FA_PERMUTE
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include "../../include/intrinsic.h"
#include "../flash_fwd_permute_hdim128.h" #include "../flash_fwd_permute_hdim128.h"
...@@ -119,28 +118,22 @@ __global__ void flash_fwd_varlen_permute_bshd2bhsd<128, 4, 0>( ...@@ -119,28 +118,22 @@ __global__ void flash_fwd_varlen_permute_bshd2bhsd<128, 4, 0>(
int32_t thread_offset = lane_id_row * 128 + lane_id_col * 8; int32_t thread_offset = lane_id_row * 128 + lane_id_col * 8;
// block 地址 + thread 地址, << 1 是获取偏移的字节数, 写到 lds 是为了转置一下 // block 地址 + thread 地址, << 1 是获取偏移的字节数, 写到 lds 是为了转置一下
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__)
{ __builtin_hcu_raw_buffer_load_lds(
auto *lds_ptr = (__attribute__((address_space(3))) int *)( read_buffer,
reinterpret_cast<size_t>(lds) + static_cast<size_t>(lane_id * 4) * sizeof(float)); lds + lane_id * 4,
__builtin_hcu_raw_buffer_load_lds( 16,
read_buffer, (block_offset + thread_offset) << 1, /* v_offset */
lds_ptr, 0, /* s_offset */
16, 0, /* immediate offset, instruction offset */
(block_offset + thread_offset) << 1, /* v_offset */ 0 /* auxilariy data| bit 0: glc, bit 1: slc, bit 2: dlc, bit 3: cache swizzle */
0, /* s_offset */ );
0, /* immediate offset, instruction offset */
0 /* auxilariy data| bit 0: glc, bit 1: slc, bit 2: dlc, bit 3: cache swizzle */
);
}
#else #else
*(vec4_fp32*)(lds + lane_id * 4) = *(vec4_fp32*)(read_ptr + ((block_offset + thread_offset) >> 1)); *(vec4_fp32*)(lds + lane_id * 4) = *(vec4_fp32*)(read_ptr + ((block_offset + thread_offset) >> 1));
#endif #endif
// 写到 lds 不需要同步, 因为只有一个 wave // 写到 lds 不需要同步, 因为只有一个 wave
// 循环 seqlen_q 次, 每次间隔 4 x 128 个 half, 需要写 4 次 // 循环 seqlen_q 次, 每次间隔 4 x 128 个 half, 需要写 4 次
vec2_fp32 data0, data1; vec2_fp32 data0 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float*)lds + lane_id, 0, 64, false);
inlineasm_fa_ds_read2_b32(lds, lane_id, data0, 0, 64); vec2_fp32 data1 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float*)lds + lane_id + 128, 0, 64, false);
inlineasm_fa_ds_read2_b32(lds, lane_id + 128, data1, 0, 64);
asm volatile("s_waitcnt lgkmcnt(0)\n");
write_ptr[(fetch * head_dim + (lane_id << 1) + 0 * cur_seqlen_q * head_dim) >> 1] = data0[0]; write_ptr[(fetch * head_dim + (lane_id << 1) + 0 * cur_seqlen_q * head_dim) >> 1] = data0[0];
write_ptr[(fetch * head_dim + (lane_id << 1) + 1 * cur_seqlen_q * head_dim) >> 1] = data0[1]; write_ptr[(fetch * head_dim + (lane_id << 1) + 1 * cur_seqlen_q * head_dim) >> 1] = data0[1];
write_ptr[(fetch * head_dim + (lane_id << 1) + 2 * cur_seqlen_q * head_dim) >> 1] = data1[0]; write_ptr[(fetch * head_dim + (lane_id << 1) + 2 * cur_seqlen_q * head_dim) >> 1] = data1[0];
...@@ -268,28 +261,22 @@ __global__ void flash_fwd_varlen_permute_bshd2bhsd<128, 4, 32>( ...@@ -268,28 +261,22 @@ __global__ void flash_fwd_varlen_permute_bshd2bhsd<128, 4, 32>(
int32_t thread_offset = lane_id_row * 128 + lane_id_col * 8; int32_t thread_offset = lane_id_row * 128 + lane_id_col * 8;
// block 地址 + thread 地址, << 1 是获取偏移的字节数, 写到 lds 是为了转置一下 // block 地址 + thread 地址, << 1 是获取偏移的字节数, 写到 lds 是为了转置一下
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__)
{ __builtin_hcu_raw_buffer_load_lds(
auto *lds_ptr = (__attribute__((address_space(3))) int *)( read_buffer,
reinterpret_cast<size_t>(lds) + static_cast<size_t>(lane_id * 4) * sizeof(float)); lds + lane_id * 4,
__builtin_hcu_raw_buffer_load_lds( 16,
read_buffer, (block_offset + thread_offset) << 1, /* v_offset */
lds_ptr, 0, /* s_offset */
16, 0, /* immediate offset, instruction offset */
(block_offset + thread_offset) << 1, /* v_offset */ 0 /* auxilariy data| bit 0: glc, bit 1: slc, bit 2: dlc, bit 3: cache swizzle */
0, /* s_offset */ );
0, /* immediate offset, instruction offset */
0 /* auxilariy data| bit 0: glc, bit 1: slc, bit 2: dlc, bit 3: cache swizzle */
);
}
#else #else
*(vec4_fp32*)(lds + lane_id * 4) = *(vec4_fp32*)(read_ptr + ((block_offset + thread_offset) >> 1)); *(vec4_fp32*)(lds + lane_id * 4) = *(vec4_fp32*)(read_ptr + ((block_offset + thread_offset) >> 1));
#endif #endif
// 写到 lds 不需要同步, 因为只有一个 wave // 写到 lds 不需要同步, 因为只有一个 wave
// 循环 seqlen_q 次, 每次间隔 4 x 128 个 half, 需要写 4 次 // 循环 seqlen_q 次, 每次间隔 4 x 128 个 half, 需要写 4 次
vec2_fp32 data0, data1; vec2_fp32 data0 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float*)lds + lane_id, 0, 64, false);
inlineasm_fa_ds_read2_b32(lds, lane_id, data0, 0, 64); vec2_fp32 data1 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float*)lds + lane_id + 128, 0, 64, false);
inlineasm_fa_ds_read2_b32(lds, lane_id + 128, data1, 0, 64);
asm volatile("s_waitcnt lgkmcnt(0)\n");
write_ptr[(seqlen_limit * head_dim + (lane_id << 1) + 0 * cur_seqlen_q * head_dim) >> 1] = data0[0]; write_ptr[(seqlen_limit * head_dim + (lane_id << 1) + 0 * cur_seqlen_q * head_dim) >> 1] = data0[0];
write_ptr[(seqlen_limit * head_dim + (lane_id << 1) + 1 * cur_seqlen_q * head_dim) >> 1] = data0[1]; write_ptr[(seqlen_limit * head_dim + (lane_id << 1) + 1 * cur_seqlen_q * head_dim) >> 1] = data0[1];
write_ptr[(seqlen_limit * head_dim + (lane_id << 1) + 2 * cur_seqlen_q * head_dim) >> 1] = data1[0]; write_ptr[(seqlen_limit * head_dim + (lane_id << 1) + 2 * cur_seqlen_q * head_dim) >> 1] = data1[0];
...@@ -383,9 +370,8 @@ __global__ void __launch_bounds__(64, 1) flash_fwd_varlen_permute_bshd2bhsd<128, ...@@ -383,9 +370,8 @@ __global__ void __launch_bounds__(64, 1) flash_fwd_varlen_permute_bshd2bhsd<128,
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
#endif #endif
// 循环 seqlen_q 次, 每次间隔 4 x 128 个 half, 需要写 4 次 // 循环 seqlen_q 次, 每次间隔 4 x 128 个 half, 需要写 4 次
inlineasm_fa_ds_read2_b32(lds, fetch * 256 + lane_id, registers_buffer[fetch * 2], 0, 64); registers_buffer[fetch * 2] = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float*)lds + fetch * 256 + lane_id, 0, 64, false);
inlineasm_fa_ds_read2_b32(lds, fetch * 256 + lane_id + 128, registers_buffer[fetch * 2 + 1], 0, 64); registers_buffer[fetch * 2 + 1] = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float*)lds + fetch * 256 + lane_id + 128, 0, 64, false);
asm volatile("s_waitcnt lgkmcnt(0)\n");
} }
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
...@@ -688,4 +674,4 @@ template<> ...@@ -688,4 +674,4 @@ template<>
__global__ void __launch_bounds__(64, 1) flash_fwd_varlen_permute_bshd2bhsd<256, 4, 32>( __global__ void __launch_bounds__(64, 1) flash_fwd_varlen_permute_bshd2bhsd<256, 4, 32>(
void* output, void* query, void* split_sizes, int64_t head_stride, int32_t num_heads, int real_headdim) {} void* output, void* query, void* split_sizes, int64_t head_stride, int32_t num_heads, int real_headdim) {}
#endif // end of BUILD_FA_PERMUTE #endif // end of BUILD_FA_PERMUTE
\ No newline at end of file
...@@ -26,7 +26,6 @@ if torch.cuda.is_available(): ...@@ -26,7 +26,6 @@ if torch.cuda.is_available():
flash_attn_varlen_with_mask_func, flash_attn_varlen_with_mask_func,
# unified attn functions # unified attn functions
varlen_fwd_unified, varlen_fwd_unified,
fwd_sparse_mean_pool_fast,
) )
# triton fa interface # triton fa interface
from flash_attn.flash_attn_triton_interface import flash_attn_func as triton_flash_attn_func from flash_attn.flash_attn_triton_interface import flash_attn_func as triton_flash_attn_func
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
from typing import Optional, Union from typing import Optional, Union
from typing import List, Tuple from typing import List, Tuple
import os
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -19,6 +18,12 @@ from flash_attn.utils.sparse_utils import hyperparameter_check, get_block_map_me ...@@ -19,6 +18,12 @@ from flash_attn.utils.sparse_utils import hyperparameter_check, get_block_map_me
DEFAULT_FA_VERSION = 2 DEFAULT_FA_VERSION = 2
try:
torch._C._dispatch_find_schema_or_throw("flash_attn2_c_op::varlen_fwd", "")
_has_flash_attn2_c_varlen_fwd = True
except RuntimeError:
_has_flash_attn2_c_varlen_fwd = False
def maybe_contiguous(x): def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x return x.contiguous() if x is not None and x.stride(-1) != 1 else x
...@@ -59,7 +64,7 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal): ...@@ -59,7 +64,7 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal):
elif head_dim <= 512: elif head_dim <= 512:
return 64 return 64
if torch.__version__ >= "2.4.0": if torch.__version__ >= "2.4.0" and _has_flash_attn2_c_varlen_fwd:
_torch_custom_op_wrapper = torch.library.custom_op _torch_custom_op_wrapper = torch.library.custom_op
_torch_register_fake_wrapper = torch.library.register_fake _torch_register_fake_wrapper = torch.library.register_fake
else: else:
...@@ -199,7 +204,11 @@ def varlen_fwd_fake( ...@@ -199,7 +204,11 @@ def varlen_fwd_fake(
return out, softmax_lse, p, rng_state return out, softmax_lse, p, rng_state
_wrapped_flash_attn_varlen_forward = torch.ops.flash_attn2_c_op.varlen_fwd _wrapped_flash_attn_varlen_forward = (
torch.ops.flash_attn2_c_op.varlen_fwd
if _has_flash_attn2_c_varlen_fwd
else _flash_attn_varlen_forward
)
def _flash_attn_backward( def _flash_attn_backward(
...@@ -2008,7 +2017,7 @@ def vllm_flash_attn_varlen_func( ...@@ -2008,7 +2017,7 @@ def vllm_flash_attn_varlen_func(
# if mtp, k head must be 1. # if mtp, k head must be 1.
# todo : support k head >1 # todo : support k head >1
is_mtp = (max_seqlen_q*bs==total_q and max_seqlen_q>1 and max_seqlen_q<5) is_mtp = (max_seqlen_q*bs==total_q and max_seqlen_q>1 and max_seqlen_q<5)
if (max_seqlen_q==1 or is_mtp ) and real_window_size[0]==-1: if (max_seqlen_q==1 or (is_mtp and k.shape[1]==1)) and real_window_size[0]==-1:
if out==None: if out==None:
if q.dtype == torch.float8_e4m3fn or q.dtype == torch.float8_e5m2: if q.dtype == torch.float8_e4m3fn or q.dtype == torch.float8_e5m2:
out = torch.empty(q.size(),device = q.device,dtype=torch.bfloat16) out = torch.empty(q.size(),device = q.device,dtype=torch.bfloat16)
...@@ -2732,6 +2741,9 @@ def varlen_fwd_unified( ...@@ -2732,6 +2741,9 @@ def varlen_fwd_unified(
*, *,
out=None, out=None,
return_softmax_lse=False, return_softmax_lse=False,
q_descale=None,
k_descale=None,
v_descale=None,
): ):
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
...@@ -2741,6 +2753,78 @@ def varlen_fwd_unified( ...@@ -2741,6 +2753,78 @@ def varlen_fwd_unified(
window_size_left, window_size_right = window_size window_size_left, window_size_right = window_size
if k.dtype.is_floating_point:
k_dtype_bits = torch.finfo(k.dtype).bits
else:
k_dtype_bits = torch.iinfo(k.dtype).bits
is_mtp = (max_seqlen_q * seqused_k.size(0) == q.shape[0] and 1 < max_seqlen_q < 16)
if max_seqlen_q >= 16:
bshd_prefill = _require_hg_varlen_symbol("hg_prefix_prefill_varlen_fwd")
fa_output, *rest_extend = bshd_prefill(
q,
k,
v,
out, # out_
cu_seqlens_q,
None, # cu_seqlens_k
seqused_k,
alibi_slopes,
block_table,
max_seqlen_q,
max_seqlen_k,
0.0, # dropout
softmax_scale,
False, # zero_tensors
causal,
window_size[0],
window_size[1],
softcap,
return_softmax_lse,
1,
None if (k_dtype_bits == 16) else q_descale,
None if (k_dtype_bits == 16) else k_descale,
None if (k_dtype_bits == 16) else v_descale,
True,
)
return (fa_output, rest_extend[0]) if return_softmax_lse else fa_output
bs = seqused_k.size(0)
total_q = q.shape[0]
if max_seqlen_q == 1 or is_mtp:
assert not use_alibi_sqrt and qq_bias is None and s_aux is None and mm_prefix_range is None, \
f"Arguments not supported in hg_bshd_decode"
bshd_pa_decode = _require_hg_varlen_symbol("hg_prefix_decode_varlen_fwd")
result = bshd_pa_decode(
q,
k,
v,
out,
cu_seqlens_q,
None,
seqused_k,
alibi_slopes,
block_table,
max_seqlen_q,
max_seqlen_k,
0.0,
softmax_scale,
False,
causal,
window_size[0],
window_size[1],
softcap,
return_softmax_lse,
1,
None if (k_dtype_bits == 16) else q_descale,
None if (k_dtype_bits == 16) else k_descale,
None if (k_dtype_bits == 16) else v_descale,
True,
)
fa_output = result[0]
return (fa_output, result[1]) if return_softmax_lse else fa_output
out, softmax_lse = flash_attn_cuda.varlen_fwd_unified( out, softmax_lse = flash_attn_cuda.varlen_fwd_unified(
q, q,
k, k,
...@@ -3816,22 +3900,6 @@ def spas_fa2_attn_meansim_topk_varlen_cuda( ...@@ -3816,22 +3900,6 @@ def spas_fa2_attn_meansim_topk_varlen_cuda(
) )
def fwd_sparse_mean_pool_fast(x,BLK,mean=None):
return flash_attn_cuda.fwd_sparse_mean_pool_fast(x,BLK,mean)
def get_block_map_fast(q, k, topk_ratio, BLKQ=128, BLKK=64):
meank = torch.mean(k, dim=-3, keepdim=True)
pooled_kblocks = fwd_sparse_mean_pool_fast(k, BLKK, meank)
pooled_qblocks = fwd_sparse_mean_pool_fast(q,BLKQ)
pooled_score = pooled_qblocks @ pooled_kblocks.transpose(-1, -2)
K = pooled_score.shape[-1]
topk = min(K, int(topk_ratio * K))
lut = torch.topk(pooled_score, topk, dim=-1, sorted=False).indices
sparse_map = torch.zeros_like(pooled_score, dtype=torch.int8)
sparse_map.scatter_(-1, lut, 1)
return sparse_map, lut, topk
class SparseLinearAttention(nn.Module): class SparseLinearAttention(nn.Module):
def __init__(self, head_dim, topk, feature_map='softmax', use_bf16=True, use_fp8=False, tie_feature_map_qk=True): def __init__(self, head_dim, topk, feature_map='softmax', use_bf16=True, use_fp8=False, tie_feature_map_qk=True):
R''' R'''
...@@ -3888,15 +3956,19 @@ class SparseLinearAttention(nn.Module): ...@@ -3888,15 +3956,19 @@ class SparseLinearAttention(nn.Module):
''' '''
B, seqlen_q, H, headdim = q.shape B, seqlen_q, H, headdim = q.shape
q_bhld = q.transpose(1, 2).contiguous() # (B, H, L, D)
k_bhld = k.transpose(1, 2).contiguous()
# v_bhld = v.transpose(1, 2).contiguous()
# import pdb
# pdb.set_trace()
if headdim == 64: if headdim == 64:
block_m = 64 if seqlen_q <= 2048 else 128 block_m = 64 if seqlen_q <= 2048 else 128
elif headdim == 128: elif headdim == 128:
block_m = 64 if seqlen_q <= 2048 else 128 block_m = 64 if seqlen_q <= 2048 else 128
block_k = 64 block_k = 64
if headdim == 64: sparse_map, lut, real_topk = get_block_map(q_bhld, k_bhld, topk_ratio=self.topk, BLKQ=block_m, BLKK=block_k)
sparse_map, lut, real_topk = get_block_map(q.transpose(1, 2).contiguous(), k.transpose(1, 2).contiguous(), topk_ratio=self.topk, BLKQ=block_m, BLKK=block_k)
else:
sparse_map, lut, real_topk = get_block_map_fast(q, k, topk_ratio=self.topk, BLKQ=block_m, BLKK=block_k)
q = q.to(self.dtype) q = q.to(self.dtype)
k = k.to(self.dtype) k = k.to(self.dtype)
...@@ -3912,10 +3984,10 @@ class SparseLinearAttention(nn.Module): ...@@ -3912,10 +3984,10 @@ class SparseLinearAttention(nn.Module):
seqlen_k = k.size(1) seqlen_k = k.size(1)
num_blocks_q = (seqlen_q + block_m - 1) // block_m num_blocks_q = (seqlen_q + block_m - 1) // block_m
num_blocks_k = (seqlen_k + block_k - 1) // block_k num_blocks_k = (seqlen_k + block_k - 1) // block_k
column_count = torch.empty( column_count = torch.zeros(
(B, H, num_blocks_q), dtype=torch.int32, device=q.device (B, H, num_blocks_q), dtype=torch.int32, device=q.device
) )
column_index = torch.empty( column_index = torch.zeros(
(B, H, num_blocks_q, 1), dtype=torch.int32, device=q.device (B, H, num_blocks_q, 1), dtype=torch.int32, device=q.device
) )
...@@ -3984,56 +4056,14 @@ def sparse_attn_with_sla( ...@@ -3984,56 +4056,14 @@ def sparse_attn_with_sla(
maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)] q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
dtype = torch.bfloat16 if use_bf16 else torch.float16 attn = SparseLinearAttention(
dtype = torch.float8_e4m3fn if use_fp8 else dtype head_dim=q.size(-1),
B, seqlen_q, H, headdim = q.shape topk=topk, # = 1 - sparsity
assert not (use_bf16 and use_fp8), "Only one of bf16 and fp8 can be used." feature_map=feature_map, # options: elu, relu, softmax
assert headdim in (64, 128), "Dtype fp16/bf16 only support dim (64, 128)." use_bf16=use_bf16,
assert not (use_fp8 and headdim==64), "Dtype fp8 only support dim 128." use_fp8=use_fp8,
if headdim == 64: ).cuda()
block_m = 64 if seqlen_q <= 2048 else 128 return attn(q, k, v, return_sparsity=return_sparsity)
elif headdim == 128:
block_m = 64 if seqlen_q <= 2048 else 128
block_k = 64
if headdim == 64:
sparse_map, lut, real_topk = get_block_map(q.transpose(1, 2).contiguous(), k.transpose(1, 2).contiguous(), topk_ratio=topk, BLKQ=block_m, BLKK=block_k)
else:
sparse_map, lut, real_topk = get_block_map_fast(q, k, topk_ratio=topk, BLKQ=block_m, BLKK=block_k)
q = q.to(dtype)
k = k.to(dtype)
v = v.to(dtype)
########## SPARGE BEGIN ##########
headdim = q.size(-1)
block_offset, block_count = block_map_to_block_offset_triton(sparse_map)
block_offset = block_offset * block_k
softmax_scale = 1.0 / (headdim ** 0.5)
assert headdim in [64, 128], "headdim should be in [64, 128]. For other headdim, you can use padding and specify the softmax scale."
seqlen_k = k.size(1)
num_blocks_q = (seqlen_q + block_m - 1) // block_m
num_blocks_k = (seqlen_k + block_k - 1) // block_k
column_count = torch.empty(
(B, H, num_blocks_q), dtype=torch.int32, device=q.device
)
column_index = torch.empty(
(B, H, num_blocks_q, 1), dtype=torch.int32, device=q.device
)
o_s = sparse_attn_func(
q, k, v, # Use original BLHD layout
block_count=block_count,
block_offset=block_offset,
column_count=column_count,
column_index=column_index,
softmax_scale=softmax_scale,
is_sla=True,
)
if return_sparsity:
return o_s, real_topk / sparse_map.shape[-1]
else:
return o_s
def _require_hg_varlen_symbol(name: str): def _require_hg_varlen_symbol(name: str):
...@@ -4045,15 +4075,6 @@ def _require_hg_varlen_symbol(name: str): ...@@ -4045,15 +4075,6 @@ def _require_hg_varlen_symbol(name: str):
return symbol return symbol
def _apply_hg_kvcache_safe_env() -> None:
# DTK/gfx938 PA launch is sensitive to these knobs. Keep the old HG-safe
# defaults unless the caller explicitly asks for the raw kernel selection.
if os.environ.get("HG_KVCACHE_RAW_KERNEL") == "1":
return
os.environ.setdefault("PA_NO_MLS", "1")
os.environ.setdefault("PA_USE_TILE32X32", "1")
def _validate_hg_paged_kv_contract(k_cache, v_cache) -> None: def _validate_hg_paged_kv_contract(k_cache, v_cache) -> None:
if k_cache.dim() != 4 or v_cache.dim() != 4: if k_cache.dim() != 4 or v_cache.dim() != 4:
raise ValueError("HG paged KV cache expects k and v to both be 4D tensors") raise ValueError("HG paged KV cache expects k and v to both be 4D tensors")
...@@ -4067,53 +4088,6 @@ def _validate_hg_paged_kv_contract(k_cache, v_cache) -> None: ...@@ -4067,53 +4088,6 @@ def _validate_hg_paged_kv_contract(k_cache, v_cache) -> None:
) )
def _normalize_hg_paged_q_scales(q_scale, batch_size, num_heads_q, num_heads_k):
if q_scale is None:
raise ValueError("q_descale must be provided for HG int8 paged-kvcache path")
q_scale = maybe_contiguous(q_scale)
if q_scale.dim() == 1:
if q_scale.numel() == batch_size * num_heads_q:
q_scale = q_scale.view(batch_size, num_heads_q)
elif q_scale.numel() == batch_size * num_heads_k:
q_scale = q_scale.view(batch_size, num_heads_k)
if q_scale.dim() != 2 or q_scale.shape[0] != batch_size:
raise ValueError(
"q_descale must have shape [batch_size, num_heads_q] "
"or [batch_size, num_heads_k] for HG int8 paged-kvcache path"
)
if q_scale.shape[1] == num_heads_q:
return q_scale.contiguous()
if q_scale.shape[1] == num_heads_k and num_heads_q % num_heads_k == 0:
return q_scale.repeat_interleave(num_heads_q // num_heads_k, dim=1).contiguous()
raise ValueError(
"q_descale must have shape [batch_size, num_heads_q] "
"or [batch_size, num_heads_k] for HG int8 paged-kvcache path"
)
def _expand_hg_paged_kv_scales(scale, block_table, page_block_size, num_heads_k, name):
if scale is None:
raise ValueError(f"{name} must be provided for HG int8 paged-kvcache path")
scale = maybe_contiguous(scale)
batch_size = block_table.shape[0]
if scale.dim() == 1 and scale.numel() == batch_size * num_heads_k:
scale = scale.view(batch_size, num_heads_k)
if scale.dim() != 2 or scale.shape != (batch_size, num_heads_k):
raise ValueError(
f"{name} must have shape [batch_size, num_heads_k] for HG int8 paged-kvcache path"
)
expanded = torch.empty(
(int(block_table.max().item()) + 1, page_block_size, num_heads_k),
device=scale.device,
dtype=scale.dtype,
)
for batch_idx in range(batch_size):
block_ids = block_table[batch_idx].to(dtype=torch.long)
expanded[block_ids] = scale[batch_idx].view(1, 1, num_heads_k).expand(
block_ids.numel(), page_block_size, num_heads_k
)
return expanded.contiguous()
def hg_flash_attn_varlen_func( def hg_flash_attn_varlen_func(
q, q,
k, k,
...@@ -4333,10 +4307,6 @@ def hg_flash_attn_varlen_func( ...@@ -4333,10 +4307,6 @@ def hg_flash_attn_varlen_func(
raise ValueError("cu_seqlens_k and seqused_k cannot be provided at the same time") raise ValueError("cu_seqlens_k and seqused_k cannot be provided at the same time")
if block_table is None: if block_table is None:
raise ValueError("block_table must be provided when seqused_k is used") raise ValueError("block_table must be provided when seqused_k is used")
if return_attn_probs:
raise NotImplementedError(
"return_attn_probs is not supported for HG prefix/paged compatibility paths"
)
if dropout_p != 0.0: if dropout_p != 0.0:
raise NotImplementedError("dropout_p must be 0.0 for HG prefix/paged compatibility paths") raise NotImplementedError("dropout_p must be 0.0 for HG prefix/paged compatibility paths")
...@@ -4366,7 +4336,7 @@ def hg_flash_attn_varlen_func( ...@@ -4366,7 +4336,7 @@ def hg_flash_attn_varlen_func(
window_size[0], window_size[0],
window_size[1], window_size[1],
softcap, softcap,
return_softmax_lse, return_softmax_lse or return_attn_probs,
1, 1,
None if k_dtype_bits == 16 else q_descale, None if k_dtype_bits == 16 else q_descale,
None if k_dtype_bits == 16 else k_descale, None if k_dtype_bits == 16 else k_descale,
...@@ -4374,7 +4344,7 @@ def hg_flash_attn_varlen_func( ...@@ -4374,7 +4344,7 @@ def hg_flash_attn_varlen_func(
is_bf16_output, is_bf16_output,
) )
fa_output = result[0] fa_output = result[0]
return (fa_output, result[1]) if return_softmax_lse else fa_output return (fa_output, result[1]) if (return_softmax_lse or return_attn_probs) else fa_output
if k_dtype_bits == 16: if k_dtype_bits == 16:
prefix_decode = _require_hg_varlen_symbol("hg_prefix_decode_varlen_fwd") prefix_decode = _require_hg_varlen_symbol("hg_prefix_decode_varlen_fwd")
...@@ -4397,13 +4367,17 @@ def hg_flash_attn_varlen_func( ...@@ -4397,13 +4367,17 @@ def hg_flash_attn_varlen_func(
window_size[0], window_size[0],
window_size[1], window_size[1],
softcap, softcap,
return_softmax_lse, return_softmax_lse or return_attn_probs,
1, 1,
None,
None,
None,
is_bf16_output,
) )
fa_output = result[0] fa_output = result[0]
return (fa_output, result[1]) if return_softmax_lse else fa_output return (fa_output, result[1]) if (return_softmax_lse or return_attn_probs) else fa_output
if return_softmax_lse: if return_softmax_lse or return_attn_probs:
raise NotImplementedError( raise NotImplementedError(
"return_softmax_lse is not supported for the HG paged-kvcache compatibility path" "return_softmax_lse is not supported for the HG paged-kvcache compatibility path"
) )
...@@ -4411,28 +4385,6 @@ def hg_flash_attn_varlen_func( ...@@ -4411,28 +4385,6 @@ def hg_flash_attn_varlen_func(
_validate_hg_paged_kv_contract(k, v) _validate_hg_paged_kv_contract(k, v)
if k.shape[1] != 128: if k.shape[1] != 128:
raise NotImplementedError("HG paged-kvcache path currently requires page_block_size == 128") raise NotImplementedError("HG paged-kvcache path currently requires page_block_size == 128")
_apply_hg_kvcache_safe_env()
q_descale = _normalize_hg_paged_q_scales(
q_descale,
batch_size=block_table.shape[0],
num_heads_q=q.shape[1],
num_heads_k=k.shape[2],
)
k_descale = _expand_hg_paged_kv_scales(
k_descale,
block_table=block_table,
page_block_size=k.shape[1],
num_heads_k=k.shape[2],
name="k_descale",
)
v_descale = _expand_hg_paged_kv_scales(
v_descale,
block_table=block_table,
page_block_size=v.shape[1],
num_heads_k=v.shape[2],
name="v_descale",
)
hg_kvcache = _require_hg_varlen_symbol("hg_fwd_kvcache_bshd") hg_kvcache = _require_hg_varlen_symbol("hg_fwd_kvcache_bshd")
result = hg_kvcache( result = hg_kvcache(
q.unsqueeze(1), q.unsqueeze(1),
...@@ -4442,7 +4394,7 @@ def hg_flash_attn_varlen_func( ...@@ -4442,7 +4394,7 @@ def hg_flash_attn_varlen_func(
None, None,
None, None,
seqused_k, seqused_k,
max_seqlen_k if max_seqlen_k > 0 else int(seqused_k.max().item()), 1,
None, None,
None, None,
None, None,
...@@ -4456,12 +4408,12 @@ def hg_flash_attn_varlen_func( ...@@ -4456,12 +4408,12 @@ def hg_flash_attn_varlen_func(
window_size[1], window_size[1],
softcap, softcap,
False, False,
num_splits, -1,
None, None,
None, None,
q_descale, None if k_dtype_bits == 16 else q_descale,
k_descale, None if k_dtype_bits == 16 else k_descale,
v_descale, None if k_dtype_bits == 16 else v_descale,
is_bf16_output, is_bf16_output,
) )
return result[0].squeeze(1) return result[0].squeeze(1)
...@@ -112,6 +112,8 @@ _HG_EXPLICIT_SOURCES_BY_MODE = { ...@@ -112,6 +112,8 @@ _HG_EXPLICIT_SOURCES_BY_MODE = {
"src/target/flash_fwd_hdim128_padding_mask_fp16.cpp", "src/target/flash_fwd_hdim128_padding_mask_fp16.cpp",
"src/target/flash_fwd_hdim128_prefix_prefill_bf16.cpp", "src/target/flash_fwd_hdim128_prefix_prefill_bf16.cpp",
"src/target/flash_fwd_hdim128_prefix_prefill_fp16.cpp", "src/target/flash_fwd_hdim128_prefix_prefill_fp16.cpp",
"src/target/flash_fp8_fwd_hdim128_prefix_prefill_bf16.cpp",
"src/target/flash_fp8_fwd_hdim128_prefix_prefill_fp16.cpp",
"src/target/flash_fwd_hdim160_bf16.cpp", "src/target/flash_fwd_hdim160_bf16.cpp",
"src/target/flash_fwd_hdim160_fp16.cpp", "src/target/flash_fwd_hdim160_fp16.cpp",
"src/target/flash_fwd_hdim192_bf16.cpp", "src/target/flash_fwd_hdim192_bf16.cpp",
...@@ -262,6 +264,32 @@ def _ninja_shell_join(args) -> str: ...@@ -262,6 +264,32 @@ def _ninja_shell_join(args) -> str:
return " ".join(_ninja_escape(shlex.quote(str(x))) for x in args) return " ".join(_ninja_escape(shlex.quote(str(x))) for x in args)
def _resolve_hg_compiler() -> str:
candidates = [
os.environ.get("FLASH_ATTN_HG_COMPILER"),
"/opt/dtk/bin/aicc",
"aicc",
]
for compiler in candidates:
if compiler and shutil.which(compiler):
return compiler
requested = [c for c in candidates if c]
raise RuntimeError(
"error: no usable HG aicc compiler found from: "
+ ", ".join(repr(c) for c in requested)
+ ". Set FLASH_ATTN_HG_COMPILER to the DTK aicc path."
)
def _is_rocm_5_7() -> bool:
version_file = "/opt/rocm/.info/version"
try:
with open(version_file, "r", encoding="utf-8") as f:
return f.read(100).startswith("5.7.0")
except OSError:
return False
def compute_hg_build_descriptor( def compute_hg_build_descriptor(
src_dir, src_dir,
build_dir, build_dir,
...@@ -279,7 +307,7 @@ def compute_hg_build_descriptor( ...@@ -279,7 +307,7 @@ def compute_hg_build_descriptor(
BUILD_FA_FWD = BUILD_FA_BWD = BUILD_FA_KVCACHE = False BUILD_FA_FWD = BUILD_FA_BWD = BUILD_FA_KVCACHE = False
BUILD_FA_PERMUTE = BUILD_FLASHMLA = False BUILD_FA_PERMUTE = BUILD_FLASHMLA = False
BUILD_C_INTERFACE = False BUILD_C_INTERFACE = False
BUILD_ASM = False BUILD_ASM = True
FA_DEBUG = True FA_DEBUG = True
FA_DEBUG_SUM_MAX = False FA_DEBUG_SUM_MAX = False
HEADDIM_128_ONLY = False HEADDIM_128_ONLY = False
...@@ -358,11 +386,8 @@ def compute_hg_build_descriptor( ...@@ -358,11 +386,8 @@ def compute_hg_build_descriptor(
GFX_VERSION = "938" GFX_VERSION = "938"
ROCM_PATH = os.environ.get("ROCM_PATH", os.environ.get("ROCM_HOME", "/opt/rocm")) ROCM_PATH = os.environ.get("ROCM_PATH", os.environ.get("ROCM_HOME", "/opt/rocm"))
HG_COMPILER = _resolve_hg_compiler()
if not shutil.which("hipcc"):
raise RuntimeError(
"error: hipcc not found in PATH. Please activate the DTK environment first."
)
if not os.path.isdir(os.path.join(ROCM_PATH, "include")): if not os.path.isdir(os.path.join(ROCM_PATH, "include")):
raise RuntimeError( raise RuntimeError(
f"error: {ROCM_PATH}/include not found. " f"error: {ROCM_PATH}/include not found. "
...@@ -444,6 +469,8 @@ def compute_hg_build_descriptor( ...@@ -444,6 +469,8 @@ def compute_hg_build_descriptor(
DEFINES.append("-DPA_PAGE_BLOCK_SIZE") DEFINES.append("-DPA_PAGE_BLOCK_SIZE")
if MLA_PAGE_BLOCK_SIZE: if MLA_PAGE_BLOCK_SIZE:
DEFINES.append("-DMLA_PAGE_BLOCK_SIZE") DEFINES.append("-DMLA_PAGE_BLOCK_SIZE")
if _is_rocm_5_7():
DEFINES.append("-DROCM_5_7")
OFFLOAD_FLAGS = [f"--offload-arch=gfx{_g}" for _g in GFX_VERSION.split(";") if _g] OFFLOAD_FLAGS = [f"--offload-arch=gfx{_g}" for _g in GFX_VERSION.split(";") if _g]
...@@ -457,23 +484,29 @@ def compute_hg_build_descriptor( ...@@ -457,23 +484,29 @@ def compute_hg_build_descriptor(
INCLUDE_FLAGS += TORCH_INCLUDE_FLAGS INCLUDE_FLAGS += TORCH_INCLUDE_FLAGS
COMMON_FLAGS = [ COMMON_FLAGS = [
"-fPIC",
"-O3", "-O3",
"-std=c++17", "-std=c++17",
"-fPIC",
"-ffast-math", "-ffast-math",
"-fno-finite-math-only", "-fno-finite-math-only",
"-fno-gpu-rdc", "-fno-gpu-rdc",
"-mno-fma",
] ]
DTK_DEVICE_FLAGS = [ DTK_DEVICE_FLAGS = [
"-DHIP_ENABLE_WARP_SYNC_BUILTINS",
"-mllvm", "-mllvm",
"-slp-phi-tree-bb-max-size=10000", "-support-768-vgprs=true",
"-mllvm", "-mllvm",
"-enable-num-vgprs-512=true", "-disable-machine-sink",
"-Rpass-analysis=kernel-resource-usage", "-mcode-object-version=5",
"-ftemplate-backtrace-limit=0",
] ]
if not _is_rocm_5_7():
DTK_DEVICE_FLAGS += [
"-mllvm",
"-amdgpu-enable-rewrite-partial-reg-uses=false",
"-mllvm",
"-allow-gvn-convergent-call=true",
"-mllvm",
"-disallow-uniform-vmed3-combine=true",
]
if os.environ.get("FLASH_ATTN_HG_SAVE_TEMPS", "") == "1": if os.environ.get("FLASH_ATTN_HG_SAVE_TEMPS", "") == "1":
DTK_DEVICE_FLAGS.append("--save-temps") DTK_DEVICE_FLAGS.append("--save-temps")
...@@ -555,6 +588,7 @@ def compute_hg_build_descriptor( ...@@ -555,6 +588,7 @@ def compute_hg_build_descriptor(
"obj_dir": obj_dir, "obj_dir": obj_dir,
"sources": _all_sources, "sources": _all_sources,
"objects": objects, "objects": objects,
"compiler": HG_COMPILER,
"compile_flags": compile_flags, "compile_flags": compile_flags,
"link_flags": link_flags, "link_flags": link_flags,
"out_so": out_so, "out_so": out_so,
...@@ -566,16 +600,17 @@ def run_hg_ninja_build(descriptor: dict) -> None: ...@@ -566,16 +600,17 @@ def run_hg_ninja_build(descriptor: dict) -> None:
"""Write build_hg.ninja and run ninja (parallel via MAX_JOBS).""" """Write build_hg.ninja and run ninja (parallel via MAX_JOBS)."""
build_dir = descriptor["build_dir"] build_dir = descriptor["build_dir"]
ninja_file = descriptor["ninja_path"] ninja_file = descriptor["ninja_path"]
compiler = _ninja_shell_join([descriptor["compiler"]])
out_so_ninja = _ninja_escape_path(descriptor["out_so"]) out_so_ninja = _ninja_escape_path(descriptor["out_so"])
lines = [ lines = [
"ninja_required_version = 1.3", "ninja_required_version = 1.3",
"", "",
"rule hipcc_compile", "rule hg_compile",
" command = hipcc -c $in -o $out $FLAGS", f" command = {compiler} -c $in -o $out $FLAGS",
" description = HG compile $in", " description = HG compile $in",
"", "",
"rule hipcc_link", "rule hg_link",
" command = hipcc -shared -o $out @$out.rsp $LINK_FLAGS", f" command = {compiler} -shared -o $out @$out.rsp $LINK_FLAGS",
" rspfile = $out.rsp", " rspfile = $out.rsp",
" rspfile_content = $in", " rspfile_content = $in",
" description = HG link $out", " description = HG link $out",
...@@ -585,9 +620,9 @@ def run_hg_ninja_build(descriptor: dict) -> None: ...@@ -585,9 +620,9 @@ def run_hg_ninja_build(descriptor: dict) -> None:
"", "",
] ]
for src, obj in zip(descriptor["sources"], descriptor["objects"]): for src, obj in zip(descriptor["sources"], descriptor["objects"]):
lines.append(f"build {_ninja_escape_path(obj)}: hipcc_compile {_ninja_escape_path(src)}") lines.append(f"build {_ninja_escape_path(obj)}: hg_compile {_ninja_escape_path(src)}")
obj_list = " ".join(_ninja_escape_path(obj) for obj in descriptor["objects"]) obj_list = " ".join(_ninja_escape_path(obj) for obj in descriptor["objects"])
lines.append(f"build {out_so_ninja}: hipcc_link {obj_list}") lines.append(f"build {out_so_ninja}: hg_link {obj_list}")
lines.append("") lines.append("")
os.makedirs(build_dir, exist_ok=True) os.makedirs(build_dir, exist_ok=True)
...@@ -878,7 +913,6 @@ if not SKIP_CUDA_BUILD: ...@@ -878,7 +913,6 @@ if not SKIP_CUDA_BUILD:
"csrc/flash_attn/src/flash_fwd_unified_hdim256_fp16_sm80.cu", "csrc/flash_attn/src/flash_fwd_unified_hdim256_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_unified_hdim256_bf16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_unified_hdim256_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_unified_hdim256_bf16_sm80.cu", "csrc/flash_attn/src/flash_fwd_unified_hdim256_bf16_sm80.cu",
"csrc/flash_attn/src/flash_sparse_util.cu"
], ],
extra_compile_args={ extra_compile_args={
"cxx": ["-O3", "-w", "-std=c++17", "cxx": ["-O3", "-w", "-std=c++17",
...@@ -944,6 +978,7 @@ if not SKIP_CUDA_BUILD: ...@@ -944,6 +978,7 @@ if not SKIP_CUDA_BUILD:
Path(this_dir) / "csrc" / "flash_attn", Path(this_dir) / "csrc" / "flash_attn",
Path(this_dir) / "csrc" / "flash_attn" / "src", Path(this_dir) / "csrc" / "flash_attn" / "src",
Path(this_dir) / "csrc" / "cutlass" / "include", Path(this_dir) / "csrc" / "cutlass" / "include",
"/public/home/huangly/数据采集/cutlass_3.2.1/include" ],
) )
) )
...@@ -1051,13 +1086,16 @@ class NinjaBuildExtension(BuildExtension): ...@@ -1051,13 +1086,16 @@ class NinjaBuildExtension(BuildExtension):
if os.path.isdir(HG_SRC_DIR): if os.path.isdir(HG_SRC_DIR):
os.makedirs(HG_BUILD_DIR, exist_ok=True) os.makedirs(HG_BUILD_DIR, exist_ok=True)
_maybe_clean_hg_build_dir(HG_BUILD_DIR) _maybe_clean_hg_build_dir(HG_BUILD_DIR)
print("=== Building HG libflash_attention.so (mode=all, gfx938, ninja) ===")
try: try:
desc = compute_hg_build_descriptor( desc = compute_hg_build_descriptor(
HG_SRC_DIR, HG_SRC_DIR,
HG_BUILD_DIR, HG_BUILD_DIR,
mode="all", mode="all",
extra_options_raw="-DGFX_VERSION=938 -Wl,-Bsymbolic", extra_options_raw="-DGFX_VERSION=938;936 -Wl,-Bsymbolic",
)
print(
"=== Building HG libflash_attention.so "
f"(mode=all, gfx938/936, ninja, compiler={desc['compiler']}) ==="
) )
run_hg_ninja_build(desc) run_hg_ninja_build(desc)
if os.path.isfile(HG_SO_BUILD): if os.path.isfile(HG_SO_BUILD):
......
import os
import sys
import math
import torch
import pickle
import time
import numpy
import argparse
import random
from datetime import datetime
use_cuda_toolkits = os.path.exists("/usr/local/cuda/bin/nvcc")
use_rocm_toolkits = os.path.exists("/opt/rocm/llvm/bin/clang")
use_dtk_toolkits = os.path.exists("/opt/dtk/bin/aicc")
if (use_cuda_toolkits):
from vllm.vllm_flash_attn import flash_attn_varlen_func
elif (use_rocm_toolkits or use_dtk_toolkits):
try:
from flash_attention_interface import flash_attn_varlen_func, flash_attn_2_cuda, flash_attn_with_kvcache
except ModuleNotFoundError:
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
import flash_attn_2_cuda as flash_attn_cuda
def _require_hg_varlen_symbol(name: str):
symbol = getattr(flash_attn_cuda, name, None)
if symbol is None:
raise RuntimeError(
f"{name} is unavailable in this build. Rebuild flash_attn with HAS_HG_DISPATCH enabled."
)
return symbol
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, do_assert=True, cos_threshold=1e-5) -> None:
assert x.shape == y.shape, "for {}, x and y must have the same shape".format(name)
x, y = x.double(), y.double()
RMSE = ((x - y) * (x - y)).mean().sqrt().item()
cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12)
amax_diff = (x - y).abs().max().item()
rel_diff_mean = (x / y).abs().mean().item()
rel_diff_max = (x / y).abs().max().item()
print("name:{} cos_diff={:.12f}, RMSE=\x1b[35m{:.12f}\x1b[0m, amax_diff=\x1b[35m{:.12f}\x1b[0m, REL=\x1b[35m{:.12f}\x1b[0m, rel_max=\x1b[35m{:.12f}\x1b[0m".format(
name, cos_diff, RMSE, amax_diff, rel_diff_mean, rel_diff_max))
if (do_assert): assert cos_diff < cos_threshold
def scaled_dot_product_attention(__query, __key, __value, h_q, h_kv, is_causal=False, USE_CPU=False, return_max_sum=False, original_seqlen_kv=0, split_slice=0, is_bshd=False, window_size=(-1, -1)):
__query = __query.transpose(0, 1).contiguous()
__key = __key.transpose(0, 1).contiguous()
__value = __value.transpose(0, 1).contiguous()
# 判断是否使用 CPU 计算 golden, 避免 blas 的影响
original_device = __query.device
original_dtype = __query.dtype
if (USE_CPU):
__query = __query.cpu()
__key = __key.cpu()
__value = __value.cpu()
# print("scaled_dot_product_attention: ", query.shape, key.shape, value.shape)
__query = __query.float()
__key = __key.float()
__value = __value.float()
# 如果按照官方的方法返回
if (not return_max_sum):
__key = __key.repeat_interleave(h_q // h_kv, dim=0)
__value = __value.repeat_interleave(h_q // h_kv, dim=0)
attn_weight = __query @ __key.transpose(-2, -1) / math.sqrt(__query.size(-1))
# MTP > 1, causal/local mask applied
if (window_size != (-1, -1)):
s_q = __query.shape[-2]
s_k = __key.shape[-2]
left, right = window_size
if left < 0:
left = s_k
if right < 0:
right = s_k
row_idx = torch.arange(s_q, dtype=torch.int32, device=attn_weight.device)[:, None]
col_idx = torch.arange(s_k, dtype=torch.int32, device=attn_weight.device)[None, :]
col_idx_limit_left = row_idx + s_k - s_q - left
col_idx_limit_right = row_idx + s_k - s_q + right
temp_mask = (col_idx >= col_idx_limit_left) & (col_idx <= col_idx_limit_right)
attn_weight = attn_weight.masked_fill(temp_mask.logical_not(), float("-inf"))
elif (is_causal):
s_q = __query.shape[-2]
s_k = __key.shape[-2]
attn_bias = torch.zeros(s_q, s_k, dtype=__query.dtype, device=attn_weight.device)
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool, device=attn_weight.device).tril(diagonal=s_k - s_q)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(__query.dtype)
attn_weight += attn_bias
# some codes for debug
scores_max = attn_weight.to(torch.float32).max(-1)[0]
scores_sum = torch.exp(attn_weight.to(torch.float32) - scores_max.unsqueeze(-1)).sum(dim=-1)
# original codes
lse = attn_weight.logsumexp(dim=-1)
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
output = attn_weight @ __value
output = output.transpose(0, 1).contiguous()
return output.to(original_device).to(original_dtype), lse.to(original_device), scores_max.to(original_device), scores_sum.to(original_device)
def set_random_seed(seed=0):
random.seed(seed) # 设置 Python 的随机种子
numpy.random.seed(seed) # 设置 NumPy 的随机种子
torch.manual_seed(seed) # 设置 PyTorch 的随机种子
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed) # 设置所有 GPU 的随机种子
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['OMP_NUM_THREADS'] = '1' # 设置 OpenMP 的线程数
torch.set_num_threads(1) # 设置 PyTorch 的线程数
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--load', default=False, action='store_true', help='load path')
parser.add_argument('--trace', default=False, action='store_true', help='whether dump perf traces')
parser.add_argument('--bf16', default=False, action='store_true', help='whether use bfloat16 as main dtype')
parser.add_argument('--fp8', default=False, action='store_true', help='whether use fp8_e4m3 inputs for HG decode')
parser.add_argument('--pressure', default=False, action='store_true', help='whether do pressure test')
parser.add_argument('--cpu', default=False, action='store_true', help='whether compute golden via cpu')
parser.add_argument('--pad', default=False, action='store_true', help='whether make query uncontiguous to simulate vllm behaviors')
parser.add_argument('--iterations', type=int, default=100, help='pressure test times')
parser.add_argument('--block_size', type=int, default=128, help='page block_size')
parser.add_argument('--batch-size', type=int, default=1, help='batch size for generated inputs')
parser.add_argument('--seq-q', type=int, default=4, help='query length per batch for generated inputs')
parser.add_argument('--seq-k', type=int, default=2048, help='kv length per batch for generated inputs')
parser.add_argument('--num-heads', type=int, default=24, help='number of query heads for generated inputs')
parser.add_argument('--num-heads-kv', type=int, default=2, help='number of kv heads for generated inputs')
parser.add_argument('--no-causal', dest='causal', default=True, action='store_false', help='disable causal mask for generated inputs')
parser.add_argument('--window-left', type=int, default=-1, help='left sliding window size')
parser.add_argument('--window-right', type=int, default=-1, help='right sliding window size')
parser.add_argument('--seed', default=False, action='store_true', help='whether do pressure test')
args = parser.parse_args()
if (args.seed):
set_random_seed(212)
# 从文件加载输入
if (args.load):
nvidia_packet = torch.load("./demo.pt")
query, key, value, cu_seqlens_q, max_seqlen_q, cache_seqlens, max_seqlen_k, softmax_scale, causal, window_size, alibi_slopes, page_table, softcap, fa_version, q_descale, k_descale, v_descale = nvidia_packet["inputs"]
vllm_golden = nvidia_packet["outputs"]
# 解析出必要的参数
batch_size = page_table.shape[0]
assert batch_size == cu_seqlens_q.shape[0] - 1, "check batch size"
page_block_size = key.shape[1]
num_heads_kv = key.shape[2]
num_heads = query.shape[1]
head_dim_qk = query.shape[2]
head_dim_v = key.shape[3]
infer_dtype = query.dtype
else:
# 随机生成 seqkv
batch_size = args.batch_size
# 得到 Q 的长度
seqlen_q = [args.seq_q for i in range(batch_size)]
seqlen_q_sum = sum(seqlen_q)
max_seqlen_q = max(seqlen_q)
cu_seqlens_q = numpy.array([0] + numpy.cumsum(seqlen_q).tolist()).astype("int32")
cu_seqlens_q = torch.from_numpy(cu_seqlens_q)
# 得到 KV 的长度
cache_seqlens = [args.seq_k for i in range(batch_size)]
# 指定分页块的大小, nvidia 64, ours 128
page_block_size = 16 if (use_cuda_toolkits) else args.block_size
# 根据分页块大小计算实际需要的页表的大小
max_seqlen_k = max(cache_seqlens)
seqlen_kv_real_required_page = [math.ceil(it / page_block_size) for it in cache_seqlens]
seqlen_kv_real_required_page_sum = sum(seqlen_kv_real_required_page)
# 默认按照最大 seqlenkv 的来分配
seqlen_kv_max_required_page = math.ceil(max_seqlen_k / page_block_size)
seqlen_kv_max_required_page_total = batch_size * seqlen_kv_max_required_page
# 打乱页表
shuffle = True
if (shuffle):
block_random = torch.randperm(seqlen_kv_max_required_page_total, dtype=torch.int32, device="cuda")
else:
block_random = torch.arange(seqlen_kv_max_required_page_total , dtype=torch.int32)
page_table = []
seq_block_incre = 0
for i in range(batch_size):
blocks_pad = [0] * seqlen_kv_max_required_page
if (shuffle):
blocks_pad[:seqlen_kv_real_required_page[i]] = block_random[seq_block_incre: seq_block_incre + seqlen_kv_real_required_page[i]].cpu().tolist()
seq_block_incre += seqlen_kv_real_required_page[i]
else:
blocks_pad = block_random[seq_block_incre: seq_block_incre + seqlen_kv_max_required_page].cpu().tolist()
seq_block_incre += seqlen_kv_max_required_page
page_table.append(torch.IntTensor(blocks_pad))
page_table = torch.stack(page_table).contiguous().to("cuda")
# 创建基本参数
head_dim_qk = 128
head_dim_v = 128
num_heads = args.num_heads
num_heads_kv = args.num_heads_kv
infer_dtype = torch.float16 # deepseek 默认使用 bfloat16 推理
if (args.bf16): infer_dtype = torch.bfloat16 # 除非命令行指定用 fp16, 不受 args.dtype 影响
softmax_scale = 1.0 / math.sqrt(head_dim_qk)
causal = args.causal
window_size = (args.window_left, args.window_right)
alibi_slopes = None
softcap = 0.0
fa_version = 2
q_descale = torch.ones((batch_size, num_heads), dtype=torch.float32, device="cuda")
k_descale = torch.ones((batch_size, num_heads_kv), dtype=torch.float32, device="cuda")
v_descale = torch.ones((batch_size, num_heads_kv), dtype=torch.float32, device="cuda")
# 创建输入张量
if (args.pad):
query_origin_tensor = torch.randn((seqlen_q_sum, num_heads + 16, head_dim_qk), dtype=infer_dtype, device="cuda")
q = query_origin_tensor[:, :num_heads]
else:
q = torch.randn((seqlen_q_sum, num_heads, head_dim_qk), dtype=infer_dtype, device="cuda")
k_cache = torch.randn((seqlen_kv_max_required_page_total, page_block_size, num_heads_kv, head_dim_qk), device="cuda", dtype=infer_dtype)
v_cache = torch.randn((seqlen_kv_max_required_page_total, page_block_size, num_heads_kv, head_dim_v), device="cuda", dtype=infer_dtype)
vllm_golden = None
cu_seqlens_q = cu_seqlens_q.to(q.device)
cache_seqlens = torch.from_numpy(numpy.array(cache_seqlens).astype("int32")).to(q.device)
q_ref = q
k_cache_ref = k_cache
v_cache_ref = v_cache
if args.fp8:
if not hasattr(torch, "float8_e4m3fn"):
raise RuntimeError("This PyTorch build does not support torch.float8_e4m3fn")
q = q.to(torch.float8_e4m3fn)
k_cache = k_cache.to(torch.float8_e4m3fn)
v_cache = v_cache.to(torch.float8_e4m3fn)
q_ref = q.to(infer_dtype)
k_cache_ref = k_cache.to(infer_dtype)
v_cache_ref = v_cache.to(infer_dtype)
# 展示一下输入数据
print("--------------------------------------------------------------------------------------------")
print("q: ", q.shape, q.dtype, q.is_contiguous(), q.stride())
print("k_cache: ", k_cache.shape, k_cache.dtype, k_cache.is_contiguous(), k_cache.stride())
print("v_cache: ", v_cache.shape, v_cache.dtype, v_cache.is_contiguous(), v_cache.stride())
print("cu_seqlens_q: ", cu_seqlens_q.shape, cu_seqlens_q.dtype, cu_seqlens_q.is_contiguous())
print("cu_seqlens_q: ", cu_seqlens_q)
print("max_seqlen_q: ", max_seqlen_q)
print("cache_seqlens: ", cache_seqlens)
print("max_seqlen_k: ", max_seqlen_k)
print("softmax_scale: ", softmax_scale)
print("causal: ", causal)
print("window_size: ", window_size)
print("alibi_slopes: ", alibi_slopes)
print("page_table: ", page_table.shape, page_table.dtype, page_table.is_contiguous(), page_table.stride())
print("page_table: ", page_table)
print("softcap: ", softcap)
print("fa_version: ", fa_version)
print("q_descale: ", q_descale.shape, q_descale.dtype, q_descale.tolist())
print("k_descale: ", k_descale.shape, k_descale.dtype, k_descale.tolist())
print("v_descale: ", v_descale.shape, v_descale.dtype, v_descale.tolist())
print("--------------------------------------------------------------------------------------------")
# 先从 kvcache 中还原出 key 和 value
key_original = []
value_original = []
for b in range(batch_size):
# 获取页表索引
index = page_table[b]
# 获取实际的索引
max_page_blocks = math.ceil(cache_seqlens[b] / page_block_size)
actual_index = index[:max_page_blocks]
# 根据该页表索引获取当前 seqlenkv 的内容
key_content = k_cache_ref[actual_index]
# reshape 回去
key_content = key_content.view(-1, num_heads_kv, head_dim_qk)[:cache_seqlens[b]].contiguous()
# 同理
value_content = v_cache_ref[actual_index].view(-1, num_heads_kv, head_dim_v)[:cache_seqlens[b]].contiguous()
key_original.append(key_content)
value_original.append(value_content)
# 同理还原出 query 的内容
query_original = []
cum_q = 0
for b in range(batch_size):
query_len = cu_seqlens_q[b + 1] - cu_seqlens_q[b]
query_content = q_ref[cum_q: cum_q + query_len]
query_original.append(query_content.contiguous())
cum_q += query_len
# 重新实现 self-attention
golden = []
golden_lse = []
golden_max = []
for b in range(batch_size):
tmp_output, lse, scores_max, scores_sum = scaled_dot_product_attention(query_original[b], key_original[b], value_original[b], num_heads, num_heads_kv, is_causal=causal, USE_CPU=args.cpu, window_size=window_size)
golden.append(tmp_output)
golden_lse.append(lse)
golden_max.append(scores_max)
golden = torch.cat(golden, dim=0)
golden_lse = torch.cat(golden_lse, dim=-1)
golden_max = torch.cat(golden_max, dim=-1)
print("golden: ", golden.shape)
print("golden_lse: ", golden_lse.shape)
print("--------------------------------------------------------------------------------------------")
if (True):
# fa_output, fa_lse = flash_attn_2_cuda.prefix_decode_varlen_fwd(
bshd_pa_decode = _require_hg_varlen_symbol("hg_prefix_decode_varlen_fwd")
fa_output, fa_lse = bshd_pa_decode(
q,
k_cache,
v_cache,
None, # out_
cu_seqlens_q,
None, # cu_seqlens_k
cache_seqlens,
alibi_slopes,
page_table,
max_seqlen_q,
max_seqlen_k,
0.0, # dropout
softmax_scale,
False, # zero_tensors
causal,
window_size[0],
window_size[1],
softcap,
True, # return_softmax_lse,
1,
q_descale if args.fp8 else None,
k_descale if args.fp8 else None,
v_descale if args.fp8 else None,
infer_dtype == torch.bfloat16,
)
else:
fa_output, fa_lse, *rest = flash_attn_with_kvcache(
q=q,
k_cache=k_cache,
v_cache=v_cache,
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k_new=cache_seqlens,
max_seqlen_q=max_seqlen_q,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
softcap=softcap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
)
torch.cuda.synchronize()
if (vllm_golden is not None):
# 检查保存流程是否有错误
cal_diff(fa_output, vllm_golden, "check")
print("fa_output: ", fa_output.shape)
if (fa_lse is not None): print("fa_lse: ", fa_lse.shape)
# 检验精度如何
fp8_threshold = 5e-3
cal_diff(golden, fa_output, "accuracy", True, fp8_threshold if args.fp8 else 1e-5)
if (fa_lse is not None): cal_diff(golden_lse, fa_lse, "softmax_lse", True, fp8_threshold if args.fp8 else 1e-5)
print("--------------------------------------------------------------------------------------------")
# benchmark 性能数据
import triton
def benchmark_prefix_prefill():
_ = bshd_pa_decode(
q,
k_cache,
v_cache,
None,
cu_seqlens_q,
None,
cache_seqlens,
alibi_slopes,
page_table,
max_seqlen_q,
max_seqlen_k,
0.0,
softmax_scale,
False,
causal,
window_size[0],
window_size[1],
softcap,
True,
1,
q_descale if args.fp8 else None,
k_descale if args.fp8 else None,
v_descale if args.fp8 else None,
infer_dtype == torch.bfloat16,
)
# 适时关闭, 用于 debug
if ((os.getenv("FA_DEBUG") is None) and (os.getenv("HIP_LOG_LEVEL") is None) and not args.trace):
import triton
t = triton.testing.do_bench_cudagraph(benchmark_prefix_prefill)
FLOPS = float(0)
BYTES = float(0)
for b in range(batch_size):
batch_seqlen_q = cu_seqlens_q[b + 1] - cu_seqlens_q[b]
batch_seqlen_k = cache_seqlens[b]
undo_flops = batch_seqlen_q * batch_seqlen_q / 2 if (causal) else 0
qk_flops = num_heads * (batch_seqlen_q * batch_seqlen_k - undo_flops) * head_dim_qk * 2
pv_flops = num_heads * (batch_seqlen_q * batch_seqlen_k - undo_flops) * head_dim_v * 2
FLOPS += qk_flops + pv_flops
q_load = batch_seqlen_q * num_heads * head_dim_qk
k_load = batch_seqlen_k * num_heads_kv * head_dim_qk # k load not only once
v_load = batch_seqlen_k * num_heads_kv * head_dim_v
BYTES += q_load * q.element_size() + k_load * k_cache.element_size() + v_load * v_cache.element_size() # ignore storation ?
print(f"Performance: {t:.3f} ms, \x1b[35m{FLOPS / 10 ** 9 / t:.2f}\x1b[0m TFLOPS, \x1b[35m{BYTES / 10 ** 6 / t:.0f}\x1b[0m GB/s")
# 压力测试
if (args.pressure):
pressure_count = max(100, args.iterations)
for p in range(pressure_count):
pressure_fa_output = torch.zeros_like(fa_output)
pressure_fa_output, _ = bshd_pa_decode(
q.clone(),
k_cache.clone(),
v_cache.clone(),
None,
cu_seqlens_q,
None,
cache_seqlens,
alibi_slopes,
page_table,
max_seqlen_q,
max_seqlen_k,
0.0,
softmax_scale,
False,
causal,
window_size[0],
window_size[1],
softcap,
True,
1,
q_descale if args.fp8 else None,
k_descale if args.fp8 else None,
v_descale if args.fp8 else None,
infer_dtype == torch.bfloat16,
)
torch.cuda.synchronize()
is_equal = torch.equal(pressure_fa_output, fa_output)
if (not is_equal): cal_diff(pressure_fa_output, fa_output, "pressure")
assert is_equal, "\x1b[31mUnstable\x1b[0m!"
del pressure_fa_output
sys.stdout.write("\rPressure Test: {}/{}".format(p + 1, pressure_count))
print(" \x1b[32mPASS\x1b[0m")
print("-----------------------------------------------------------------------------------")
import torch import torch
import os
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -6,6 +7,27 @@ from vllm.triton_utils import tl, triton ...@@ -6,6 +7,27 @@ from vllm.triton_utils import tl, triton
import math import math
import time import time
def estimate_unified_attention_bytes(
batch_size,
seqlen_q,
seqlen_k,
nheads,
nheads_k,
d,
block_size,
q_bytes,
k_bytes,
v_bytes,
out_bytes=0,
):
num_blocks = math.ceil(seqlen_k / block_size)
q_bytes_total = batch_size * seqlen_q * nheads * d * q_bytes
kv_bytes_total = seqlen_k * batch_size * nheads_k * d * (k_bytes + v_bytes)
out_bytes_total = batch_size * seqlen_q * nheads * d * out_bytes
metadata_bytes = (batch_size + 1) * 4 + batch_size * 4 + batch_size * num_blocks * 4
return q_bytes_total + kv_bytes_total + out_bytes_total + metadata_bytes
import pytest import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -18,19 +40,19 @@ import pdb ...@@ -18,19 +40,19 @@ import pdb
from einops import rearrange, repeat from einops import rearrange, repeat
from flash_attn import ( from flash_attn import (
flash_attn_func, # flash_attn_func,
flash_attn_kvpacked_func, # flash_attn_kvpacked_func,
flash_attn_qkvpacked_func, # flash_attn_qkvpacked_func,
flash_attn_varlen_func, # flash_attn_varlen_func,
flash_attn_varlen_kvpacked_func, # flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func, # flash_attn_varlen_qkvpacked_func,
flash_attn_with_kvcache, # flash_attn_with_kvcache,
varlen_fwd_unified, varlen_fwd_unified,
) )
from flash_attn import flash_attn_func # from flash_attn import flash_attn_func
from flash_attn.bert_padding import pad_input, unpad_input # from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import _get_block_size_n # from flash_attn.flash_attn_interface import _get_block_size_n
from flash_attn.layers.rotary import apply_rotary_emb # from flash_attn.layers.rotary import apply_rotary_emb
MAX_HEADDIM_SM8x = 192 MAX_HEADDIM_SM8x = 192
...@@ -1245,60 +1267,102 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool = False, ...@@ -1245,60 +1267,102 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool = False,
if is_e5m2: if is_e5m2:
assert cos_diff < 1e-2 assert cos_diff < 1e-2
elif use_fp8: elif use_fp8:
assert cos_diff < 1e-3 assert cos_diff < 5e-3
else: else:
assert cos_diff < (1e-4 if torch_dtype == torch.bfloat16 else 1e-5) assert cos_diff < (1e-4 if torch_dtype == torch.bfloat16 else 1e-5)
class _NoUnifiedFallback:
def __init__(self, module):
self._module = module
def __getattr__(self, name):
if name == "varlen_fwd_unified":
raise AssertionError("unexpected fallback to flash_attn_cuda.varlen_fwd_unified")
return getattr(self._module, name)
def varlen_fwd_unified_expect_hg(expected_symbol, *args, **kwargs):
fn_globals = varlen_fwd_unified.__globals__
original_module = fn_globals["flash_attn_cuda"]
original_require = fn_globals["_require_hg_varlen_symbol"]
called_symbols = []
def require_hg_symbol(name):
called_symbols.append(name)
return original_require(name)
fn_globals["flash_attn_cuda"] = _NoUnifiedFallback(original_module)
fn_globals["_require_hg_varlen_symbol"] = require_hg_symbol
try:
result = varlen_fwd_unified(*args, **kwargs)
finally:
fn_globals["flash_attn_cuda"] = original_module
fn_globals["_require_hg_varlen_symbol"] = original_require
assert expected_symbol in called_symbols, (
f"expected {expected_symbol}, got HG calls {called_symbols}"
)
return result
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Accuracy tests # Accuracy tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("use_fp8", [False, True])
@pytest.mark.parametrize("mha_type", ["gqa"]) @pytest.mark.parametrize("mha_type", ["gqa"])
@pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("softcap", [0.0]) @pytest.mark.parametrize("softcap", [0.0])
@pytest.mark.parametrize("window_size", [(-1, -1)]) @pytest.mark.parametrize("window_size", [(-1, -1), (511, 0)])
@pytest.mark.parametrize("use_alibi_sqrt", [True, False]) @pytest.mark.parametrize("use_alibi_sqrt", [False])
@pytest.mark.parametrize("use_qq_bias", [True, False]) # seqlen_q > seqlen_k 时 skip @pytest.mark.parametrize("use_qq_bias", [False]) # seqlen_q > seqlen_k 时 skip
@pytest.mark.parametrize("use_sinks", [True, False]) @pytest.mark.parametrize("use_sinks", [False])
@pytest.mark.parametrize("use_mm_prefix", [True, False]) @pytest.mark.parametrize("use_mm_prefix", [False])
@pytest.mark.parametrize("d", [128, 256]) @pytest.mark.parametrize("d", [128])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"batch_size,seqlen_q,seqlen_k,block_size", "batch_size,seqlen_q,seqlen_k,block_size",
[ [
# --- 场景 1: 标准 Prefill (全量预填充) --- # --- 场景 1: 标准 Prefill (全量预填充) ---
# 验证对角线处理、Is_causal 逻辑以及全量 Bias 覆盖 # 验证对角线处理、Is_causal 逻辑以及全量 Bias 覆盖
(1, 512, 512, 128), # 单 Batch 小尺寸,快速验证 # (1, 512, 512, 128), # 单 Batch 小尺寸,快速验证
(4, 2048, 2048, 128), # 匹配你日志的大尺寸,验证多 Batch 偏移 # (4, 2048, 2048, 128), # 匹配你日志的大尺寸,验证多 Batch 偏移
(2, 1024, 1024, 128), # 较小的 block_size,增加循环迭代次数 # (2, 1024, 1024, 128), # 较小的 block_size,增加循环迭代次数
# --- 场景 2: Decode 场景 (增量推理) --- # --- 场景 2: Decode 场景 (增量推理) ---
# 验证 seqlen_q=1 时,如何正确从 KV Cache 的最后位置读取信息 # 验证 seqlen_q=1 时,如何正确从 KV Cache 的最后位置读取信息
# 此时 qq_bias 实际上只退化为向量加法,最容易测出指针偏移错误 # 此时 qq_bias 实际上只退化为向量加法,最容易测出指针偏移错误
(8, 1, 2048, 128), # 高 Batch 的标准 Decode (32, 4, 2048, 128),
(1, 1, 4096, 128), # 超长上下文 Decode,验证大索引寻址 (16, 4, 2048, 128),
(8, 4, 2048, 128), # 高 Batch 的标准 Decode
# --- 场景 3: Chunked Prefill / Speculative Decoding (分段/投机采样) --- (4, 4, 2048, 128),
# Q 小于 K,但大于 1。这是最难写的逻辑,验证 Is_causal 的动态截断 (2, 4, 2048, 128),
(2, 128, 1024, 128), # Q 是一小段,K 是长历史 (1, 4, 4096, 128), # 超长上下文 Decode,验证大索引寻址
(4, 256, 512, 128), # 验证 Q 和 K 比例较近时的处理
# --- 场景 3: Prefix Prefill ---
# --- 场景 4: 边界非对称尺寸 (非 2 的幂次) --- (1, 16, 128, 128), # fp8 prefill lower boundary
# 专门用来抓那些“假设数据一定是 BlockSize 整数倍”的 Bug (1, 32, 512, 128),
(1, 127, 127, 128), # 刚好差 1 个填满 Block (2, 32, 513, 128), # non block-aligned KV length
(2, 33, 1025, 128), # 非常细碎的 Block 和不规则长度 (2, 64, 2048, 128),
(3, 96, 1537, 128), # non power-of-two batch/length
(2, 128, 1024, 128),
# # --- 场景 4: 边界非对称尺寸 (非 2 的幂次) ---
# # 专门用来抓那些“假设数据一定是 BlockSize 整数倍”的 Bug
# (1, 127, 127, 128), # 刚好差 1 个填满 Block
# (2, 33, 1025, 128), # 非常细碎的 Block 和不规则长度
], ],
) )
def test_unified_attn_2d( def test_unified_attn_2d(
batch_size, seqlen_q, seqlen_k, block_size, batch_size, seqlen_q, seqlen_k, block_size,
d, causal, window_size, softcap, d, causal, window_size, softcap,
mha_type, dtype, mha_type, dtype,
use_alibi_sqrt, use_qq_bias, use_sinks, use_mm_prefix, use_alibi_sqrt, use_qq_bias, use_sinks, use_mm_prefix, use_fp8,
): ):
device = torch.device("cuda") device = torch.device("cuda")
torch.manual_seed(42) torch.manual_seed(42)
nheads = 8 nheads = 24
nheads_k = 1 if mha_type == "gqa" else nheads nheads_k = 2 if mha_type == "gqa" else nheads
softmax_scale = d ** (-0.5) softmax_scale = d ** (-0.5)
MAX_MM_RANGES = 2 MAX_MM_RANGES = 2
...@@ -1310,6 +1374,8 @@ def test_unified_attn_2d( ...@@ -1310,6 +1374,8 @@ def test_unified_attn_2d(
if use_mm_prefix and seqlen_q > seqlen_k: if use_mm_prefix and seqlen_q > seqlen_k:
pytest.skip("mm_prefix not supported when seqlen_q > seqlen_k") pytest.skip("mm_prefix not supported when seqlen_q > seqlen_k")
if causal and window_size != (-1, -1):
pytest.skip("HG local window path is selected with causal=False")
# if use_mm_prefix and not causal: # if use_mm_prefix and not causal:
# pytest.skip("mm_prefix_range is only meaningful with causal=True") # pytest.skip("mm_prefix_range is only meaningful with causal=True")
...@@ -1320,13 +1386,28 @@ def test_unified_attn_2d( ...@@ -1320,13 +1386,28 @@ def test_unified_attn_2d(
k_list.append(torch.randn(seqlen_k, nheads_k, d, device=device, dtype=dtype)) k_list.append(torch.randn(seqlen_k, nheads_k, d, device=device, dtype=dtype))
v_list.append(torch.randn(seqlen_k, nheads_k, d, device=device, dtype=dtype)) v_list.append(torch.randn(seqlen_k, nheads_k, d, device=device, dtype=dtype))
q_varlen = torch.cat(q_list, dim=0) if use_fp8:
fp8_dtype = current_platform.fp8_dtype()
q_kernel_list = [q.to(fp8_dtype) for q in q_list]
k_kernel_list = [k.to(fp8_dtype) for k in k_list]
v_kernel_list = [v.to(fp8_dtype) for v in v_list]
q_ref_list = [q.to(dtype) for q in q_kernel_list]
k_ref_list = [k.to(dtype) for k in k_kernel_list]
v_ref_list = [v.to(dtype) for v in v_kernel_list]
else:
q_kernel_list, k_kernel_list, v_kernel_list = q_list, k_list, v_list
q_ref_list, k_ref_list, v_ref_list = q_list, k_list, v_list
q_varlen = torch.cat(q_kernel_list, dim=0)
cu_seqlens_q = torch.zeros(batch_size + 1, device=device, dtype=torch.int32) cu_seqlens_q = torch.zeros(batch_size + 1, device=device, dtype=torch.int32)
cu_seqlens_q[1:] = torch.cumsum( cu_seqlens_q[1:] = torch.cumsum(
torch.tensor([seqlen_q] * batch_size, dtype=torch.int32), dim=0 torch.tensor([seqlen_q] * batch_size, dtype=torch.int32), dim=0
) )
seqused_k = torch.tensor([seqlen_k] * batch_size, device=device, dtype=torch.int32) seqused_k = torch.tensor([seqlen_k] * batch_size, device=device, dtype=torch.int32)
k_cache, v_cache, block_table = make_paged_kv(k_list, v_list, block_size, device, dtype) k_cache, v_cache, block_table = make_paged_kv(k_kernel_list, v_kernel_list, block_size, device, q_varlen.dtype)
q_descale = torch.ones((batch_size, nheads), device=device, dtype=torch.float32) if use_fp8 else None
k_descale = torch.ones((batch_size, nheads_k), device=device, dtype=torch.float32) if use_fp8 else None
v_descale = torch.ones((batch_size, nheads_k), device=device, dtype=torch.float32) if use_fp8 else None
# Build optional tensors # Build optional tensors
alibi_slopes = None alibi_slopes = None
...@@ -1354,7 +1435,7 @@ def test_unified_attn_2d( ...@@ -1354,7 +1435,7 @@ def test_unified_attn_2d(
for i in range(batch_size): for i in range(batch_size):
ref_outs.append( ref_outs.append(
ref_attn( ref_attn(
q_list[i], k_list[i], v_list[i], q_ref_list[i], k_ref_list[i], v_ref_list[i],
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
...@@ -1369,7 +1450,20 @@ def test_unified_attn_2d( ...@@ -1369,7 +1450,20 @@ def test_unified_attn_2d(
ref_out = torch.cat(ref_outs, dim=0) ref_out = torch.cat(ref_outs, dim=0)
# ---- CUDA kernel ---- # ---- CUDA kernel ----
cuda_out, cuda_lse = varlen_fwd_unified( expected_hg_symbol = None
if use_fp8 and seqlen_q >= 16:
expected_hg_symbol = "hg_prefix_prefill_varlen_fwd"
elif use_fp8 or seqlen_q == 1 or 1 < seqlen_q < 16:
expected_hg_symbol = "hg_prefix_decode_varlen_fwd"
elif seqlen_q > 16:
expected_hg_symbol = "hg_prefix_prefill_varlen_fwd"
varlen_runner = (
varlen_fwd_unified
if expected_hg_symbol is None
else lambda *args, **kwargs: varlen_fwd_unified_expect_hg(expected_hg_symbol, *args, **kwargs)
)
cuda_out, cuda_lse = varlen_runner(
q_varlen, k_cache, v_cache, q_varlen, k_cache, v_cache,
cu_seqlens_q, seqused_k, block_table, cu_seqlens_q, seqused_k, block_table,
max_seqlen_q=seqlen_q, max_seqlen_q=seqlen_q,
...@@ -1384,6 +1478,9 @@ def test_unified_attn_2d( ...@@ -1384,6 +1478,9 @@ def test_unified_attn_2d(
s_aux=sinks, s_aux=sinks,
mm_prefix_range=mm_prefix_range, mm_prefix_range=mm_prefix_range,
return_softmax_lse=True, return_softmax_lse=True,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
) )
# # ---- Triton kernel ---- # # ---- Triton kernel ----
...@@ -1416,7 +1513,7 @@ def test_unified_attn_2d( ...@@ -1416,7 +1513,7 @@ def test_unified_attn_2d(
# triton_max_diff = (triton_out - ref_out).abs().max().item() # triton_max_diff = (triton_out - ref_out).abs().max().item()
print( print(
f"\n[{dtype} | causal={causal} | {mha_type} | bs={batch_size} " f"\n[{dtype} | fp8={use_fp8} | causal={causal} | {mha_type} | bs={batch_size} "
f"sq={seqlen_q} sk={seqlen_k} blk={block_size} | " f"sq={seqlen_q} sk={seqlen_k} blk={block_size} | "
f"alibi_sqrt={use_alibi_sqrt} qq_bias={use_qq_bias} " f"alibi_sqrt={use_alibi_sqrt} qq_bias={use_qq_bias} "
f"sinks={use_sinks} mm_prefix={use_mm_prefix}]" f"sinks={use_sinks} mm_prefix={use_mm_prefix}]"
...@@ -1424,7 +1521,7 @@ def test_unified_attn_2d( ...@@ -1424,7 +1521,7 @@ def test_unified_attn_2d(
# f"\n Triton max_diff={triton_max_diff:.4e}" # f"\n Triton max_diff={triton_max_diff:.4e}"
) )
cal_diff(cuda_out, ref_out, "out") cal_diff(cuda_out, ref_out, "out", use_fp8=use_fp8)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
...@@ -1440,8 +1537,8 @@ def benchmark_unified_attention(): ...@@ -1440,8 +1537,8 @@ def benchmark_unified_attention():
torch.manual_seed(42) torch.manual_seed(42)
dtype = torch.float16 dtype = torch.float16
d = 256 d = 128
block_size = 320 block_size = 128
warmup = 10 warmup = 10
repeat = 50 repeat = 50
...@@ -1452,30 +1549,35 @@ def benchmark_unified_attention(): ...@@ -1452,30 +1549,35 @@ def benchmark_unified_attention():
MAX_MM_RANGES = 2 MAX_MM_RANGES = 2
# GQA # GQA
nheads = 8 nheads = 24
nheads_k = 1 nheads_k = 2
# workload shapes # workload shapes
shapes = [ shapes = [
# (8, 2048, 2048), # (8, 2048, 2048),
# (4, 2048, 2048), # (4, 2048, 2048),
(4, 1, 2048), (1, 4, 51200),
(2, 4, 51200),
(4, 4, 51200),
(8, 4, 51200),
(16, 4, 51200),
(32, 4, 51200),
# (4, 2048, 4096), # (4, 2048, 4096),
] ]
# feature configs (C A Q S P) # feature configs (C A Q S P)
feature_configs = [ feature_configs = [
(1,0,0,0,0), (1,0,0,0,0),
(1,1,0,0,0), # (1,1,0,0,0),
(1,0,1,0,0), # (1,0,1,0,0),
(1,0,0,1,0), # (1,0,0,1,0),
(1,0,0,0,1), # (1,0,0,0,1),
(1,1,1,0,0), # (1,1,1,0,0),
(1,1,0,1,0), # (1,1,0,1,0),
(1,1,0,0,1), # (1,1,0,0,1),
(1,1,1,1,0), # (1,1,1,1,0),
(1,1,1,0,1), # (1,1,1,0,1),
(1,1,1,1,1), # (1,1,1,1,1),
] ]
print("\nUnified Attention GQA Benchmark") print("\nUnified Attention GQA Benchmark")
...@@ -1485,7 +1587,8 @@ def benchmark_unified_attention(): ...@@ -1485,7 +1587,8 @@ def benchmark_unified_attention():
f"{'BS':>3} {'SQ':>6} {'SK':>6} | " f"{'BS':>3} {'SQ':>6} {'SK':>6} | "
f"{'C':>1} {'A':>1} {'Q':>1} {'S':>1} {'P':>1} | " f"{'C':>1} {'A':>1} {'Q':>1} {'S':>1} {'P':>1} | "
f"{'CUDA(ms)':>10} {'Triton(ms)':>11} | " f"{'CUDA(ms)':>10} {'Triton(ms)':>11} | "
f"{'CUDA TFLOPS':>11} {'Triton TFLOPS':>13} | {'Speedup':>8}" f"{'CUDA TFLOPS':>11} {'Triton TFLOPS':>13} | "
f"{'CUDA(GB/s)':>10} {'Triton(GB/s)':>12} | {'Speedup':>8}"
) )
print("-" * 120) print("-" * 120)
...@@ -1548,6 +1651,13 @@ def benchmark_unified_attention(): ...@@ -1548,6 +1651,13 @@ def benchmark_unified_attention():
triton_out = torch.zeros_like(q_varlen) triton_out = torch.zeros_like(q_varlen)
total_bytes = estimate_unified_attention_bytes(
batch_size, seqlen_q, seqlen_k, nheads, nheads_k, d, block_size,
q_bytes=torch.finfo(dtype).bits // 8,
k_bytes=torch.finfo(dtype).bits // 8,
v_bytes=torch.finfo(dtype).bits // 8,
)
for C, A, Q, S, P in feature_configs: for C, A, Q, S, P in feature_configs:
causal = bool(C) causal = bool(C)
...@@ -1655,6 +1765,9 @@ def benchmark_unified_attention(): ...@@ -1655,6 +1765,9 @@ def benchmark_unified_attention():
# FLOPs # FLOPs
flops = 4.0 * batch_size * nheads * seqlen_q * seqlen_k * d flops = 4.0 * batch_size * nheads * seqlen_q * seqlen_k * d
cuda_bandwidth = total_bytes / 1e9 / cuda_ms * 1000
triton_bandwidth = total_bytes / 1e9 / triton_ms * 1000
cuda_tflops = flops / cuda_ms / 1e9 cuda_tflops = flops / cuda_ms / 1e9
triton_tflops = flops / triton_ms / 1e9 triton_tflops = flops / triton_ms / 1e9
...@@ -1663,11 +1776,130 @@ def benchmark_unified_attention(): ...@@ -1663,11 +1776,130 @@ def benchmark_unified_attention():
f"{C} {A} {Q} {S} {P} | " f"{C} {A} {Q} {S} {P} | "
f"{cuda_ms:10.3f} {triton_ms:11.3f} | " f"{cuda_ms:10.3f} {triton_ms:11.3f} | "
f"{cuda_tflops:11.2f} {triton_tflops:13.2f} | " f"{cuda_tflops:11.2f} {triton_tflops:13.2f} | "
f"{cuda_bandwidth:10.2f} {triton_bandwidth:12.2f} | "
f"{triton_ms/cuda_ms:8.2f}x" f"{triton_ms/cuda_ms:8.2f}x"
) )
print("=" * 120) print("=" * 120)
def benchmark_hg_b16_fp8_pa():
device = torch.device("cuda")
torch.manual_seed(42)
dtype = torch.bfloat16
fp8_dtype = current_platform.fp8_dtype()
d = 128
block_size = 128
warmup = 10
repeat = 50
nheads = 24
nheads_k = 2
softcap = 0.0
shapes = [
(bs, seqlen_q, 51200)
for seqlen_q in (1, 4)
for bs in (1, 2, 4, 8, 16, 32)
]
shapes += [
(bs, seqlen_q, 4096)
for seqlen_q in (32, 128)
for bs in (1, 2, 4)
]
windows = [(-1, -1), (511, 0)]
def time_fn(fn):
for _ in range(warmup):
fn()
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(repeat):
fn()
torch.cuda.synchronize()
return (time.perf_counter() - start) / repeat * 1000
print("\nHG Unified PA BF16/FP8 Benchmark")
print("=" * 104)
print(
f"{'BS':>3} {'SQ':>3} {'SK':>6} {'WINDOW':>12} | "
f"{'BF16(ms)':>9} {'FP8(ms)':>9} {'Speedup':>8} | "
f"{'BF16 GB/s':>10} {'FP8 GB/s':>9}"
)
print("-" * 104)
for window_size in windows:
for batch_size, seqlen_q, seqlen_k in shapes:
if seqlen_q >= 16 and window_size != (-1, -1):
continue
softmax_scale = d ** (-0.5)
q_list = [torch.randn(seqlen_q, nheads, d, device=device, dtype=dtype) for _ in range(batch_size)]
k_list = [torch.randn(seqlen_k, nheads_k, d, device=device, dtype=dtype) for _ in range(batch_size)]
v_list = [torch.randn(seqlen_k, nheads_k, d, device=device, dtype=dtype) for _ in range(batch_size)]
q_b16 = torch.cat(q_list, dim=0)
k_b16, v_b16, block_table = make_paged_kv(k_list, v_list, block_size, device, dtype)
q_fp8 = q_b16.to(fp8_dtype)
k_fp8 = k_b16.to(fp8_dtype)
v_fp8 = v_b16.to(fp8_dtype)
cu_seqlens_q = torch.zeros(batch_size + 1, device=device, dtype=torch.int32)
cu_seqlens_q[1:] = torch.cumsum(
torch.tensor([seqlen_q] * batch_size, dtype=torch.int32), dim=0
)
seqused_k = torch.tensor([seqlen_k] * batch_size, device=device, dtype=torch.int32)
q_descale = torch.ones((batch_size, nheads), device=device, dtype=torch.float32)
k_descale = torch.ones((batch_size, nheads_k), device=device, dtype=torch.float32)
v_descale = torch.ones((batch_size, nheads_k), device=device, dtype=torch.float32)
def run_b16():
varlen_fwd_unified(
q_b16, k_b16, v_b16, cu_seqlens_q, seqused_k, block_table,
max_seqlen_q=seqlen_q, max_seqlen_k=seqlen_k,
softmax_scale=softmax_scale, causal=True,
window_size=window_size, softcap=softcap,
return_softmax_lse=False,
)
def run_fp8():
varlen_fwd_unified(
q_fp8, k_fp8, v_fp8, cu_seqlens_q, seqused_k, block_table,
max_seqlen_q=seqlen_q, max_seqlen_k=seqlen_k,
softmax_scale=softmax_scale, causal=True,
window_size=window_size, softcap=softcap,
return_softmax_lse=False,
q_descale=q_descale, k_descale=k_descale, v_descale=v_descale,
)
run_b16()
run_fp8()
torch.cuda.synchronize()
b16_ms = time_fn(run_b16)
fp8_ms = time_fn(run_fp8)
b16_bytes = estimate_unified_attention_bytes(
batch_size, seqlen_q, seqlen_k, nheads, nheads_k, d, block_size,
q_bytes=2, k_bytes=2, v_bytes=2, out_bytes=2,
)
fp8_bytes = estimate_unified_attention_bytes(
batch_size, seqlen_q, seqlen_k, nheads, nheads_k, d, block_size,
q_bytes=1, k_bytes=1, v_bytes=1, out_bytes=2,
)
print(
f"{batch_size:3d} {seqlen_q:3d} {seqlen_k:6d} {str(window_size):>12} | "
f"{b16_ms:9.3f} {fp8_ms:9.3f} {b16_ms / fp8_ms:8.2f} | "
f"{b16_bytes / 1e9 / b16_ms * 1000:10.2f} {fp8_bytes / 1e9 / fp8_ms * 1000:9.2f}"
)
print("=" * 104)
if __name__ == "__main__": if __name__ == "__main__":
benchmark_unified_attention() if os.getenv("RUN_UNIFIED_BENCHMARK") == "1":
\ No newline at end of file benchmark_hg_b16_fp8_pa()
else:
test_unified_attn_2d(
batch_size=1, seqlen_q=1, seqlen_k=40960, block_size=128,
d=128, causal=False, window_size=(511, 0), softcap=0.0,
mha_type="gqa", dtype=torch.bfloat16,
use_alibi_sqrt=False, use_qq_bias=False, use_sinks=False, use_mm_prefix=False,
use_fp8=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment