Commit 5248d7d2 authored by hly's avatar hly
Browse files

Import latest aicc hipcc fp8 pa snapshot.

Source: feature/aicc-hipcc-unified-attn-fp8-pa @ fc89765
parent c2a1b310
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
#ifdef BUILD_FA_PERMUTE
#include <hip/hip_runtime.h>
#include "../../include/intrinsic.h"
#include "../flash_fwd_permute_hdim128.h"
template<>
......@@ -141,27 +140,21 @@ __global__ void flash_fwd_varlen_permute_bhsd2bshd<128, 4, 32>(
int32_t thread_offset = lane_id_col * 8;
// 一次读取 4x128 的 Half 到 LDS
#if defined(__gfx936__) || defined(__gfx938__)
{
auto *lds_ptr = (__attribute__((address_space(3))) int *)(
reinterpret_cast<size_t>(lds) + static_cast<size_t>(lane_id * 4) * sizeof(float));
__builtin_hcu_raw_buffer_load_lds(
read_buffer,
lds_ptr,
lds + lane_id * 4,
16,
(block_offset + thread_offset) << 1, /* v_offset */
0, /* s_offset */
0, /* immediate offset, instruction offset */
0 /* auxilariy data| bit 0: glc, bit 1: slc, bit 2: dlc, bit 3: cache swizzle */
);
}
#else
*(vec4_fp32*)(lds + lane_id * 4) = *(vec4_fp32*)(read_ptr + ((block_offset + thread_offset) >> 1));
#endif
// 从 LDS 转置后, 64 个线程写 4 行, 每次写 128 个 Half, 对应 fetch * 4 + [0,3] 的 seqlen
vec2_fp32 data0, data1;
inlineasm_fa_ds_read2_b32(lds, lane_id, data0, 0, 64);
inlineasm_fa_ds_read2_b32(lds, lane_id + 128, data1, 0, 64);
asm volatile("s_waitcnt lgkmcnt(0)\n");
vec2_fp32 data0 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float*)lds + lane_id, 0, 64, false);
vec2_fp32 data1 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float*)lds + lane_id + 128, 0, 64, false);
write_ptr[(min(actual_seqlen - 1, fetch * 4 + 0) * num_heads * kHeadDim + (lane_id << 1)) >> 1] = data0[0];
write_ptr[(min(actual_seqlen - 1, fetch * 4 + 1) * num_heads * kHeadDim + (lane_id << 1)) >> 1] = data0[1];
write_ptr[(min(actual_seqlen - 1, fetch * 4 + 2) * num_heads * kHeadDim + (lane_id << 1)) >> 1] = data1[0];
......
#ifdef BUILD_FA_PERMUTE
#include <hip/hip_runtime.h>
#include "../../include/intrinsic.h"
#include "../flash_fwd_permute_hdim128.h"
......@@ -119,28 +118,22 @@ __global__ void flash_fwd_varlen_permute_bshd2bhsd<128, 4, 0>(
int32_t thread_offset = lane_id_row * 128 + lane_id_col * 8;
// block 地址 + thread 地址, << 1 是获取偏移的字节数, 写到 lds 是为了转置一下
#if defined(__gfx936__) || defined(__gfx938__)
{
auto *lds_ptr = (__attribute__((address_space(3))) int *)(
reinterpret_cast<size_t>(lds) + static_cast<size_t>(lane_id * 4) * sizeof(float));
__builtin_hcu_raw_buffer_load_lds(
read_buffer,
lds_ptr,
lds + lane_id * 4,
16,
(block_offset + thread_offset) << 1, /* v_offset */
0, /* s_offset */
0, /* immediate offset, instruction offset */
0 /* auxilariy data| bit 0: glc, bit 1: slc, bit 2: dlc, bit 3: cache swizzle */
);
}
#else
*(vec4_fp32*)(lds + lane_id * 4) = *(vec4_fp32*)(read_ptr + ((block_offset + thread_offset) >> 1));
#endif
// 写到 lds 不需要同步, 因为只有一个 wave
// 循环 seqlen_q 次, 每次间隔 4 x 128 个 half, 需要写 4 次
vec2_fp32 data0, data1;
inlineasm_fa_ds_read2_b32(lds, lane_id, data0, 0, 64);
inlineasm_fa_ds_read2_b32(lds, lane_id + 128, data1, 0, 64);
asm volatile("s_waitcnt lgkmcnt(0)\n");
vec2_fp32 data0 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float*)lds + lane_id, 0, 64, false);
vec2_fp32 data1 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float*)lds + lane_id + 128, 0, 64, false);
write_ptr[(fetch * head_dim + (lane_id << 1) + 0 * cur_seqlen_q * head_dim) >> 1] = data0[0];
write_ptr[(fetch * head_dim + (lane_id << 1) + 1 * cur_seqlen_q * head_dim) >> 1] = data0[1];
write_ptr[(fetch * head_dim + (lane_id << 1) + 2 * cur_seqlen_q * head_dim) >> 1] = data1[0];
......@@ -268,28 +261,22 @@ __global__ void flash_fwd_varlen_permute_bshd2bhsd<128, 4, 32>(
int32_t thread_offset = lane_id_row * 128 + lane_id_col * 8;
// block 地址 + thread 地址, << 1 是获取偏移的字节数, 写到 lds 是为了转置一下
#if defined(__gfx936__) || defined(__gfx938__)
{
auto *lds_ptr = (__attribute__((address_space(3))) int *)(
reinterpret_cast<size_t>(lds) + static_cast<size_t>(lane_id * 4) * sizeof(float));
__builtin_hcu_raw_buffer_load_lds(
read_buffer,
lds_ptr,
lds + lane_id * 4,
16,
(block_offset + thread_offset) << 1, /* v_offset */
0, /* s_offset */
0, /* immediate offset, instruction offset */
0 /* auxilariy data| bit 0: glc, bit 1: slc, bit 2: dlc, bit 3: cache swizzle */
);
}
#else
*(vec4_fp32*)(lds + lane_id * 4) = *(vec4_fp32*)(read_ptr + ((block_offset + thread_offset) >> 1));
#endif
// 写到 lds 不需要同步, 因为只有一个 wave
// 循环 seqlen_q 次, 每次间隔 4 x 128 个 half, 需要写 4 次
vec2_fp32 data0, data1;
inlineasm_fa_ds_read2_b32(lds, lane_id, data0, 0, 64);
inlineasm_fa_ds_read2_b32(lds, lane_id + 128, data1, 0, 64);
asm volatile("s_waitcnt lgkmcnt(0)\n");
vec2_fp32 data0 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float*)lds + lane_id, 0, 64, false);
vec2_fp32 data1 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float*)lds + lane_id + 128, 0, 64, false);
write_ptr[(seqlen_limit * head_dim + (lane_id << 1) + 0 * cur_seqlen_q * head_dim) >> 1] = data0[0];
write_ptr[(seqlen_limit * head_dim + (lane_id << 1) + 1 * cur_seqlen_q * head_dim) >> 1] = data0[1];
write_ptr[(seqlen_limit * head_dim + (lane_id << 1) + 2 * cur_seqlen_q * head_dim) >> 1] = data1[0];
......@@ -383,9 +370,8 @@ __global__ void __launch_bounds__(64, 1) flash_fwd_varlen_permute_bshd2bhsd<128,
__builtin_amdgcn_sched_barrier(0);
#endif
// 循环 seqlen_q 次, 每次间隔 4 x 128 个 half, 需要写 4 次
inlineasm_fa_ds_read2_b32(lds, fetch * 256 + lane_id, registers_buffer[fetch * 2], 0, 64);
inlineasm_fa_ds_read2_b32(lds, fetch * 256 + lane_id + 128, registers_buffer[fetch * 2 + 1], 0, 64);
asm volatile("s_waitcnt lgkmcnt(0)\n");
registers_buffer[fetch * 2] = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float*)lds + fetch * 256 + lane_id, 0, 64, false);
registers_buffer[fetch * 2 + 1] = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float*)lds + fetch * 256 + lane_id + 128, 0, 64, false);
}
__builtin_amdgcn_sched_barrier(0);
......
......@@ -26,7 +26,6 @@ if torch.cuda.is_available():
flash_attn_varlen_with_mask_func,
# unified attn functions
varlen_fwd_unified,
fwd_sparse_mean_pool_fast,
)
# triton fa interface
from flash_attn.flash_attn_triton_interface import flash_attn_func as triton_flash_attn_func
......
......@@ -2,7 +2,6 @@
from typing import Optional, Union
from typing import List, Tuple
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -19,6 +18,12 @@ from flash_attn.utils.sparse_utils import hyperparameter_check, get_block_map_me
DEFAULT_FA_VERSION = 2
try:
torch._C._dispatch_find_schema_or_throw("flash_attn2_c_op::varlen_fwd", "")
_has_flash_attn2_c_varlen_fwd = True
except RuntimeError:
_has_flash_attn2_c_varlen_fwd = False
def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
......@@ -59,7 +64,7 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal):
elif head_dim <= 512:
return 64
if torch.__version__ >= "2.4.0":
if torch.__version__ >= "2.4.0" and _has_flash_attn2_c_varlen_fwd:
_torch_custom_op_wrapper = torch.library.custom_op
_torch_register_fake_wrapper = torch.library.register_fake
else:
......@@ -199,7 +204,11 @@ def varlen_fwd_fake(
return out, softmax_lse, p, rng_state
_wrapped_flash_attn_varlen_forward = torch.ops.flash_attn2_c_op.varlen_fwd
_wrapped_flash_attn_varlen_forward = (
torch.ops.flash_attn2_c_op.varlen_fwd
if _has_flash_attn2_c_varlen_fwd
else _flash_attn_varlen_forward
)
def _flash_attn_backward(
......@@ -2008,7 +2017,7 @@ def vllm_flash_attn_varlen_func(
# if mtp, k head must be 1.
# todo : support k head >1
is_mtp = (max_seqlen_q*bs==total_q and max_seqlen_q>1 and max_seqlen_q<5)
if (max_seqlen_q==1 or is_mtp ) and real_window_size[0]==-1:
if (max_seqlen_q==1 or (is_mtp and k.shape[1]==1)) and real_window_size[0]==-1:
if out==None:
if q.dtype == torch.float8_e4m3fn or q.dtype == torch.float8_e5m2:
out = torch.empty(q.size(),device = q.device,dtype=torch.bfloat16)
......@@ -2732,6 +2741,9 @@ def varlen_fwd_unified(
*,
out=None,
return_softmax_lse=False,
q_descale=None,
k_descale=None,
v_descale=None,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
......@@ -2741,6 +2753,78 @@ def varlen_fwd_unified(
window_size_left, window_size_right = window_size
if k.dtype.is_floating_point:
k_dtype_bits = torch.finfo(k.dtype).bits
else:
k_dtype_bits = torch.iinfo(k.dtype).bits
is_mtp = (max_seqlen_q * seqused_k.size(0) == q.shape[0] and 1 < max_seqlen_q < 16)
if max_seqlen_q >= 16:
bshd_prefill = _require_hg_varlen_symbol("hg_prefix_prefill_varlen_fwd")
fa_output, *rest_extend = bshd_prefill(
q,
k,
v,
out, # out_
cu_seqlens_q,
None, # cu_seqlens_k
seqused_k,
alibi_slopes,
block_table,
max_seqlen_q,
max_seqlen_k,
0.0, # dropout
softmax_scale,
False, # zero_tensors
causal,
window_size[0],
window_size[1],
softcap,
return_softmax_lse,
1,
None if (k_dtype_bits == 16) else q_descale,
None if (k_dtype_bits == 16) else k_descale,
None if (k_dtype_bits == 16) else v_descale,
True,
)
return (fa_output, rest_extend[0]) if return_softmax_lse else fa_output
bs = seqused_k.size(0)
total_q = q.shape[0]
if max_seqlen_q == 1 or is_mtp:
assert not use_alibi_sqrt and qq_bias is None and s_aux is None and mm_prefix_range is None, \
f"Arguments not supported in hg_bshd_decode"
bshd_pa_decode = _require_hg_varlen_symbol("hg_prefix_decode_varlen_fwd")
result = bshd_pa_decode(
q,
k,
v,
out,
cu_seqlens_q,
None,
seqused_k,
alibi_slopes,
block_table,
max_seqlen_q,
max_seqlen_k,
0.0,
softmax_scale,
False,
causal,
window_size[0],
window_size[1],
softcap,
return_softmax_lse,
1,
None if (k_dtype_bits == 16) else q_descale,
None if (k_dtype_bits == 16) else k_descale,
None if (k_dtype_bits == 16) else v_descale,
True,
)
fa_output = result[0]
return (fa_output, result[1]) if return_softmax_lse else fa_output
out, softmax_lse = flash_attn_cuda.varlen_fwd_unified(
q,
k,
......@@ -3816,22 +3900,6 @@ def spas_fa2_attn_meansim_topk_varlen_cuda(
)
def fwd_sparse_mean_pool_fast(x,BLK,mean=None):
return flash_attn_cuda.fwd_sparse_mean_pool_fast(x,BLK,mean)
def get_block_map_fast(q, k, topk_ratio, BLKQ=128, BLKK=64):
meank = torch.mean(k, dim=-3, keepdim=True)
pooled_kblocks = fwd_sparse_mean_pool_fast(k, BLKK, meank)
pooled_qblocks = fwd_sparse_mean_pool_fast(q,BLKQ)
pooled_score = pooled_qblocks @ pooled_kblocks.transpose(-1, -2)
K = pooled_score.shape[-1]
topk = min(K, int(topk_ratio * K))
lut = torch.topk(pooled_score, topk, dim=-1, sorted=False).indices
sparse_map = torch.zeros_like(pooled_score, dtype=torch.int8)
sparse_map.scatter_(-1, lut, 1)
return sparse_map, lut, topk
class SparseLinearAttention(nn.Module):
def __init__(self, head_dim, topk, feature_map='softmax', use_bf16=True, use_fp8=False, tie_feature_map_qk=True):
R'''
......@@ -3888,15 +3956,19 @@ class SparseLinearAttention(nn.Module):
'''
B, seqlen_q, H, headdim = q.shape
q_bhld = q.transpose(1, 2).contiguous() # (B, H, L, D)
k_bhld = k.transpose(1, 2).contiguous()
# v_bhld = v.transpose(1, 2).contiguous()
# import pdb
# pdb.set_trace()
if headdim == 64:
block_m = 64 if seqlen_q <= 2048 else 128
elif headdim == 128:
block_m = 64 if seqlen_q <= 2048 else 128
block_k = 64
if headdim == 64:
sparse_map, lut, real_topk = get_block_map(q.transpose(1, 2).contiguous(), k.transpose(1, 2).contiguous(), topk_ratio=self.topk, BLKQ=block_m, BLKK=block_k)
else:
sparse_map, lut, real_topk = get_block_map_fast(q, k, topk_ratio=self.topk, BLKQ=block_m, BLKK=block_k)
sparse_map, lut, real_topk = get_block_map(q_bhld, k_bhld, topk_ratio=self.topk, BLKQ=block_m, BLKK=block_k)
q = q.to(self.dtype)
k = k.to(self.dtype)
......@@ -3912,10 +3984,10 @@ class SparseLinearAttention(nn.Module):
seqlen_k = k.size(1)
num_blocks_q = (seqlen_q + block_m - 1) // block_m
num_blocks_k = (seqlen_k + block_k - 1) // block_k
column_count = torch.empty(
column_count = torch.zeros(
(B, H, num_blocks_q), dtype=torch.int32, device=q.device
)
column_index = torch.empty(
column_index = torch.zeros(
(B, H, num_blocks_q, 1), dtype=torch.int32, device=q.device
)
......@@ -3984,56 +4056,14 @@ def sparse_attn_with_sla(
maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
dtype = torch.bfloat16 if use_bf16 else torch.float16
dtype = torch.float8_e4m3fn if use_fp8 else dtype
B, seqlen_q, H, headdim = q.shape
assert not (use_bf16 and use_fp8), "Only one of bf16 and fp8 can be used."
assert headdim in (64, 128), "Dtype fp16/bf16 only support dim (64, 128)."
assert not (use_fp8 and headdim==64), "Dtype fp8 only support dim 128."
if headdim == 64:
block_m = 64 if seqlen_q <= 2048 else 128
elif headdim == 128:
block_m = 64 if seqlen_q <= 2048 else 128
block_k = 64
if headdim == 64:
sparse_map, lut, real_topk = get_block_map(q.transpose(1, 2).contiguous(), k.transpose(1, 2).contiguous(), topk_ratio=topk, BLKQ=block_m, BLKK=block_k)
else:
sparse_map, lut, real_topk = get_block_map_fast(q, k, topk_ratio=topk, BLKQ=block_m, BLKK=block_k)
q = q.to(dtype)
k = k.to(dtype)
v = v.to(dtype)
########## SPARGE BEGIN ##########
headdim = q.size(-1)
block_offset, block_count = block_map_to_block_offset_triton(sparse_map)
block_offset = block_offset * block_k
softmax_scale = 1.0 / (headdim ** 0.5)
assert headdim in [64, 128], "headdim should be in [64, 128]. For other headdim, you can use padding and specify the softmax scale."
seqlen_k = k.size(1)
num_blocks_q = (seqlen_q + block_m - 1) // block_m
num_blocks_k = (seqlen_k + block_k - 1) // block_k
column_count = torch.empty(
(B, H, num_blocks_q), dtype=torch.int32, device=q.device
)
column_index = torch.empty(
(B, H, num_blocks_q, 1), dtype=torch.int32, device=q.device
)
o_s = sparse_attn_func(
q, k, v, # Use original BLHD layout
block_count=block_count,
block_offset=block_offset,
column_count=column_count,
column_index=column_index,
softmax_scale=softmax_scale,
is_sla=True,
)
if return_sparsity:
return o_s, real_topk / sparse_map.shape[-1]
else:
return o_s
attn = SparseLinearAttention(
head_dim=q.size(-1),
topk=topk, # = 1 - sparsity
feature_map=feature_map, # options: elu, relu, softmax
use_bf16=use_bf16,
use_fp8=use_fp8,
).cuda()
return attn(q, k, v, return_sparsity=return_sparsity)
def _require_hg_varlen_symbol(name: str):
......@@ -4045,15 +4075,6 @@ def _require_hg_varlen_symbol(name: str):
return symbol
def _apply_hg_kvcache_safe_env() -> None:
# DTK/gfx938 PA launch is sensitive to these knobs. Keep the old HG-safe
# defaults unless the caller explicitly asks for the raw kernel selection.
if os.environ.get("HG_KVCACHE_RAW_KERNEL") == "1":
return
os.environ.setdefault("PA_NO_MLS", "1")
os.environ.setdefault("PA_USE_TILE32X32", "1")
def _validate_hg_paged_kv_contract(k_cache, v_cache) -> None:
if k_cache.dim() != 4 or v_cache.dim() != 4:
raise ValueError("HG paged KV cache expects k and v to both be 4D tensors")
......@@ -4067,53 +4088,6 @@ def _validate_hg_paged_kv_contract(k_cache, v_cache) -> None:
)
def _normalize_hg_paged_q_scales(q_scale, batch_size, num_heads_q, num_heads_k):
if q_scale is None:
raise ValueError("q_descale must be provided for HG int8 paged-kvcache path")
q_scale = maybe_contiguous(q_scale)
if q_scale.dim() == 1:
if q_scale.numel() == batch_size * num_heads_q:
q_scale = q_scale.view(batch_size, num_heads_q)
elif q_scale.numel() == batch_size * num_heads_k:
q_scale = q_scale.view(batch_size, num_heads_k)
if q_scale.dim() != 2 or q_scale.shape[0] != batch_size:
raise ValueError(
"q_descale must have shape [batch_size, num_heads_q] "
"or [batch_size, num_heads_k] for HG int8 paged-kvcache path"
)
if q_scale.shape[1] == num_heads_q:
return q_scale.contiguous()
if q_scale.shape[1] == num_heads_k and num_heads_q % num_heads_k == 0:
return q_scale.repeat_interleave(num_heads_q // num_heads_k, dim=1).contiguous()
raise ValueError(
"q_descale must have shape [batch_size, num_heads_q] "
"or [batch_size, num_heads_k] for HG int8 paged-kvcache path"
)
def _expand_hg_paged_kv_scales(scale, block_table, page_block_size, num_heads_k, name):
if scale is None:
raise ValueError(f"{name} must be provided for HG int8 paged-kvcache path")
scale = maybe_contiguous(scale)
batch_size = block_table.shape[0]
if scale.dim() == 1 and scale.numel() == batch_size * num_heads_k:
scale = scale.view(batch_size, num_heads_k)
if scale.dim() != 2 or scale.shape != (batch_size, num_heads_k):
raise ValueError(
f"{name} must have shape [batch_size, num_heads_k] for HG int8 paged-kvcache path"
)
expanded = torch.empty(
(int(block_table.max().item()) + 1, page_block_size, num_heads_k),
device=scale.device,
dtype=scale.dtype,
)
for batch_idx in range(batch_size):
block_ids = block_table[batch_idx].to(dtype=torch.long)
expanded[block_ids] = scale[batch_idx].view(1, 1, num_heads_k).expand(
block_ids.numel(), page_block_size, num_heads_k
)
return expanded.contiguous()
def hg_flash_attn_varlen_func(
q,
k,
......@@ -4333,10 +4307,6 @@ def hg_flash_attn_varlen_func(
raise ValueError("cu_seqlens_k and seqused_k cannot be provided at the same time")
if block_table is None:
raise ValueError("block_table must be provided when seqused_k is used")
if return_attn_probs:
raise NotImplementedError(
"return_attn_probs is not supported for HG prefix/paged compatibility paths"
)
if dropout_p != 0.0:
raise NotImplementedError("dropout_p must be 0.0 for HG prefix/paged compatibility paths")
......@@ -4366,7 +4336,7 @@ def hg_flash_attn_varlen_func(
window_size[0],
window_size[1],
softcap,
return_softmax_lse,
return_softmax_lse or return_attn_probs,
1,
None if k_dtype_bits == 16 else q_descale,
None if k_dtype_bits == 16 else k_descale,
......@@ -4374,7 +4344,7 @@ def hg_flash_attn_varlen_func(
is_bf16_output,
)
fa_output = result[0]
return (fa_output, result[1]) if return_softmax_lse else fa_output
return (fa_output, result[1]) if (return_softmax_lse or return_attn_probs) else fa_output
if k_dtype_bits == 16:
prefix_decode = _require_hg_varlen_symbol("hg_prefix_decode_varlen_fwd")
......@@ -4397,13 +4367,17 @@ def hg_flash_attn_varlen_func(
window_size[0],
window_size[1],
softcap,
return_softmax_lse,
return_softmax_lse or return_attn_probs,
1,
None,
None,
None,
is_bf16_output,
)
fa_output = result[0]
return (fa_output, result[1]) if return_softmax_lse else fa_output
return (fa_output, result[1]) if (return_softmax_lse or return_attn_probs) else fa_output
if return_softmax_lse:
if return_softmax_lse or return_attn_probs:
raise NotImplementedError(
"return_softmax_lse is not supported for the HG paged-kvcache compatibility path"
)
......@@ -4411,28 +4385,6 @@ def hg_flash_attn_varlen_func(
_validate_hg_paged_kv_contract(k, v)
if k.shape[1] != 128:
raise NotImplementedError("HG paged-kvcache path currently requires page_block_size == 128")
_apply_hg_kvcache_safe_env()
q_descale = _normalize_hg_paged_q_scales(
q_descale,
batch_size=block_table.shape[0],
num_heads_q=q.shape[1],
num_heads_k=k.shape[2],
)
k_descale = _expand_hg_paged_kv_scales(
k_descale,
block_table=block_table,
page_block_size=k.shape[1],
num_heads_k=k.shape[2],
name="k_descale",
)
v_descale = _expand_hg_paged_kv_scales(
v_descale,
block_table=block_table,
page_block_size=v.shape[1],
num_heads_k=v.shape[2],
name="v_descale",
)
hg_kvcache = _require_hg_varlen_symbol("hg_fwd_kvcache_bshd")
result = hg_kvcache(
q.unsqueeze(1),
......@@ -4442,7 +4394,7 @@ def hg_flash_attn_varlen_func(
None,
None,
seqused_k,
max_seqlen_k if max_seqlen_k > 0 else int(seqused_k.max().item()),
1,
None,
None,
None,
......@@ -4456,12 +4408,12 @@ def hg_flash_attn_varlen_func(
window_size[1],
softcap,
False,
num_splits,
-1,
None,
None,
q_descale,
k_descale,
v_descale,
None if k_dtype_bits == 16 else q_descale,
None if k_dtype_bits == 16 else k_descale,
None if k_dtype_bits == 16 else v_descale,
is_bf16_output,
)
return result[0].squeeze(1)
......@@ -112,6 +112,8 @@ _HG_EXPLICIT_SOURCES_BY_MODE = {
"src/target/flash_fwd_hdim128_padding_mask_fp16.cpp",
"src/target/flash_fwd_hdim128_prefix_prefill_bf16.cpp",
"src/target/flash_fwd_hdim128_prefix_prefill_fp16.cpp",
"src/target/flash_fp8_fwd_hdim128_prefix_prefill_bf16.cpp",
"src/target/flash_fp8_fwd_hdim128_prefix_prefill_fp16.cpp",
"src/target/flash_fwd_hdim160_bf16.cpp",
"src/target/flash_fwd_hdim160_fp16.cpp",
"src/target/flash_fwd_hdim192_bf16.cpp",
......@@ -262,6 +264,32 @@ def _ninja_shell_join(args) -> str:
return " ".join(_ninja_escape(shlex.quote(str(x))) for x in args)
def _resolve_hg_compiler() -> str:
candidates = [
os.environ.get("FLASH_ATTN_HG_COMPILER"),
"/opt/dtk/bin/aicc",
"aicc",
]
for compiler in candidates:
if compiler and shutil.which(compiler):
return compiler
requested = [c for c in candidates if c]
raise RuntimeError(
"error: no usable HG aicc compiler found from: "
+ ", ".join(repr(c) for c in requested)
+ ". Set FLASH_ATTN_HG_COMPILER to the DTK aicc path."
)
def _is_rocm_5_7() -> bool:
version_file = "/opt/rocm/.info/version"
try:
with open(version_file, "r", encoding="utf-8") as f:
return f.read(100).startswith("5.7.0")
except OSError:
return False
def compute_hg_build_descriptor(
src_dir,
build_dir,
......@@ -279,7 +307,7 @@ def compute_hg_build_descriptor(
BUILD_FA_FWD = BUILD_FA_BWD = BUILD_FA_KVCACHE = False
BUILD_FA_PERMUTE = BUILD_FLASHMLA = False
BUILD_C_INTERFACE = False
BUILD_ASM = False
BUILD_ASM = True
FA_DEBUG = True
FA_DEBUG_SUM_MAX = False
HEADDIM_128_ONLY = False
......@@ -358,11 +386,8 @@ def compute_hg_build_descriptor(
GFX_VERSION = "938"
ROCM_PATH = os.environ.get("ROCM_PATH", os.environ.get("ROCM_HOME", "/opt/rocm"))
HG_COMPILER = _resolve_hg_compiler()
if not shutil.which("hipcc"):
raise RuntimeError(
"error: hipcc not found in PATH. Please activate the DTK environment first."
)
if not os.path.isdir(os.path.join(ROCM_PATH, "include")):
raise RuntimeError(
f"error: {ROCM_PATH}/include not found. "
......@@ -444,6 +469,8 @@ def compute_hg_build_descriptor(
DEFINES.append("-DPA_PAGE_BLOCK_SIZE")
if MLA_PAGE_BLOCK_SIZE:
DEFINES.append("-DMLA_PAGE_BLOCK_SIZE")
if _is_rocm_5_7():
DEFINES.append("-DROCM_5_7")
OFFLOAD_FLAGS = [f"--offload-arch=gfx{_g}" for _g in GFX_VERSION.split(";") if _g]
......@@ -457,22 +484,28 @@ def compute_hg_build_descriptor(
INCLUDE_FLAGS += TORCH_INCLUDE_FLAGS
COMMON_FLAGS = [
"-fPIC",
"-O3",
"-std=c++17",
"-fPIC",
"-ffast-math",
"-fno-finite-math-only",
"-fno-gpu-rdc",
"-mno-fma",
]
DTK_DEVICE_FLAGS = [
"-DHIP_ENABLE_WARP_SYNC_BUILTINS",
"-mllvm",
"-slp-phi-tree-bb-max-size=10000",
"-support-768-vgprs=true",
"-mllvm",
"-enable-num-vgprs-512=true",
"-Rpass-analysis=kernel-resource-usage",
"-ftemplate-backtrace-limit=0",
"-disable-machine-sink",
"-mcode-object-version=5",
]
if not _is_rocm_5_7():
DTK_DEVICE_FLAGS += [
"-mllvm",
"-amdgpu-enable-rewrite-partial-reg-uses=false",
"-mllvm",
"-allow-gvn-convergent-call=true",
"-mllvm",
"-disallow-uniform-vmed3-combine=true",
]
if os.environ.get("FLASH_ATTN_HG_SAVE_TEMPS", "") == "1":
DTK_DEVICE_FLAGS.append("--save-temps")
......@@ -555,6 +588,7 @@ def compute_hg_build_descriptor(
"obj_dir": obj_dir,
"sources": _all_sources,
"objects": objects,
"compiler": HG_COMPILER,
"compile_flags": compile_flags,
"link_flags": link_flags,
"out_so": out_so,
......@@ -566,16 +600,17 @@ def run_hg_ninja_build(descriptor: dict) -> None:
"""Write build_hg.ninja and run ninja (parallel via MAX_JOBS)."""
build_dir = descriptor["build_dir"]
ninja_file = descriptor["ninja_path"]
compiler = _ninja_shell_join([descriptor["compiler"]])
out_so_ninja = _ninja_escape_path(descriptor["out_so"])
lines = [
"ninja_required_version = 1.3",
"",
"rule hipcc_compile",
" command = hipcc -c $in -o $out $FLAGS",
"rule hg_compile",
f" command = {compiler} -c $in -o $out $FLAGS",
" description = HG compile $in",
"",
"rule hipcc_link",
" command = hipcc -shared -o $out @$out.rsp $LINK_FLAGS",
"rule hg_link",
f" command = {compiler} -shared -o $out @$out.rsp $LINK_FLAGS",
" rspfile = $out.rsp",
" rspfile_content = $in",
" description = HG link $out",
......@@ -585,9 +620,9 @@ def run_hg_ninja_build(descriptor: dict) -> None:
"",
]
for src, obj in zip(descriptor["sources"], descriptor["objects"]):
lines.append(f"build {_ninja_escape_path(obj)}: hipcc_compile {_ninja_escape_path(src)}")
lines.append(f"build {_ninja_escape_path(obj)}: hg_compile {_ninja_escape_path(src)}")
obj_list = " ".join(_ninja_escape_path(obj) for obj in descriptor["objects"])
lines.append(f"build {out_so_ninja}: hipcc_link {obj_list}")
lines.append(f"build {out_so_ninja}: hg_link {obj_list}")
lines.append("")
os.makedirs(build_dir, exist_ok=True)
......@@ -878,7 +913,6 @@ if not SKIP_CUDA_BUILD:
"csrc/flash_attn/src/flash_fwd_unified_hdim256_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_unified_hdim256_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_unified_hdim256_bf16_sm80.cu",
"csrc/flash_attn/src/flash_sparse_util.cu"
],
extra_compile_args={
"cxx": ["-O3", "-w", "-std=c++17",
......@@ -944,6 +978,7 @@ if not SKIP_CUDA_BUILD:
Path(this_dir) / "csrc" / "flash_attn",
Path(this_dir) / "csrc" / "flash_attn" / "src",
Path(this_dir) / "csrc" / "cutlass" / "include",
"/public/home/huangly/数据采集/cutlass_3.2.1/include" ],
)
)
......@@ -1051,13 +1086,16 @@ class NinjaBuildExtension(BuildExtension):
if os.path.isdir(HG_SRC_DIR):
os.makedirs(HG_BUILD_DIR, exist_ok=True)
_maybe_clean_hg_build_dir(HG_BUILD_DIR)
print("=== Building HG libflash_attention.so (mode=all, gfx938, ninja) ===")
try:
desc = compute_hg_build_descriptor(
HG_SRC_DIR,
HG_BUILD_DIR,
mode="all",
extra_options_raw="-DGFX_VERSION=938 -Wl,-Bsymbolic",
extra_options_raw="-DGFX_VERSION=938;936 -Wl,-Bsymbolic",
)
print(
"=== Building HG libflash_attention.so "
f"(mode=all, gfx938/936, ninja, compiler={desc['compiler']}) ==="
)
run_hg_ninja_build(desc)
if os.path.isfile(HG_SO_BUILD):
......
import os
import sys
import math
import torch
import pickle
import time
import numpy
import argparse
import random
from datetime import datetime
use_cuda_toolkits = os.path.exists("/usr/local/cuda/bin/nvcc")
use_rocm_toolkits = os.path.exists("/opt/rocm/llvm/bin/clang")
use_dtk_toolkits = os.path.exists("/opt/dtk/bin/aicc")
if (use_cuda_toolkits):
from vllm.vllm_flash_attn import flash_attn_varlen_func
elif (use_rocm_toolkits or use_dtk_toolkits):
try:
from flash_attention_interface import flash_attn_varlen_func, flash_attn_2_cuda, flash_attn_with_kvcache
except ModuleNotFoundError:
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
import flash_attn_2_cuda as flash_attn_cuda
def _require_hg_varlen_symbol(name: str):
symbol = getattr(flash_attn_cuda, name, None)
if symbol is None:
raise RuntimeError(
f"{name} is unavailable in this build. Rebuild flash_attn with HAS_HG_DISPATCH enabled."
)
return symbol
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, do_assert=True, cos_threshold=1e-5) -> None:
assert x.shape == y.shape, "for {}, x and y must have the same shape".format(name)
x, y = x.double(), y.double()
RMSE = ((x - y) * (x - y)).mean().sqrt().item()
cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12)
amax_diff = (x - y).abs().max().item()
rel_diff_mean = (x / y).abs().mean().item()
rel_diff_max = (x / y).abs().max().item()
print("name:{} cos_diff={:.12f}, RMSE=\x1b[35m{:.12f}\x1b[0m, amax_diff=\x1b[35m{:.12f}\x1b[0m, REL=\x1b[35m{:.12f}\x1b[0m, rel_max=\x1b[35m{:.12f}\x1b[0m".format(
name, cos_diff, RMSE, amax_diff, rel_diff_mean, rel_diff_max))
if (do_assert): assert cos_diff < cos_threshold
def scaled_dot_product_attention(__query, __key, __value, h_q, h_kv, is_causal=False, USE_CPU=False, return_max_sum=False, original_seqlen_kv=0, split_slice=0, is_bshd=False, window_size=(-1, -1)):
__query = __query.transpose(0, 1).contiguous()
__key = __key.transpose(0, 1).contiguous()
__value = __value.transpose(0, 1).contiguous()
# 判断是否使用 CPU 计算 golden, 避免 blas 的影响
original_device = __query.device
original_dtype = __query.dtype
if (USE_CPU):
__query = __query.cpu()
__key = __key.cpu()
__value = __value.cpu()
# print("scaled_dot_product_attention: ", query.shape, key.shape, value.shape)
__query = __query.float()
__key = __key.float()
__value = __value.float()
# 如果按照官方的方法返回
if (not return_max_sum):
__key = __key.repeat_interleave(h_q // h_kv, dim=0)
__value = __value.repeat_interleave(h_q // h_kv, dim=0)
attn_weight = __query @ __key.transpose(-2, -1) / math.sqrt(__query.size(-1))
# MTP > 1, causal/local mask applied
if (window_size != (-1, -1)):
s_q = __query.shape[-2]
s_k = __key.shape[-2]
left, right = window_size
if left < 0:
left = s_k
if right < 0:
right = s_k
row_idx = torch.arange(s_q, dtype=torch.int32, device=attn_weight.device)[:, None]
col_idx = torch.arange(s_k, dtype=torch.int32, device=attn_weight.device)[None, :]
col_idx_limit_left = row_idx + s_k - s_q - left
col_idx_limit_right = row_idx + s_k - s_q + right
temp_mask = (col_idx >= col_idx_limit_left) & (col_idx <= col_idx_limit_right)
attn_weight = attn_weight.masked_fill(temp_mask.logical_not(), float("-inf"))
elif (is_causal):
s_q = __query.shape[-2]
s_k = __key.shape[-2]
attn_bias = torch.zeros(s_q, s_k, dtype=__query.dtype, device=attn_weight.device)
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool, device=attn_weight.device).tril(diagonal=s_k - s_q)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(__query.dtype)
attn_weight += attn_bias
# some codes for debug
scores_max = attn_weight.to(torch.float32).max(-1)[0]
scores_sum = torch.exp(attn_weight.to(torch.float32) - scores_max.unsqueeze(-1)).sum(dim=-1)
# original codes
lse = attn_weight.logsumexp(dim=-1)
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
output = attn_weight @ __value
output = output.transpose(0, 1).contiguous()
return output.to(original_device).to(original_dtype), lse.to(original_device), scores_max.to(original_device), scores_sum.to(original_device)
def set_random_seed(seed=0):
random.seed(seed) # 设置 Python 的随机种子
numpy.random.seed(seed) # 设置 NumPy 的随机种子
torch.manual_seed(seed) # 设置 PyTorch 的随机种子
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed) # 设置所有 GPU 的随机种子
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['OMP_NUM_THREADS'] = '1' # 设置 OpenMP 的线程数
torch.set_num_threads(1) # 设置 PyTorch 的线程数
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--load', default=False, action='store_true', help='load path')
parser.add_argument('--trace', default=False, action='store_true', help='whether dump perf traces')
parser.add_argument('--bf16', default=False, action='store_true', help='whether use bfloat16 as main dtype')
parser.add_argument('--fp8', default=False, action='store_true', help='whether use fp8_e4m3 inputs for HG decode')
parser.add_argument('--pressure', default=False, action='store_true', help='whether do pressure test')
parser.add_argument('--cpu', default=False, action='store_true', help='whether compute golden via cpu')
parser.add_argument('--pad', default=False, action='store_true', help='whether make query uncontiguous to simulate vllm behaviors')
parser.add_argument('--iterations', type=int, default=100, help='pressure test times')
parser.add_argument('--block_size', type=int, default=128, help='page block_size')
parser.add_argument('--batch-size', type=int, default=1, help='batch size for generated inputs')
parser.add_argument('--seq-q', type=int, default=4, help='query length per batch for generated inputs')
parser.add_argument('--seq-k', type=int, default=2048, help='kv length per batch for generated inputs')
parser.add_argument('--num-heads', type=int, default=24, help='number of query heads for generated inputs')
parser.add_argument('--num-heads-kv', type=int, default=2, help='number of kv heads for generated inputs')
parser.add_argument('--no-causal', dest='causal', default=True, action='store_false', help='disable causal mask for generated inputs')
parser.add_argument('--window-left', type=int, default=-1, help='left sliding window size')
parser.add_argument('--window-right', type=int, default=-1, help='right sliding window size')
parser.add_argument('--seed', default=False, action='store_true', help='whether do pressure test')
args = parser.parse_args()
if (args.seed):
set_random_seed(212)
# 从文件加载输入
if (args.load):
nvidia_packet = torch.load("./demo.pt")
query, key, value, cu_seqlens_q, max_seqlen_q, cache_seqlens, max_seqlen_k, softmax_scale, causal, window_size, alibi_slopes, page_table, softcap, fa_version, q_descale, k_descale, v_descale = nvidia_packet["inputs"]
vllm_golden = nvidia_packet["outputs"]
# 解析出必要的参数
batch_size = page_table.shape[0]
assert batch_size == cu_seqlens_q.shape[0] - 1, "check batch size"
page_block_size = key.shape[1]
num_heads_kv = key.shape[2]
num_heads = query.shape[1]
head_dim_qk = query.shape[2]
head_dim_v = key.shape[3]
infer_dtype = query.dtype
else:
# 随机生成 seqkv
batch_size = args.batch_size
# 得到 Q 的长度
seqlen_q = [args.seq_q for i in range(batch_size)]
seqlen_q_sum = sum(seqlen_q)
max_seqlen_q = max(seqlen_q)
cu_seqlens_q = numpy.array([0] + numpy.cumsum(seqlen_q).tolist()).astype("int32")
cu_seqlens_q = torch.from_numpy(cu_seqlens_q)
# 得到 KV 的长度
cache_seqlens = [args.seq_k for i in range(batch_size)]
# 指定分页块的大小, nvidia 64, ours 128
page_block_size = 16 if (use_cuda_toolkits) else args.block_size
# 根据分页块大小计算实际需要的页表的大小
max_seqlen_k = max(cache_seqlens)
seqlen_kv_real_required_page = [math.ceil(it / page_block_size) for it in cache_seqlens]
seqlen_kv_real_required_page_sum = sum(seqlen_kv_real_required_page)
# 默认按照最大 seqlenkv 的来分配
seqlen_kv_max_required_page = math.ceil(max_seqlen_k / page_block_size)
seqlen_kv_max_required_page_total = batch_size * seqlen_kv_max_required_page
# 打乱页表
shuffle = True
if (shuffle):
block_random = torch.randperm(seqlen_kv_max_required_page_total, dtype=torch.int32, device="cuda")
else:
block_random = torch.arange(seqlen_kv_max_required_page_total , dtype=torch.int32)
page_table = []
seq_block_incre = 0
for i in range(batch_size):
blocks_pad = [0] * seqlen_kv_max_required_page
if (shuffle):
blocks_pad[:seqlen_kv_real_required_page[i]] = block_random[seq_block_incre: seq_block_incre + seqlen_kv_real_required_page[i]].cpu().tolist()
seq_block_incre += seqlen_kv_real_required_page[i]
else:
blocks_pad = block_random[seq_block_incre: seq_block_incre + seqlen_kv_max_required_page].cpu().tolist()
seq_block_incre += seqlen_kv_max_required_page
page_table.append(torch.IntTensor(blocks_pad))
page_table = torch.stack(page_table).contiguous().to("cuda")
# 创建基本参数
head_dim_qk = 128
head_dim_v = 128
num_heads = args.num_heads
num_heads_kv = args.num_heads_kv
infer_dtype = torch.float16 # deepseek 默认使用 bfloat16 推理
if (args.bf16): infer_dtype = torch.bfloat16 # 除非命令行指定用 fp16, 不受 args.dtype 影响
softmax_scale = 1.0 / math.sqrt(head_dim_qk)
causal = args.causal
window_size = (args.window_left, args.window_right)
alibi_slopes = None
softcap = 0.0
fa_version = 2
q_descale = torch.ones((batch_size, num_heads), dtype=torch.float32, device="cuda")
k_descale = torch.ones((batch_size, num_heads_kv), dtype=torch.float32, device="cuda")
v_descale = torch.ones((batch_size, num_heads_kv), dtype=torch.float32, device="cuda")
# 创建输入张量
if (args.pad):
query_origin_tensor = torch.randn((seqlen_q_sum, num_heads + 16, head_dim_qk), dtype=infer_dtype, device="cuda")
q = query_origin_tensor[:, :num_heads]
else:
q = torch.randn((seqlen_q_sum, num_heads, head_dim_qk), dtype=infer_dtype, device="cuda")
k_cache = torch.randn((seqlen_kv_max_required_page_total, page_block_size, num_heads_kv, head_dim_qk), device="cuda", dtype=infer_dtype)
v_cache = torch.randn((seqlen_kv_max_required_page_total, page_block_size, num_heads_kv, head_dim_v), device="cuda", dtype=infer_dtype)
vllm_golden = None
cu_seqlens_q = cu_seqlens_q.to(q.device)
cache_seqlens = torch.from_numpy(numpy.array(cache_seqlens).astype("int32")).to(q.device)
q_ref = q
k_cache_ref = k_cache
v_cache_ref = v_cache
if args.fp8:
if not hasattr(torch, "float8_e4m3fn"):
raise RuntimeError("This PyTorch build does not support torch.float8_e4m3fn")
q = q.to(torch.float8_e4m3fn)
k_cache = k_cache.to(torch.float8_e4m3fn)
v_cache = v_cache.to(torch.float8_e4m3fn)
q_ref = q.to(infer_dtype)
k_cache_ref = k_cache.to(infer_dtype)
v_cache_ref = v_cache.to(infer_dtype)
# 展示一下输入数据
print("--------------------------------------------------------------------------------------------")
print("q: ", q.shape, q.dtype, q.is_contiguous(), q.stride())
print("k_cache: ", k_cache.shape, k_cache.dtype, k_cache.is_contiguous(), k_cache.stride())
print("v_cache: ", v_cache.shape, v_cache.dtype, v_cache.is_contiguous(), v_cache.stride())
print("cu_seqlens_q: ", cu_seqlens_q.shape, cu_seqlens_q.dtype, cu_seqlens_q.is_contiguous())
print("cu_seqlens_q: ", cu_seqlens_q)
print("max_seqlen_q: ", max_seqlen_q)
print("cache_seqlens: ", cache_seqlens)
print("max_seqlen_k: ", max_seqlen_k)
print("softmax_scale: ", softmax_scale)
print("causal: ", causal)
print("window_size: ", window_size)
print("alibi_slopes: ", alibi_slopes)
print("page_table: ", page_table.shape, page_table.dtype, page_table.is_contiguous(), page_table.stride())
print("page_table: ", page_table)
print("softcap: ", softcap)
print("fa_version: ", fa_version)
print("q_descale: ", q_descale.shape, q_descale.dtype, q_descale.tolist())
print("k_descale: ", k_descale.shape, k_descale.dtype, k_descale.tolist())
print("v_descale: ", v_descale.shape, v_descale.dtype, v_descale.tolist())
print("--------------------------------------------------------------------------------------------")
# 先从 kvcache 中还原出 key 和 value
key_original = []
value_original = []
for b in range(batch_size):
# 获取页表索引
index = page_table[b]
# 获取实际的索引
max_page_blocks = math.ceil(cache_seqlens[b] / page_block_size)
actual_index = index[:max_page_blocks]
# 根据该页表索引获取当前 seqlenkv 的内容
key_content = k_cache_ref[actual_index]
# reshape 回去
key_content = key_content.view(-1, num_heads_kv, head_dim_qk)[:cache_seqlens[b]].contiguous()
# 同理
value_content = v_cache_ref[actual_index].view(-1, num_heads_kv, head_dim_v)[:cache_seqlens[b]].contiguous()
key_original.append(key_content)
value_original.append(value_content)
# 同理还原出 query 的内容
query_original = []
cum_q = 0
for b in range(batch_size):
query_len = cu_seqlens_q[b + 1] - cu_seqlens_q[b]
query_content = q_ref[cum_q: cum_q + query_len]
query_original.append(query_content.contiguous())
cum_q += query_len
# 重新实现 self-attention
golden = []
golden_lse = []
golden_max = []
for b in range(batch_size):
tmp_output, lse, scores_max, scores_sum = scaled_dot_product_attention(query_original[b], key_original[b], value_original[b], num_heads, num_heads_kv, is_causal=causal, USE_CPU=args.cpu, window_size=window_size)
golden.append(tmp_output)
golden_lse.append(lse)
golden_max.append(scores_max)
golden = torch.cat(golden, dim=0)
golden_lse = torch.cat(golden_lse, dim=-1)
golden_max = torch.cat(golden_max, dim=-1)
print("golden: ", golden.shape)
print("golden_lse: ", golden_lse.shape)
print("--------------------------------------------------------------------------------------------")
if (True):
# fa_output, fa_lse = flash_attn_2_cuda.prefix_decode_varlen_fwd(
bshd_pa_decode = _require_hg_varlen_symbol("hg_prefix_decode_varlen_fwd")
fa_output, fa_lse = bshd_pa_decode(
q,
k_cache,
v_cache,
None, # out_
cu_seqlens_q,
None, # cu_seqlens_k
cache_seqlens,
alibi_slopes,
page_table,
max_seqlen_q,
max_seqlen_k,
0.0, # dropout
softmax_scale,
False, # zero_tensors
causal,
window_size[0],
window_size[1],
softcap,
True, # return_softmax_lse,
1,
q_descale if args.fp8 else None,
k_descale if args.fp8 else None,
v_descale if args.fp8 else None,
infer_dtype == torch.bfloat16,
)
else:
fa_output, fa_lse, *rest = flash_attn_with_kvcache(
q=q,
k_cache=k_cache,
v_cache=v_cache,
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k_new=cache_seqlens,
max_seqlen_q=max_seqlen_q,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
softcap=softcap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
)
torch.cuda.synchronize()
if (vllm_golden is not None):
# 检查保存流程是否有错误
cal_diff(fa_output, vllm_golden, "check")
print("fa_output: ", fa_output.shape)
if (fa_lse is not None): print("fa_lse: ", fa_lse.shape)
# 检验精度如何
fp8_threshold = 5e-3
cal_diff(golden, fa_output, "accuracy", True, fp8_threshold if args.fp8 else 1e-5)
if (fa_lse is not None): cal_diff(golden_lse, fa_lse, "softmax_lse", True, fp8_threshold if args.fp8 else 1e-5)
print("--------------------------------------------------------------------------------------------")
# benchmark 性能数据
import triton
def benchmark_prefix_prefill():
_ = bshd_pa_decode(
q,
k_cache,
v_cache,
None,
cu_seqlens_q,
None,
cache_seqlens,
alibi_slopes,
page_table,
max_seqlen_q,
max_seqlen_k,
0.0,
softmax_scale,
False,
causal,
window_size[0],
window_size[1],
softcap,
True,
1,
q_descale if args.fp8 else None,
k_descale if args.fp8 else None,
v_descale if args.fp8 else None,
infer_dtype == torch.bfloat16,
)
# 适时关闭, 用于 debug
if ((os.getenv("FA_DEBUG") is None) and (os.getenv("HIP_LOG_LEVEL") is None) and not args.trace):
import triton
t = triton.testing.do_bench_cudagraph(benchmark_prefix_prefill)
FLOPS = float(0)
BYTES = float(0)
for b in range(batch_size):
batch_seqlen_q = cu_seqlens_q[b + 1] - cu_seqlens_q[b]
batch_seqlen_k = cache_seqlens[b]
undo_flops = batch_seqlen_q * batch_seqlen_q / 2 if (causal) else 0
qk_flops = num_heads * (batch_seqlen_q * batch_seqlen_k - undo_flops) * head_dim_qk * 2
pv_flops = num_heads * (batch_seqlen_q * batch_seqlen_k - undo_flops) * head_dim_v * 2
FLOPS += qk_flops + pv_flops
q_load = batch_seqlen_q * num_heads * head_dim_qk
k_load = batch_seqlen_k * num_heads_kv * head_dim_qk # k load not only once
v_load = batch_seqlen_k * num_heads_kv * head_dim_v
BYTES += q_load * q.element_size() + k_load * k_cache.element_size() + v_load * v_cache.element_size() # ignore storation ?
print(f"Performance: {t:.3f} ms, \x1b[35m{FLOPS / 10 ** 9 / t:.2f}\x1b[0m TFLOPS, \x1b[35m{BYTES / 10 ** 6 / t:.0f}\x1b[0m GB/s")
# 压力测试
if (args.pressure):
pressure_count = max(100, args.iterations)
for p in range(pressure_count):
pressure_fa_output = torch.zeros_like(fa_output)
pressure_fa_output, _ = bshd_pa_decode(
q.clone(),
k_cache.clone(),
v_cache.clone(),
None,
cu_seqlens_q,
None,
cache_seqlens,
alibi_slopes,
page_table,
max_seqlen_q,
max_seqlen_k,
0.0,
softmax_scale,
False,
causal,
window_size[0],
window_size[1],
softcap,
True,
1,
q_descale if args.fp8 else None,
k_descale if args.fp8 else None,
v_descale if args.fp8 else None,
infer_dtype == torch.bfloat16,
)
torch.cuda.synchronize()
is_equal = torch.equal(pressure_fa_output, fa_output)
if (not is_equal): cal_diff(pressure_fa_output, fa_output, "pressure")
assert is_equal, "\x1b[31mUnstable\x1b[0m!"
del pressure_fa_output
sys.stdout.write("\rPressure Test: {}/{}".format(p + 1, pressure_count))
print(" \x1b[32mPASS\x1b[0m")
print("-----------------------------------------------------------------------------------")
import torch
import os
from vllm.logger import init_logger
from vllm.platforms import current_platform
......@@ -6,6 +7,27 @@ from vllm.triton_utils import tl, triton
import math
import time
def estimate_unified_attention_bytes(
batch_size,
seqlen_q,
seqlen_k,
nheads,
nheads_k,
d,
block_size,
q_bytes,
k_bytes,
v_bytes,
out_bytes=0,
):
num_blocks = math.ceil(seqlen_k / block_size)
q_bytes_total = batch_size * seqlen_q * nheads * d * q_bytes
kv_bytes_total = seqlen_k * batch_size * nheads_k * d * (k_bytes + v_bytes)
out_bytes_total = batch_size * seqlen_q * nheads * d * out_bytes
metadata_bytes = (batch_size + 1) * 4 + batch_size * 4 + batch_size * num_blocks * 4
return q_bytes_total + kv_bytes_total + out_bytes_total + metadata_bytes
import pytest
import torch
import torch.nn.functional as F
......@@ -18,19 +40,19 @@ import pdb
from einops import rearrange, repeat
from flash_attn import (
flash_attn_func,
flash_attn_kvpacked_func,
flash_attn_qkvpacked_func,
flash_attn_varlen_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func,
flash_attn_with_kvcache,
# flash_attn_func,
# flash_attn_kvpacked_func,
# flash_attn_qkvpacked_func,
# flash_attn_varlen_func,
# flash_attn_varlen_kvpacked_func,
# flash_attn_varlen_qkvpacked_func,
# flash_attn_with_kvcache,
varlen_fwd_unified,
)
from flash_attn import flash_attn_func
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import _get_block_size_n
from flash_attn.layers.rotary import apply_rotary_emb
# from flash_attn import flash_attn_func
# from flash_attn.bert_padding import pad_input, unpad_input
# from flash_attn.flash_attn_interface import _get_block_size_n
# from flash_attn.layers.rotary import apply_rotary_emb
MAX_HEADDIM_SM8x = 192
......@@ -1245,60 +1267,102 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool = False,
if is_e5m2:
assert cos_diff < 1e-2
elif use_fp8:
assert cos_diff < 1e-3
assert cos_diff < 5e-3
else:
assert cos_diff < (1e-4 if torch_dtype == torch.bfloat16 else 1e-5)
class _NoUnifiedFallback:
def __init__(self, module):
self._module = module
def __getattr__(self, name):
if name == "varlen_fwd_unified":
raise AssertionError("unexpected fallback to flash_attn_cuda.varlen_fwd_unified")
return getattr(self._module, name)
def varlen_fwd_unified_expect_hg(expected_symbol, *args, **kwargs):
fn_globals = varlen_fwd_unified.__globals__
original_module = fn_globals["flash_attn_cuda"]
original_require = fn_globals["_require_hg_varlen_symbol"]
called_symbols = []
def require_hg_symbol(name):
called_symbols.append(name)
return original_require(name)
fn_globals["flash_attn_cuda"] = _NoUnifiedFallback(original_module)
fn_globals["_require_hg_varlen_symbol"] = require_hg_symbol
try:
result = varlen_fwd_unified(*args, **kwargs)
finally:
fn_globals["flash_attn_cuda"] = original_module
fn_globals["_require_hg_varlen_symbol"] = original_require
assert expected_symbol in called_symbols, (
f"expected {expected_symbol}, got HG calls {called_symbols}"
)
return result
# ---------------------------------------------------------------------------
# Accuracy tests
# ---------------------------------------------------------------------------
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("use_fp8", [False, True])
@pytest.mark.parametrize("mha_type", ["gqa"])
@pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("softcap", [0.0])
@pytest.mark.parametrize("window_size", [(-1, -1)])
@pytest.mark.parametrize("use_alibi_sqrt", [True, False])
@pytest.mark.parametrize("use_qq_bias", [True, False]) # seqlen_q > seqlen_k 时 skip
@pytest.mark.parametrize("use_sinks", [True, False])
@pytest.mark.parametrize("use_mm_prefix", [True, False])
@pytest.mark.parametrize("d", [128, 256])
@pytest.mark.parametrize("window_size", [(-1, -1), (511, 0)])
@pytest.mark.parametrize("use_alibi_sqrt", [False])
@pytest.mark.parametrize("use_qq_bias", [False]) # seqlen_q > seqlen_k 时 skip
@pytest.mark.parametrize("use_sinks", [False])
@pytest.mark.parametrize("use_mm_prefix", [False])
@pytest.mark.parametrize("d", [128])
@pytest.mark.parametrize(
"batch_size,seqlen_q,seqlen_k,block_size",
[
# --- 场景 1: 标准 Prefill (全量预填充) ---
# 验证对角线处理、Is_causal 逻辑以及全量 Bias 覆盖
(1, 512, 512, 128), # 单 Batch 小尺寸,快速验证
(4, 2048, 2048, 128), # 匹配你日志的大尺寸,验证多 Batch 偏移
(2, 1024, 1024, 128), # 较小的 block_size,增加循环迭代次数
# (1, 512, 512, 128), # 单 Batch 小尺寸,快速验证
# (4, 2048, 2048, 128), # 匹配你日志的大尺寸,验证多 Batch 偏移
# (2, 1024, 1024, 128), # 较小的 block_size,增加循环迭代次数
# --- 场景 2: Decode 场景 (增量推理) ---
# 验证 seqlen_q=1 时,如何正确从 KV Cache 的最后位置读取信息
# 此时 qq_bias 实际上只退化为向量加法,最容易测出指针偏移错误
(8, 1, 2048, 128), # 高 Batch 的标准 Decode
(1, 1, 4096, 128), # 超长上下文 Decode,验证大索引寻址
# --- 场景 3: Chunked Prefill / Speculative Decoding (分段/投机采样) ---
# Q 小于 K,但大于 1。这是最难写的逻辑,验证 Is_causal 的动态截断
(2, 128, 1024, 128), # Q 是一小段,K 是长历史
(4, 256, 512, 128), # 验证 Q 和 K 比例较近时的处理
# --- 场景 4: 边界非对称尺寸 (非 2 的幂次) ---
# 专门用来抓那些“假设数据一定是 BlockSize 整数倍”的 Bug
(1, 127, 127, 128), # 刚好差 1 个填满 Block
(2, 33, 1025, 128), # 非常细碎的 Block 和不规则长度
(32, 4, 2048, 128),
(16, 4, 2048, 128),
(8, 4, 2048, 128), # 高 Batch 的标准 Decode
(4, 4, 2048, 128),
(2, 4, 2048, 128),
(1, 4, 4096, 128), # 超长上下文 Decode,验证大索引寻址
# --- 场景 3: Prefix Prefill ---
(1, 16, 128, 128), # fp8 prefill lower boundary
(1, 32, 512, 128),
(2, 32, 513, 128), # non block-aligned KV length
(2, 64, 2048, 128),
(3, 96, 1537, 128), # non power-of-two batch/length
(2, 128, 1024, 128),
# # --- 场景 4: 边界非对称尺寸 (非 2 的幂次) ---
# # 专门用来抓那些“假设数据一定是 BlockSize 整数倍”的 Bug
# (1, 127, 127, 128), # 刚好差 1 个填满 Block
# (2, 33, 1025, 128), # 非常细碎的 Block 和不规则长度
],
)
def test_unified_attn_2d(
batch_size, seqlen_q, seqlen_k, block_size,
d, causal, window_size, softcap,
mha_type, dtype,
use_alibi_sqrt, use_qq_bias, use_sinks, use_mm_prefix,
use_alibi_sqrt, use_qq_bias, use_sinks, use_mm_prefix, use_fp8,
):
device = torch.device("cuda")
torch.manual_seed(42)
nheads = 8
nheads_k = 1 if mha_type == "gqa" else nheads
nheads = 24
nheads_k = 2 if mha_type == "gqa" else nheads
softmax_scale = d ** (-0.5)
MAX_MM_RANGES = 2
......@@ -1310,6 +1374,8 @@ def test_unified_attn_2d(
if use_mm_prefix and seqlen_q > seqlen_k:
pytest.skip("mm_prefix not supported when seqlen_q > seqlen_k")
if causal and window_size != (-1, -1):
pytest.skip("HG local window path is selected with causal=False")
# if use_mm_prefix and not causal:
# pytest.skip("mm_prefix_range is only meaningful with causal=True")
......@@ -1320,13 +1386,28 @@ def test_unified_attn_2d(
k_list.append(torch.randn(seqlen_k, nheads_k, d, device=device, dtype=dtype))
v_list.append(torch.randn(seqlen_k, nheads_k, d, device=device, dtype=dtype))
q_varlen = torch.cat(q_list, dim=0)
if use_fp8:
fp8_dtype = current_platform.fp8_dtype()
q_kernel_list = [q.to(fp8_dtype) for q in q_list]
k_kernel_list = [k.to(fp8_dtype) for k in k_list]
v_kernel_list = [v.to(fp8_dtype) for v in v_list]
q_ref_list = [q.to(dtype) for q in q_kernel_list]
k_ref_list = [k.to(dtype) for k in k_kernel_list]
v_ref_list = [v.to(dtype) for v in v_kernel_list]
else:
q_kernel_list, k_kernel_list, v_kernel_list = q_list, k_list, v_list
q_ref_list, k_ref_list, v_ref_list = q_list, k_list, v_list
q_varlen = torch.cat(q_kernel_list, dim=0)
cu_seqlens_q = torch.zeros(batch_size + 1, device=device, dtype=torch.int32)
cu_seqlens_q[1:] = torch.cumsum(
torch.tensor([seqlen_q] * batch_size, dtype=torch.int32), dim=0
)
seqused_k = torch.tensor([seqlen_k] * batch_size, device=device, dtype=torch.int32)
k_cache, v_cache, block_table = make_paged_kv(k_list, v_list, block_size, device, dtype)
k_cache, v_cache, block_table = make_paged_kv(k_kernel_list, v_kernel_list, block_size, device, q_varlen.dtype)
q_descale = torch.ones((batch_size, nheads), device=device, dtype=torch.float32) if use_fp8 else None
k_descale = torch.ones((batch_size, nheads_k), device=device, dtype=torch.float32) if use_fp8 else None
v_descale = torch.ones((batch_size, nheads_k), device=device, dtype=torch.float32) if use_fp8 else None
# Build optional tensors
alibi_slopes = None
......@@ -1354,7 +1435,7 @@ def test_unified_attn_2d(
for i in range(batch_size):
ref_outs.append(
ref_attn(
q_list[i], k_list[i], v_list[i],
q_ref_list[i], k_ref_list[i], v_ref_list[i],
causal=causal,
window_size=window_size,
softmax_scale=softmax_scale,
......@@ -1369,7 +1450,20 @@ def test_unified_attn_2d(
ref_out = torch.cat(ref_outs, dim=0)
# ---- CUDA kernel ----
cuda_out, cuda_lse = varlen_fwd_unified(
expected_hg_symbol = None
if use_fp8 and seqlen_q >= 16:
expected_hg_symbol = "hg_prefix_prefill_varlen_fwd"
elif use_fp8 or seqlen_q == 1 or 1 < seqlen_q < 16:
expected_hg_symbol = "hg_prefix_decode_varlen_fwd"
elif seqlen_q > 16:
expected_hg_symbol = "hg_prefix_prefill_varlen_fwd"
varlen_runner = (
varlen_fwd_unified
if expected_hg_symbol is None
else lambda *args, **kwargs: varlen_fwd_unified_expect_hg(expected_hg_symbol, *args, **kwargs)
)
cuda_out, cuda_lse = varlen_runner(
q_varlen, k_cache, v_cache,
cu_seqlens_q, seqused_k, block_table,
max_seqlen_q=seqlen_q,
......@@ -1384,6 +1478,9 @@ def test_unified_attn_2d(
s_aux=sinks,
mm_prefix_range=mm_prefix_range,
return_softmax_lse=True,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
)
# # ---- Triton kernel ----
......@@ -1416,7 +1513,7 @@ def test_unified_attn_2d(
# triton_max_diff = (triton_out - ref_out).abs().max().item()
print(
f"\n[{dtype} | causal={causal} | {mha_type} | bs={batch_size} "
f"\n[{dtype} | fp8={use_fp8} | causal={causal} | {mha_type} | bs={batch_size} "
f"sq={seqlen_q} sk={seqlen_k} blk={block_size} | "
f"alibi_sqrt={use_alibi_sqrt} qq_bias={use_qq_bias} "
f"sinks={use_sinks} mm_prefix={use_mm_prefix}]"
......@@ -1424,7 +1521,7 @@ def test_unified_attn_2d(
# f"\n Triton max_diff={triton_max_diff:.4e}"
)
cal_diff(cuda_out, ref_out, "out")
cal_diff(cuda_out, ref_out, "out", use_fp8=use_fp8)
# ---------------------------------------------------------------------------
......@@ -1440,8 +1537,8 @@ def benchmark_unified_attention():
torch.manual_seed(42)
dtype = torch.float16
d = 256
block_size = 320
d = 128
block_size = 128
warmup = 10
repeat = 50
......@@ -1452,30 +1549,35 @@ def benchmark_unified_attention():
MAX_MM_RANGES = 2
# GQA
nheads = 8
nheads_k = 1
nheads = 24
nheads_k = 2
# workload shapes
shapes = [
# (8, 2048, 2048),
# (4, 2048, 2048),
(4, 1, 2048),
(1, 4, 51200),
(2, 4, 51200),
(4, 4, 51200),
(8, 4, 51200),
(16, 4, 51200),
(32, 4, 51200),
# (4, 2048, 4096),
]
# feature configs (C A Q S P)
feature_configs = [
(1,0,0,0,0),
(1,1,0,0,0),
(1,0,1,0,0),
(1,0,0,1,0),
(1,0,0,0,1),
(1,1,1,0,0),
(1,1,0,1,0),
(1,1,0,0,1),
(1,1,1,1,0),
(1,1,1,0,1),
(1,1,1,1,1),
# (1,1,0,0,0),
# (1,0,1,0,0),
# (1,0,0,1,0),
# (1,0,0,0,1),
# (1,1,1,0,0),
# (1,1,0,1,0),
# (1,1,0,0,1),
# (1,1,1,1,0),
# (1,1,1,0,1),
# (1,1,1,1,1),
]
print("\nUnified Attention GQA Benchmark")
......@@ -1485,7 +1587,8 @@ def benchmark_unified_attention():
f"{'BS':>3} {'SQ':>6} {'SK':>6} | "
f"{'C':>1} {'A':>1} {'Q':>1} {'S':>1} {'P':>1} | "
f"{'CUDA(ms)':>10} {'Triton(ms)':>11} | "
f"{'CUDA TFLOPS':>11} {'Triton TFLOPS':>13} | {'Speedup':>8}"
f"{'CUDA TFLOPS':>11} {'Triton TFLOPS':>13} | "
f"{'CUDA(GB/s)':>10} {'Triton(GB/s)':>12} | {'Speedup':>8}"
)
print("-" * 120)
......@@ -1548,6 +1651,13 @@ def benchmark_unified_attention():
triton_out = torch.zeros_like(q_varlen)
total_bytes = estimate_unified_attention_bytes(
batch_size, seqlen_q, seqlen_k, nheads, nheads_k, d, block_size,
q_bytes=torch.finfo(dtype).bits // 8,
k_bytes=torch.finfo(dtype).bits // 8,
v_bytes=torch.finfo(dtype).bits // 8,
)
for C, A, Q, S, P in feature_configs:
causal = bool(C)
......@@ -1655,6 +1765,9 @@ def benchmark_unified_attention():
# FLOPs
flops = 4.0 * batch_size * nheads * seqlen_q * seqlen_k * d
cuda_bandwidth = total_bytes / 1e9 / cuda_ms * 1000
triton_bandwidth = total_bytes / 1e9 / triton_ms * 1000
cuda_tflops = flops / cuda_ms / 1e9
triton_tflops = flops / triton_ms / 1e9
......@@ -1663,11 +1776,130 @@ def benchmark_unified_attention():
f"{C} {A} {Q} {S} {P} | "
f"{cuda_ms:10.3f} {triton_ms:11.3f} | "
f"{cuda_tflops:11.2f} {triton_tflops:13.2f} | "
f"{cuda_bandwidth:10.2f} {triton_bandwidth:12.2f} | "
f"{triton_ms/cuda_ms:8.2f}x"
)
print("=" * 120)
def benchmark_hg_b16_fp8_pa():
device = torch.device("cuda")
torch.manual_seed(42)
dtype = torch.bfloat16
fp8_dtype = current_platform.fp8_dtype()
d = 128
block_size = 128
warmup = 10
repeat = 50
nheads = 24
nheads_k = 2
softcap = 0.0
shapes = [
(bs, seqlen_q, 51200)
for seqlen_q in (1, 4)
for bs in (1, 2, 4, 8, 16, 32)
]
shapes += [
(bs, seqlen_q, 4096)
for seqlen_q in (32, 128)
for bs in (1, 2, 4)
]
windows = [(-1, -1), (511, 0)]
def time_fn(fn):
for _ in range(warmup):
fn()
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(repeat):
fn()
torch.cuda.synchronize()
return (time.perf_counter() - start) / repeat * 1000
print("\nHG Unified PA BF16/FP8 Benchmark")
print("=" * 104)
print(
f"{'BS':>3} {'SQ':>3} {'SK':>6} {'WINDOW':>12} | "
f"{'BF16(ms)':>9} {'FP8(ms)':>9} {'Speedup':>8} | "
f"{'BF16 GB/s':>10} {'FP8 GB/s':>9}"
)
print("-" * 104)
for window_size in windows:
for batch_size, seqlen_q, seqlen_k in shapes:
if seqlen_q >= 16 and window_size != (-1, -1):
continue
softmax_scale = d ** (-0.5)
q_list = [torch.randn(seqlen_q, nheads, d, device=device, dtype=dtype) for _ in range(batch_size)]
k_list = [torch.randn(seqlen_k, nheads_k, d, device=device, dtype=dtype) for _ in range(batch_size)]
v_list = [torch.randn(seqlen_k, nheads_k, d, device=device, dtype=dtype) for _ in range(batch_size)]
q_b16 = torch.cat(q_list, dim=0)
k_b16, v_b16, block_table = make_paged_kv(k_list, v_list, block_size, device, dtype)
q_fp8 = q_b16.to(fp8_dtype)
k_fp8 = k_b16.to(fp8_dtype)
v_fp8 = v_b16.to(fp8_dtype)
cu_seqlens_q = torch.zeros(batch_size + 1, device=device, dtype=torch.int32)
cu_seqlens_q[1:] = torch.cumsum(
torch.tensor([seqlen_q] * batch_size, dtype=torch.int32), dim=0
)
seqused_k = torch.tensor([seqlen_k] * batch_size, device=device, dtype=torch.int32)
q_descale = torch.ones((batch_size, nheads), device=device, dtype=torch.float32)
k_descale = torch.ones((batch_size, nheads_k), device=device, dtype=torch.float32)
v_descale = torch.ones((batch_size, nheads_k), device=device, dtype=torch.float32)
def run_b16():
varlen_fwd_unified(
q_b16, k_b16, v_b16, cu_seqlens_q, seqused_k, block_table,
max_seqlen_q=seqlen_q, max_seqlen_k=seqlen_k,
softmax_scale=softmax_scale, causal=True,
window_size=window_size, softcap=softcap,
return_softmax_lse=False,
)
def run_fp8():
varlen_fwd_unified(
q_fp8, k_fp8, v_fp8, cu_seqlens_q, seqused_k, block_table,
max_seqlen_q=seqlen_q, max_seqlen_k=seqlen_k,
softmax_scale=softmax_scale, causal=True,
window_size=window_size, softcap=softcap,
return_softmax_lse=False,
q_descale=q_descale, k_descale=k_descale, v_descale=v_descale,
)
run_b16()
run_fp8()
torch.cuda.synchronize()
b16_ms = time_fn(run_b16)
fp8_ms = time_fn(run_fp8)
b16_bytes = estimate_unified_attention_bytes(
batch_size, seqlen_q, seqlen_k, nheads, nheads_k, d, block_size,
q_bytes=2, k_bytes=2, v_bytes=2, out_bytes=2,
)
fp8_bytes = estimate_unified_attention_bytes(
batch_size, seqlen_q, seqlen_k, nheads, nheads_k, d, block_size,
q_bytes=1, k_bytes=1, v_bytes=1, out_bytes=2,
)
print(
f"{batch_size:3d} {seqlen_q:3d} {seqlen_k:6d} {str(window_size):>12} | "
f"{b16_ms:9.3f} {fp8_ms:9.3f} {b16_ms / fp8_ms:8.2f} | "
f"{b16_bytes / 1e9 / b16_ms * 1000:10.2f} {fp8_bytes / 1e9 / fp8_ms * 1000:9.2f}"
)
print("=" * 104)
if __name__ == "__main__":
benchmark_unified_attention()
\ No newline at end of file
if os.getenv("RUN_UNIFIED_BENCHMARK") == "1":
benchmark_hg_b16_fp8_pa()
else:
test_unified_attn_2d(
batch_size=1, seqlen_q=1, seqlen_k=40960, block_size=128,
d=128, causal=False, window_size=(511, 0), softcap=0.0,
mha_type="gqa", dtype=torch.bfloat16,
use_alibi_sqrt=False, use_qq_bias=False, use_sinks=False, use_mm_prefix=False,
use_fp8=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment