Commit 518a5f4d authored by hly's avatar hly
Browse files

import aicc-master-dev

parent c2a1b310
// Copyright (c) 2025, Xin Zhou.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "../flash_fwd_launch_template.h"
template<>
void run_fp8_mha_fwd_prefix_prefill_<BFloat16, 256, 256>(Flash_fwd_params &params, hipStream_t stream) {
#ifdef BUILD_FA_FWD
run_fp8_flash_fwd_prefix_prefill<BFloat16, 256, 256>(params, stream);
#endif
}
// Copyright (c) 2025, Xin Zhou.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "../flash_fwd_launch_template.h"
template<>
void run_fp8_mha_fwd_prefix_prefill_<Float16, 256, 256>(Flash_fwd_params &params, hipStream_t stream) {
#ifdef BUILD_FA_FWD
run_fp8_flash_fwd_prefix_prefill<Float16, 256, 256>(params, stream);
#endif
}
// Copyright (c) 2025, Wenjian Zhang.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "../flash_fwd_launch_template.h"
template<>
void run_fp8_mha_fwd_prefix_prefill_<BFloat16, 192, 128>(Flash_fwd_params &params, hipStream_t stream) {
#ifdef BUILD_FA_FWD
run_fp8_flash_fwd_prefix_prefill<BFloat16, 192, 128>(params, stream);
#endif
}
// Copyright (c) 2025, Wenjian Zhang.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "../flash_fwd_launch_template.h"
template<>
void run_fp8_mha_fwd_prefix_prefill_<Float16, 192, 128>(Flash_fwd_params &params, hipStream_t stream) {
#ifdef BUILD_FA_FWD
run_fp8_flash_fwd_prefix_prefill<Float16, 192, 128>(params, stream);
#endif
}
#ifdef BUILD_FA_PERMUTE #ifdef BUILD_FA_PERMUTE
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include "../../include/intrinsic.h"
#include "../flash_fwd_permute_hdim128.h" #include "../flash_fwd_permute_hdim128.h"
template<> template<>
...@@ -140,28 +139,22 @@ __global__ void flash_fwd_varlen_permute_bhsd2bshd<128, 4, 32>( ...@@ -140,28 +139,22 @@ __global__ void flash_fwd_varlen_permute_bhsd2bshd<128, 4, 32>(
int32_t block_offset = seqlen_limit * kHeadDim; int32_t block_offset = seqlen_limit * kHeadDim;
int32_t thread_offset = lane_id_col * 8; int32_t thread_offset = lane_id_col * 8;
// 一次读取 4x128 的 Half 到 LDS // 一次读取 4x128 的 Half 到 LDS
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
{ __builtin_hcu_raw_buffer_load_lds(
auto *lds_ptr = (__attribute__((address_space(3))) int *)( read_buffer,
reinterpret_cast<size_t>(lds) + static_cast<size_t>(lane_id * 4) * sizeof(float)); lds + lane_id * 4,
__builtin_hcu_raw_buffer_load_lds( 16,
read_buffer, (block_offset + thread_offset) << 1, /* v_offset */
lds_ptr, 0, /* s_offset */
16, 0, /* immediate offset, instruction offset */
(block_offset + thread_offset) << 1, /* v_offset */ 0 /* auxilariy data| bit 0: glc, bit 1: slc, bit 2: dlc, bit 3: cache swizzle */
0, /* s_offset */ );
0, /* immediate offset, instruction offset */
0 /* auxilariy data| bit 0: glc, bit 1: slc, bit 2: dlc, bit 3: cache swizzle */
);
}
#else #else
*(vec4_fp32*)(lds + lane_id * 4) = *(vec4_fp32*)(read_ptr + ((block_offset + thread_offset) >> 1)); *(vec4_fp32*)(lds + lane_id * 4) = *(vec4_fp32*)(read_ptr + ((block_offset + thread_offset) >> 1));
#endif #endif
// 从 LDS 转置后, 64 个线程写 4 行, 每次写 128 个 Half, 对应 fetch * 4 + [0,3] 的 seqlen // 从 LDS 转置后, 64 个线程写 4 行, 每次写 128 个 Half, 对应 fetch * 4 + [0,3] 的 seqlen
vec2_fp32 data0, data1; vec2_fp32 data0 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float*)lds + lane_id, 0, 64, false);
inlineasm_fa_ds_read2_b32(lds, lane_id, data0, 0, 64); vec2_fp32 data1 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float*)lds + lane_id + 128, 0, 64, false);
inlineasm_fa_ds_read2_b32(lds, lane_id + 128, data1, 0, 64);
asm volatile("s_waitcnt lgkmcnt(0)\n");
write_ptr[(min(actual_seqlen - 1, fetch * 4 + 0) * num_heads * kHeadDim + (lane_id << 1)) >> 1] = data0[0]; write_ptr[(min(actual_seqlen - 1, fetch * 4 + 0) * num_heads * kHeadDim + (lane_id << 1)) >> 1] = data0[0];
write_ptr[(min(actual_seqlen - 1, fetch * 4 + 1) * num_heads * kHeadDim + (lane_id << 1)) >> 1] = data0[1]; write_ptr[(min(actual_seqlen - 1, fetch * 4 + 1) * num_heads * kHeadDim + (lane_id << 1)) >> 1] = data0[1];
write_ptr[(min(actual_seqlen - 1, fetch * 4 + 2) * num_heads * kHeadDim + (lane_id << 1)) >> 1] = data1[0]; write_ptr[(min(actual_seqlen - 1, fetch * 4 + 2) * num_heads * kHeadDim + (lane_id << 1)) >> 1] = data1[0];
...@@ -370,4 +363,4 @@ __global__ void flash_fwd_varlen_permute_bhsd2bshd<256, 1, 32>( ...@@ -370,4 +363,4 @@ __global__ void flash_fwd_varlen_permute_bhsd2bshd<256, 1, 32>(
#endif #endif
\ No newline at end of file
#ifdef BUILD_FA_PERMUTE #ifdef BUILD_FA_PERMUTE
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include "../../include/intrinsic.h"
#include "../flash_fwd_permute_hdim128.h" #include "../flash_fwd_permute_hdim128.h"
...@@ -118,29 +117,23 @@ __global__ void flash_fwd_varlen_permute_bshd2bhsd<128, 4, 0>( ...@@ -118,29 +117,23 @@ __global__ void flash_fwd_varlen_permute_bshd2bhsd<128, 4, 0>(
// 接下来, 这个 block 要读取 4x128 的内容, 15 个线程读取一行 128 个 half(这里写死了 head_dim = 128), 每个线程读取 8 个 half // 接下来, 这个 block 要读取 4x128 的内容, 15 个线程读取一行 128 个 half(这里写死了 head_dim = 128), 每个线程读取 8 个 half
int32_t thread_offset = lane_id_row * 128 + lane_id_col * 8; int32_t thread_offset = lane_id_row * 128 + lane_id_col * 8;
// block 地址 + thread 地址, << 1 是获取偏移的字节数, 写到 lds 是为了转置一下 // block 地址 + thread 地址, << 1 是获取偏移的字节数, 写到 lds 是为了转置一下
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
{ __builtin_hcu_raw_buffer_load_lds(
auto *lds_ptr = (__attribute__((address_space(3))) int *)( read_buffer,
reinterpret_cast<size_t>(lds) + static_cast<size_t>(lane_id * 4) * sizeof(float)); lds + lane_id * 4,
__builtin_hcu_raw_buffer_load_lds( 16,
read_buffer, (block_offset + thread_offset) << 1, /* v_offset */
lds_ptr, 0, /* s_offset */
16, 0, /* immediate offset, instruction offset */
(block_offset + thread_offset) << 1, /* v_offset */ 0 /* auxilariy data| bit 0: glc, bit 1: slc, bit 2: dlc, bit 3: cache swizzle */
0, /* s_offset */ );
0, /* immediate offset, instruction offset */
0 /* auxilariy data| bit 0: glc, bit 1: slc, bit 2: dlc, bit 3: cache swizzle */
);
}
#else #else
*(vec4_fp32*)(lds + lane_id * 4) = *(vec4_fp32*)(read_ptr + ((block_offset + thread_offset) >> 1)); *(vec4_fp32*)(lds + lane_id * 4) = *(vec4_fp32*)(read_ptr + ((block_offset + thread_offset) >> 1));
#endif #endif
// 写到 lds 不需要同步, 因为只有一个 wave // 写到 lds 不需要同步, 因为只有一个 wave
// 循环 seqlen_q 次, 每次间隔 4 x 128 个 half, 需要写 4 次 // 循环 seqlen_q 次, 每次间隔 4 x 128 个 half, 需要写 4 次
vec2_fp32 data0, data1; vec2_fp32 data0 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float*)lds + lane_id, 0, 64, false);
inlineasm_fa_ds_read2_b32(lds, lane_id, data0, 0, 64); vec2_fp32 data1 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float*)lds + lane_id + 128, 0, 64, false);
inlineasm_fa_ds_read2_b32(lds, lane_id + 128, data1, 0, 64);
asm volatile("s_waitcnt lgkmcnt(0)\n");
write_ptr[(fetch * head_dim + (lane_id << 1) + 0 * cur_seqlen_q * head_dim) >> 1] = data0[0]; write_ptr[(fetch * head_dim + (lane_id << 1) + 0 * cur_seqlen_q * head_dim) >> 1] = data0[0];
write_ptr[(fetch * head_dim + (lane_id << 1) + 1 * cur_seqlen_q * head_dim) >> 1] = data0[1]; write_ptr[(fetch * head_dim + (lane_id << 1) + 1 * cur_seqlen_q * head_dim) >> 1] = data0[1];
write_ptr[(fetch * head_dim + (lane_id << 1) + 2 * cur_seqlen_q * head_dim) >> 1] = data1[0]; write_ptr[(fetch * head_dim + (lane_id << 1) + 2 * cur_seqlen_q * head_dim) >> 1] = data1[0];
...@@ -267,29 +260,23 @@ __global__ void flash_fwd_varlen_permute_bshd2bhsd<128, 4, 32>( ...@@ -267,29 +260,23 @@ __global__ void flash_fwd_varlen_permute_bshd2bhsd<128, 4, 32>(
// 接下来, 这个 block 要读取 4x128 的内容, 15 个线程读取一行 128 个 half(这里写死了 head_dim = 128), 每个线程读取 8 个 half // 接下来, 这个 block 要读取 4x128 的内容, 15 个线程读取一行 128 个 half(这里写死了 head_dim = 128), 每个线程读取 8 个 half
int32_t thread_offset = lane_id_row * 128 + lane_id_col * 8; int32_t thread_offset = lane_id_row * 128 + lane_id_col * 8;
// block 地址 + thread 地址, << 1 是获取偏移的字节数, 写到 lds 是为了转置一下 // block 地址 + thread 地址, << 1 是获取偏移的字节数, 写到 lds 是为了转置一下
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
{ __builtin_hcu_raw_buffer_load_lds(
auto *lds_ptr = (__attribute__((address_space(3))) int *)( read_buffer,
reinterpret_cast<size_t>(lds) + static_cast<size_t>(lane_id * 4) * sizeof(float)); lds + lane_id * 4,
__builtin_hcu_raw_buffer_load_lds( 16,
read_buffer, (block_offset + thread_offset) << 1, /* v_offset */
lds_ptr, 0, /* s_offset */
16, 0, /* immediate offset, instruction offset */
(block_offset + thread_offset) << 1, /* v_offset */ 0 /* auxilariy data| bit 0: glc, bit 1: slc, bit 2: dlc, bit 3: cache swizzle */
0, /* s_offset */ );
0, /* immediate offset, instruction offset */
0 /* auxilariy data| bit 0: glc, bit 1: slc, bit 2: dlc, bit 3: cache swizzle */
);
}
#else #else
*(vec4_fp32*)(lds + lane_id * 4) = *(vec4_fp32*)(read_ptr + ((block_offset + thread_offset) >> 1)); *(vec4_fp32*)(lds + lane_id * 4) = *(vec4_fp32*)(read_ptr + ((block_offset + thread_offset) >> 1));
#endif #endif
// 写到 lds 不需要同步, 因为只有一个 wave // 写到 lds 不需要同步, 因为只有一个 wave
// 循环 seqlen_q 次, 每次间隔 4 x 128 个 half, 需要写 4 次 // 循环 seqlen_q 次, 每次间隔 4 x 128 个 half, 需要写 4 次
vec2_fp32 data0, data1; vec2_fp32 data0 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float*)lds + lane_id, 0, 64, false);
inlineasm_fa_ds_read2_b32(lds, lane_id, data0, 0, 64); vec2_fp32 data1 = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float*)lds + lane_id + 128, 0, 64, false);
inlineasm_fa_ds_read2_b32(lds, lane_id + 128, data1, 0, 64);
asm volatile("s_waitcnt lgkmcnt(0)\n");
write_ptr[(seqlen_limit * head_dim + (lane_id << 1) + 0 * cur_seqlen_q * head_dim) >> 1] = data0[0]; write_ptr[(seqlen_limit * head_dim + (lane_id << 1) + 0 * cur_seqlen_q * head_dim) >> 1] = data0[0];
write_ptr[(seqlen_limit * head_dim + (lane_id << 1) + 1 * cur_seqlen_q * head_dim) >> 1] = data0[1]; write_ptr[(seqlen_limit * head_dim + (lane_id << 1) + 1 * cur_seqlen_q * head_dim) >> 1] = data0[1];
write_ptr[(seqlen_limit * head_dim + (lane_id << 1) + 2 * cur_seqlen_q * head_dim) >> 1] = data1[0]; write_ptr[(seqlen_limit * head_dim + (lane_id << 1) + 2 * cur_seqlen_q * head_dim) >> 1] = data1[0];
...@@ -352,7 +339,7 @@ __global__ void __launch_bounds__(64, 1) flash_fwd_varlen_permute_bshd2bhsd<128, ...@@ -352,7 +339,7 @@ __global__ void __launch_bounds__(64, 1) flash_fwd_varlen_permute_bshd2bhsd<128,
// 接下来, 这个 block 要读取 4x128 的内容, 15 个线程读取一行 128 个 half(这里写死了 head_dim = 128), 每个线程读取 8 个 half // 接下来, 这个 block 要读取 4x128 的内容, 15 个线程读取一行 128 个 half(这里写死了 head_dim = 128), 每个线程读取 8 个 half
int32_t thread_offset = lane_id_row * 128 + lane_id_col * 8; int32_t thread_offset = lane_id_row * 128 + lane_id_col * 8;
// block 地址 + thread 地址, << 1 是获取偏移的字节数, 写到 lds 是为了转置一下 // block 地址 + thread 地址, << 1 是获取偏移的字节数, 写到 lds 是为了转置一下
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
int m0_offset = reinterpret_cast<size_t>(lds) + (fetch * 256 << 2); int m0_offset = reinterpret_cast<size_t>(lds) + (fetch * 256 << 2);
int offset_v = (block_offset + thread_offset) << 1; int offset_v = (block_offset + thread_offset) << 1;
asm volatile( asm volatile(
...@@ -377,15 +364,14 @@ __global__ void __launch_bounds__(64, 1) flash_fwd_varlen_permute_bshd2bhsd<128, ...@@ -377,15 +364,14 @@ __global__ void __launch_bounds__(64, 1) flash_fwd_varlen_permute_bshd2bhsd<128,
// 把所有的 buffer_load 指令下发之后, 再从 lds 开始读取 // 把所有的 buffer_load 指令下发之后, 再从 lds 开始读取
#pragma unroll #pragma unroll
for (int32_t fetch = 0; fetch < SEQLEN_PER_BLOCK; ++fetch) { for (int32_t fetch = 0; fetch < SEQLEN_PER_BLOCK; ++fetch) {
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(%0)\n" :: "B"(SEQLEN_PER_BLOCK - fetch - 1)); asm volatile("s_waitcnt vmcnt(%0)\n" :: "B"(SEQLEN_PER_BLOCK - fetch - 1));
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
#endif #endif
// 循环 seqlen_q 次, 每次间隔 4 x 128 个 half, 需要写 4 次 // 循环 seqlen_q 次, 每次间隔 4 x 128 个 half, 需要写 4 次
inlineasm_fa_ds_read2_b32(lds, fetch * 256 + lane_id, registers_buffer[fetch * 2], 0, 64); registers_buffer[fetch * 2] = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float*)lds + fetch * 256 + lane_id, 0, 64, false);
inlineasm_fa_ds_read2_b32(lds, fetch * 256 + lane_id + 128, registers_buffer[fetch * 2 + 1], 0, 64); registers_buffer[fetch * 2 + 1] = __builtin_hcu_ds_read2_f32((__attribute__((address_space(3))) float*)lds + fetch * 256 + lane_id + 128, 0, 64, false);
asm volatile("s_waitcnt lgkmcnt(0)\n");
} }
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
...@@ -394,7 +380,7 @@ __global__ void __launch_bounds__(64, 1) flash_fwd_varlen_permute_bshd2bhsd<128, ...@@ -394,7 +380,7 @@ __global__ void __launch_bounds__(64, 1) flash_fwd_varlen_permute_bshd2bhsd<128,
for (int32_t fetch = 0; fetch < SEQLEN_PER_BLOCK; ++fetch) { for (int32_t fetch = 0; fetch < SEQLEN_PER_BLOCK; ++fetch) {
// 限制边界 // 限制边界
int32_t seqlen_limit = min(actual_seqlen - 1, fetch); int32_t seqlen_limit = min(actual_seqlen - 1, fetch);
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
// 计算固定的偏移, 字节数目 // 计算固定的偏移, 字节数目
int32_t v_addr = (seqlen_limit * head_dim << 1) + (lane_id << 2); int32_t v_addr = (seqlen_limit * head_dim << 1) + (lane_id << 2);
// 循环 seqlen_q 次, 每次间隔 4 x 128 个 half, 需要写 4 次 // 循环 seqlen_q 次, 每次间隔 4 x 128 个 half, 需要写 4 次
...@@ -688,4 +674,4 @@ template<> ...@@ -688,4 +674,4 @@ template<>
__global__ void __launch_bounds__(64, 1) flash_fwd_varlen_permute_bshd2bhsd<256, 4, 32>( __global__ void __launch_bounds__(64, 1) flash_fwd_varlen_permute_bshd2bhsd<256, 4, 32>(
void* output, void* query, void* split_sizes, int64_t head_stride, int32_t num_heads, int real_headdim) {} void* output, void* query, void* split_sizes, int64_t head_stride, int32_t num_heads, int real_headdim) {}
#endif // end of BUILD_FA_PERMUTE #endif // end of BUILD_FA_PERMUTE
\ No newline at end of file
...@@ -8,6 +8,8 @@ if torch.cuda.is_available(): ...@@ -8,6 +8,8 @@ if torch.cuda.is_available():
flash_attn_qkvpacked_func, flash_attn_qkvpacked_func,
flash_attn_varlen_func, flash_attn_varlen_func,
hg_flash_attn_varlen_func, hg_flash_attn_varlen_func,
flash_mla_with_kvcache,
get_mla_metadata,
vllm_flash_attn_varlen_func, vllm_flash_attn_varlen_func,
flash_attn_varlen_kvpacked_func, flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func, flash_attn_varlen_qkvpacked_func,
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
from typing import Optional, Union from typing import Optional, Union
from typing import List, Tuple from typing import List, Tuple
import os
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -19,13 +18,19 @@ from flash_attn.utils.sparse_utils import hyperparameter_check, get_block_map_me ...@@ -19,13 +18,19 @@ from flash_attn.utils.sparse_utils import hyperparameter_check, get_block_map_me
DEFAULT_FA_VERSION = 2 DEFAULT_FA_VERSION = 2
try:
torch._C._dispatch_find_schema_or_throw("flash_attn2_c_op::varlen_fwd", "")
_has_flash_attn2_c_varlen_fwd = True
except RuntimeError:
_has_flash_attn2_c_varlen_fwd = False
def maybe_contiguous(x): def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x return x.contiguous() if x is not None and x.stride(-1) != 1 else x
def round_multiple(x, m): def round_multiple(x, m):
return (x + m - 1) // m * m return (x + m - 1) // m * m
def _get_block_size_n(device, head_dim, is_dropout, is_causal): def _get_block_size_n(device, head_dim, is_dropout, is_causal):
# This should match the block sizes in the CUDA kernel # This should match the block sizes in the CUDA kernel
...@@ -59,7 +64,7 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal): ...@@ -59,7 +64,7 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal):
elif head_dim <= 512: elif head_dim <= 512:
return 64 return 64
if torch.__version__ >= "2.4.0": if torch.__version__ >= "2.4.0" and _has_flash_attn2_c_varlen_fwd:
_torch_custom_op_wrapper = torch.library.custom_op _torch_custom_op_wrapper = torch.library.custom_op
_torch_register_fake_wrapper = torch.library.register_fake _torch_register_fake_wrapper = torch.library.register_fake
else: else:
...@@ -199,7 +204,11 @@ def varlen_fwd_fake( ...@@ -199,7 +204,11 @@ def varlen_fwd_fake(
return out, softmax_lse, p, rng_state return out, softmax_lse, p, rng_state
_wrapped_flash_attn_varlen_forward = torch.ops.flash_attn2_c_op.varlen_fwd _wrapped_flash_attn_varlen_forward = (
torch.ops.flash_attn2_c_op.varlen_fwd
if _has_flash_attn2_c_varlen_fwd
else _flash_attn_varlen_forward
)
def _flash_attn_backward( def _flash_attn_backward(
...@@ -596,7 +605,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function): ...@@ -596,7 +605,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
softcap, softcap,
alibi_slopes, alibi_slopes,
deterministic, deterministic,
return_softmax, return_softmax,
bhsd = False bhsd = False
): ):
if softmax_scale is None: if softmax_scale is None:
...@@ -611,7 +620,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function): ...@@ -611,7 +620,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
window_size=window_size, window_size=window_size,
softcap=softcap, softcap=softcap,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0, return_softmax=return_softmax and dropout_p > 0,
bhsd = bhsd bhsd = bhsd
) )
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
...@@ -1922,7 +1931,7 @@ def vllm_flash_attn_varlen_func( ...@@ -1922,7 +1931,7 @@ def vllm_flash_attn_varlen_func(
# Version selector # Version selector
fa_version: int = DEFAULT_FA_VERSION, fa_version: int = DEFAULT_FA_VERSION,
s_aux=None, s_aux=None,
): ):
""" """
仅用于vllm prefix cache 仅用于vllm prefix cache
dropout_p should be set to 0.0 during evaluation dropout_p should be set to 0.0 during evaluation
...@@ -1994,7 +2003,7 @@ def vllm_flash_attn_varlen_func( ...@@ -1994,7 +2003,7 @@ def vllm_flash_attn_varlen_func(
else: else:
assert len(window_size) == 2 assert len(window_size) == 2
real_window_size = (window_size[0], window_size[1]) real_window_size = (window_size[0], window_size[1])
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)] q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
dummy_cu_seqlens_k = torch.empty_like(cu_seqlens_q) dummy_cu_seqlens_k = torch.empty_like(cu_seqlens_q)
...@@ -2005,7 +2014,7 @@ def vllm_flash_attn_varlen_func( ...@@ -2005,7 +2014,7 @@ def vllm_flash_attn_varlen_func(
bs = cu_seqlens_q.shape[0] - 1 bs = cu_seqlens_q.shape[0] - 1
total_q = q.shape[0] total_q = q.shape[0]
# max_seqlen_q*bs==total_q and max_seqlen_q<=4 means mtp # max_seqlen_q*bs==total_q and max_seqlen_q<=4 means mtp
# if mtp, k head must be 1. # if mtp, k head must be 1.
# todo : support k head >1 # todo : support k head >1
is_mtp = (max_seqlen_q*bs==total_q and max_seqlen_q>1 and max_seqlen_q<5) is_mtp = (max_seqlen_q*bs==total_q and max_seqlen_q>1 and max_seqlen_q<5)
if (max_seqlen_q==1 or is_mtp ) and real_window_size[0]==-1: if (max_seqlen_q==1 or is_mtp ) and real_window_size[0]==-1:
...@@ -2015,9 +2024,9 @@ def vllm_flash_attn_varlen_func( ...@@ -2015,9 +2024,9 @@ def vllm_flash_attn_varlen_func(
else : else :
out = torch.empty_like(q) out = torch.empty_like(q)
flash_attn_cuda.paged_attention(out,q.reshape(bs,max_seqlen_q,q.shape[1],q.shape[-1]),k,v,softmax_scale,block_table, flash_attn_cuda.paged_attention(out,q.reshape(bs,max_seqlen_q,q.shape[1],q.shape[-1]),k,v,softmax_scale,block_table,
seqused_k,alibi_slopes,kv_cache_dtype,q_descale,k_descale,v_descale,max_seqlen_k,s_aux) seqused_k,alibi_slopes,kv_cache_dtype,q_descale,k_descale,v_descale,max_seqlen_k,s_aux)
return out return out
is_938 = "gfx938" in torch.cuda.get_device_properties("cuda").gcnArchName is_938 = ("gfx938" in torch.cuda.get_device_properties("cuda").gcnArchName or "gfx92a" in torch.cuda.get_device_properties("cuda").gcnArchName)
if (not is_938) and k.dtype == torch.float8_e5m2 and v.dtype == torch.float8_e5m2: if (not is_938) and k.dtype == torch.float8_e5m2 and v.dtype == torch.float8_e5m2:
assert q.dtype != torch.float8_e5m2 , "UnSupport q.dtype:fp8" assert q.dtype != torch.float8_e5m2 , "UnSupport q.dtype:fp8"
q_descale = None q_descale = None
...@@ -2048,7 +2057,7 @@ def vllm_flash_attn_varlen_func( ...@@ -2048,7 +2057,7 @@ def vllm_flash_attn_varlen_func(
None, None,
s_aux, s_aux,
) )
else: else:
if(k.dtype == torch.float8_e4m3fn or k.dtype == torch.float8_e5m2) and q.dtype != k.dtype: if(k.dtype == torch.float8_e4m3fn or k.dtype == torch.float8_e5m2) and q.dtype != k.dtype:
if q_descale is not None: if q_descale is not None:
q=q/q_descale q=q/q_descale
...@@ -2059,7 +2068,7 @@ def vllm_flash_attn_varlen_func( ...@@ -2059,7 +2068,7 @@ def vllm_flash_attn_varlen_func(
v, v,
out, out,
cu_seqlens_q, cu_seqlens_q,
# cu_seqlens_k not used since we use seqused_k, but flash_api.cpp # cu_seqlens_k not used since we use seqused_k, but flash_api.cpp
# still wants it so we pass all zeros # still wants it so we pass all zeros
dummy_cu_seqlens_k if cu_seqlens_k is None else cu_seqlens_k, dummy_cu_seqlens_k if cu_seqlens_k is None else cu_seqlens_k,
seqused_k, seqused_k,
...@@ -2092,7 +2101,7 @@ def vllm_flash_attn_varlen_func( ...@@ -2092,7 +2101,7 @@ def vllm_flash_attn_varlen_func(
v, v,
out, out,
cu_seqlens_q, cu_seqlens_q,
# cu_seqlens_k not used since we use seqused_k, but flash_api.cpp # cu_seqlens_k not used since we use seqused_k, but flash_api.cpp
# still wants it so we pass all zeros # still wants it so we pass all zeros
dummy_cu_seqlens_k if cu_seqlens_k is None else cu_seqlens_k, dummy_cu_seqlens_k if cu_seqlens_k is None else cu_seqlens_k,
seqused_k, seqused_k,
...@@ -2334,6 +2343,7 @@ def flash_attn_with_kvcache( ...@@ -2334,6 +2343,7 @@ def flash_attn_with_kvcache(
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)] q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
s_aux = maybe_contiguous(s_aux)
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
if cache_seqlens is not None and isinstance(cache_seqlens, int): if cache_seqlens is not None and isinstance(cache_seqlens, int):
...@@ -2646,7 +2656,7 @@ def sparse_attn_varlen_func( ...@@ -2646,7 +2656,7 @@ def sparse_attn_varlen_func(
block_count and block_offset for slash sparsity patterns, and block_count and block_offset for slash sparsity patterns, and
column_count and column_index for vertical sparsity patterns. column_count and column_index for vertical sparsity patterns.
For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490. For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490.
Arguments: Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
...@@ -2682,7 +2692,7 @@ def sparse_attn_varlen_func( ...@@ -2682,7 +2692,7 @@ def sparse_attn_varlen_func(
""" """
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)] q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, softmax_lse = flash_attn_cuda.varlen_fwd_sparse( out, softmax_lse = flash_attn_cuda.varlen_fwd_sparse(
q, q,
...@@ -2723,7 +2733,7 @@ def varlen_fwd_unified( ...@@ -2723,7 +2733,7 @@ def varlen_fwd_unified(
softmax_scale=None, softmax_scale=None,
causal=False, causal=False,
softcap=0.0, softcap=0.0,
window_size=(-1, -1), window_size=(-1, -1),
alibi_slopes=None, alibi_slopes=None,
use_alibi_sqrt=False, use_alibi_sqrt=False,
qq_bias=None, qq_bias=None,
...@@ -2732,15 +2742,125 @@ def varlen_fwd_unified( ...@@ -2732,15 +2742,125 @@ def varlen_fwd_unified(
*, *,
out=None, out=None,
return_softmax_lse=False, return_softmax_lse=False,
q_descale=None,
k_descale=None,
v_descale=None,
): ):
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)] q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
window_size_left, window_size_right = window_size window_size_left, window_size_right = window_size
if k.dtype.is_floating_point:
k_dtype_bits = torch.finfo(k.dtype).bits
else:
k_dtype_bits = torch.iinfo(k.dtype).bits
is_mtp = (max_seqlen_q * seqused_k.size(0) == q.shape[0] and 1 < max_seqlen_q < 16)
if max_seqlen_q >= 16:
fp8_dtypes = [torch.float8_e4m3fn]
if hasattr(torch, "float8_e4m3fnuz"):
fp8_dtypes.append(torch.float8_e4m3fnuz)
if k_dtype_bits == 16 and q.shape[-1] == 256 and v.shape[-1] == 256:
out, softmax_lse = flash_attn_cuda.varlen_fwd_unified(
q,
k,
v,
out,
cu_seqlens_q,
max_seqlen_q,
seqused_k,
max_seqlen_k,
block_table,
softmax_scale,
softcap,
None, # q_descale
None, # k_descale
None, # v_descale
None, # output_scale
causal,
window_size_left,
window_size_right,
alibi_slopes,
use_alibi_sqrt,
qq_bias,
s_aux,
mm_prefix_range,
)
return (out, softmax_lse) if return_softmax_lse else out
bshd_prefill = _require_hg_varlen_symbol("hg_prefix_prefill_varlen_fwd")
fa_output, *rest_extend = bshd_prefill(
q,
k,
v,
out, # out_
cu_seqlens_q,
None, # cu_seqlens_k
seqused_k,
alibi_slopes,
block_table,
max_seqlen_q,
max_seqlen_k,
0.0, # dropout
softmax_scale,
False, # zero_tensors
causal,
window_size[0],
window_size[1],
softcap,
return_softmax_lse,
1,
None if (k_dtype_bits == 16) else q_descale,
None if (k_dtype_bits == 16) else k_descale,
None if (k_dtype_bits == 16) else v_descale,
s_aux,
True,
)
return (fa_output, rest_extend[0]) if return_softmax_lse else fa_output
bs = seqused_k.size(0)
total_q = q.shape[0]
if max_seqlen_q == 1 or is_mtp:
if k_dtype_bits == 16 and s_aux is not None:
raise RuntimeError(
"b16 prefix decode with attention sink is not supported by unified attention yet"
)
assert not use_alibi_sqrt and qq_bias is None and mm_prefix_range is None, \
f"Arguments not supported in hg_bshd_decode"
bshd_pa_decode = _require_hg_varlen_symbol("hg_prefix_decode_varlen_fwd")
result = bshd_pa_decode(
q,
k,
v,
out,
cu_seqlens_q,
None,
seqused_k,
alibi_slopes,
block_table,
max_seqlen_q,
max_seqlen_k,
0.0,
softmax_scale,
False,
causal,
window_size[0],
window_size[1],
softcap,
return_softmax_lse,
1,
None if (k_dtype_bits == 16) else q_descale,
None if (k_dtype_bits == 16) else k_descale,
None if (k_dtype_bits == 16) else v_descale,
s_aux,
True,
)
fa_output = result[0]
return (fa_output, result[1]) if return_softmax_lse else fa_output
out, softmax_lse = flash_attn_cuda.varlen_fwd_unified( out, softmax_lse = flash_attn_cuda.varlen_fwd_unified(
q, q,
k, k,
...@@ -3830,8 +3950,8 @@ def get_block_map_fast(q, k, topk_ratio, BLKQ=128, BLKK=64): ...@@ -3830,8 +3950,8 @@ def get_block_map_fast(q, k, topk_ratio, BLKQ=128, BLKK=64):
sparse_map = torch.zeros_like(pooled_score, dtype=torch.int8) sparse_map = torch.zeros_like(pooled_score, dtype=torch.int8)
sparse_map.scatter_(-1, lut, 1) sparse_map.scatter_(-1, lut, 1)
return sparse_map, lut, topk return sparse_map, lut, topk
class SparseLinearAttention(nn.Module): class SparseLinearAttention(nn.Module):
def __init__(self, head_dim, topk, feature_map='softmax', use_bf16=True, use_fp8=False, tie_feature_map_qk=True): def __init__(self, head_dim, topk, feature_map='softmax', use_bf16=True, use_fp8=False, tie_feature_map_qk=True):
R''' R'''
...@@ -3877,7 +3997,7 @@ class SparseLinearAttention(nn.Module): ...@@ -3877,7 +3997,7 @@ class SparseLinearAttention(nn.Module):
with torch.no_grad(): with torch.no_grad():
nn.init.zeros_(self.proj_l.weight) nn.init.zeros_(self.proj_l.weight)
nn.init.zeros_(self.proj_l.bias) nn.init.zeros_(self.proj_l.bias)
def forward(self, q, k, v, return_sparsity=False): def forward(self, q, k, v, return_sparsity=False):
R''' R'''
Args: Args:
...@@ -3886,18 +4006,18 @@ class SparseLinearAttention(nn.Module): ...@@ -3886,18 +4006,18 @@ class SparseLinearAttention(nn.Module):
v: values of shape (B, L, H, D). v: values of shape (B, L, H, D).
return_sparsity: whether to return the actual sparsity. return_sparsity: whether to return the actual sparsity.
''' '''
B, seqlen_q, H, headdim = q.shape B, seqlen_q, H, headdim = q.shape
if headdim == 64: if headdim == 64:
block_m = 64 if seqlen_q <= 2048 else 128 block_m = 64 if seqlen_q <= 2048 else 128
elif headdim == 128: elif headdim == 128:
block_m = 64 if seqlen_q <= 2048 else 128 block_m = 64 if seqlen_q <= 2048 else 128
block_k = 64 block_k = 64
if headdim == 64: if headdim == 64:
sparse_map, lut, real_topk = get_block_map(q.transpose(1, 2).contiguous(), k.transpose(1, 2).contiguous(), topk_ratio=self.topk, BLKQ=block_m, BLKK=block_k) sparse_map, lut, real_topk = get_block_map(q.transpose(1, 2).contiguous(), k.transpose(1, 2).contiguous(), topk_ratio=self.topk, BLKQ=block_m, BLKK=block_k)
else: else:
sparse_map, lut, real_topk = get_block_map_fast(q, k, topk_ratio=self.topk, BLKQ=block_m, BLKK=block_k) sparse_map, lut, real_topk = get_block_map_fast(q, k, topk_ratio=self.topk, BLKQ=block_m, BLKK=block_k)
q = q.to(self.dtype) q = q.to(self.dtype)
k = k.to(self.dtype) k = k.to(self.dtype)
v = v.to(self.dtype) v = v.to(self.dtype)
...@@ -3981,7 +4101,7 @@ def sparse_attn_with_sla( ...@@ -3981,7 +4101,7 @@ def sparse_attn_with_sla(
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor). normalization factor).
""" """
maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)] q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
dtype = torch.bfloat16 if use_bf16 else torch.float16 dtype = torch.bfloat16 if use_bf16 else torch.float16
...@@ -3994,12 +4114,12 @@ def sparse_attn_with_sla( ...@@ -3994,12 +4114,12 @@ def sparse_attn_with_sla(
block_m = 64 if seqlen_q <= 2048 else 128 block_m = 64 if seqlen_q <= 2048 else 128
elif headdim == 128: elif headdim == 128:
block_m = 64 if seqlen_q <= 2048 else 128 block_m = 64 if seqlen_q <= 2048 else 128
block_k = 64 block_k = 64
if headdim == 64: if headdim == 64:
sparse_map, lut, real_topk = get_block_map(q.transpose(1, 2).contiguous(), k.transpose(1, 2).contiguous(), topk_ratio=topk, BLKQ=block_m, BLKK=block_k) sparse_map, lut, real_topk = get_block_map(q.transpose(1, 2).contiguous(), k.transpose(1, 2).contiguous(), topk_ratio=topk, BLKQ=block_m, BLKK=block_k)
else: else:
sparse_map, lut, real_topk = get_block_map_fast(q, k, topk_ratio=topk, BLKQ=block_m, BLKK=block_k) sparse_map, lut, real_topk = get_block_map_fast(q, k, topk_ratio=topk, BLKQ=block_m, BLKK=block_k)
q = q.to(dtype) q = q.to(dtype)
k = k.to(dtype) k = k.to(dtype)
v = v.to(dtype) v = v.to(dtype)
...@@ -4045,15 +4165,6 @@ def _require_hg_varlen_symbol(name: str): ...@@ -4045,15 +4165,6 @@ def _require_hg_varlen_symbol(name: str):
return symbol return symbol
def _apply_hg_kvcache_safe_env() -> None:
# DTK/gfx938 PA launch is sensitive to these knobs. Keep the old HG-safe
# defaults unless the caller explicitly asks for the raw kernel selection.
if os.environ.get("HG_KVCACHE_RAW_KERNEL") == "1":
return
os.environ.setdefault("PA_NO_MLS", "1")
os.environ.setdefault("PA_USE_TILE32X32", "1")
def _validate_hg_paged_kv_contract(k_cache, v_cache) -> None: def _validate_hg_paged_kv_contract(k_cache, v_cache) -> None:
if k_cache.dim() != 4 or v_cache.dim() != 4: if k_cache.dim() != 4 or v_cache.dim() != 4:
raise ValueError("HG paged KV cache expects k and v to both be 4D tensors") raise ValueError("HG paged KV cache expects k and v to both be 4D tensors")
...@@ -4066,53 +4177,56 @@ def _validate_hg_paged_kv_contract(k_cache, v_cache) -> None: ...@@ -4066,53 +4177,56 @@ def _validate_hg_paged_kv_contract(k_cache, v_cache) -> None:
"v=[num_blocks, page_block_size, num_heads_k, d_v]" "v=[num_blocks, page_block_size, num_heads_k, d_v]"
) )
def get_mla_metadata(
def _normalize_hg_paged_q_scales(q_scale, batch_size, num_heads_q, num_heads_k): cache_seqlens: torch.Tensor,
if q_scale is None: num_heads_per_head_k: int,
raise ValueError("q_descale must be provided for HG int8 paged-kvcache path") num_heads_k: int,
q_scale = maybe_contiguous(q_scale) is_fp8_kvcache: bool = False,
if q_scale.dim() == 1: ):
if q_scale.numel() == batch_size * num_heads_q: return None, None
q_scale = q_scale.view(batch_size, num_heads_q)
elif q_scale.numel() == batch_size * num_heads_k:
q_scale = q_scale.view(batch_size, num_heads_k)
if q_scale.dim() != 2 or q_scale.shape[0] != batch_size:
raise ValueError(
"q_descale must have shape [batch_size, num_heads_q] "
"or [batch_size, num_heads_k] for HG int8 paged-kvcache path"
)
if q_scale.shape[1] == num_heads_q:
return q_scale.contiguous()
if q_scale.shape[1] == num_heads_k and num_heads_q % num_heads_k == 0:
return q_scale.repeat_interleave(num_heads_q // num_heads_k, dim=1).contiguous()
raise ValueError(
"q_descale must have shape [batch_size, num_heads_q] "
"or [batch_size, num_heads_k] for HG int8 paged-kvcache path"
)
def _expand_hg_paged_kv_scales(scale, block_table, page_block_size, num_heads_k, name): def flash_mla_with_kvcache(
if scale is None: q: torch.Tensor,
raise ValueError(f"{name} must be provided for HG int8 paged-kvcache path") k_cache: torch.Tensor,
scale = maybe_contiguous(scale) block_table: torch.Tensor,
batch_size = block_table.shape[0] cache_seqlens: torch.Tensor,
if scale.dim() == 1 and scale.numel() == batch_size * num_heads_k: head_dim_v: int,
scale = scale.view(batch_size, num_heads_k) tile_scheduler_metadata: Optional[torch.Tensor],
if scale.dim() != 2 or scale.shape != (batch_size, num_heads_k): num_splits: Optional[torch.Tensor],
raise ValueError( softmax_scale: Optional[float] = None,
f"{name} must have shape [batch_size, num_heads_k] for HG int8 paged-kvcache path" causal: bool = False,
use_cuda_graph: bool = True,
out: Optional[torch.Tensor] = None,
):
if k_cache.dtype not in (torch.float16, torch.bfloat16):
raise NotImplementedError(
"HG MLA dispatch in the main repository supports fp16/bf16 only; "
"fp8/int8 MLA is not supported."
) )
expanded = torch.empty( if softmax_scale is None:
(int(block_table.max().item()) + 1, page_block_size, num_heads_k), softmax_scale = q.shape[-1] ** (-0.5)
device=scale.device,
dtype=scale.dtype, hg_mla = _require_hg_varlen_symbol("hg_fwd_kvcache_mla")
max_seqlen_k = 1 if use_cuda_graph else cache_seqlens.max().item()
result = hg_mla(
q,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
out,
max_seqlen_k,
) )
for batch_idx in range(batch_size): if len(result) < 2:
block_ids = block_table[batch_idx].to(dtype=torch.long) raise RuntimeError("hg_fwd_kvcache_mla did not return softmax_lse")
expanded[block_ids] = scale[batch_idx].view(1, 1, num_heads_k).expand( return result[0], result[1]
block_ids.numel(), page_block_size, num_heads_k
)
return expanded.contiguous()
def hg_flash_attn_varlen_func( def hg_flash_attn_varlen_func(
q, q,
...@@ -4247,8 +4361,6 @@ def hg_flash_attn_varlen_func( ...@@ -4247,8 +4361,6 @@ def hg_flash_attn_varlen_func(
unsupported.append("num_splits") unsupported.append("num_splits")
if fa_version != 2: if fa_version != 2:
unsupported.append("fa_version") unsupported.append("fa_version")
if s_aux is not None:
unsupported.append("s_aux")
if custom_mask is not None: if custom_mask is not None:
unsupported.append("custom_mask") unsupported.append("custom_mask")
if unsupported: if unsupported:
...@@ -4266,6 +4378,7 @@ def hg_flash_attn_varlen_func( ...@@ -4266,6 +4378,7 @@ def hg_flash_attn_varlen_func(
raise ValueError("cu_seqlens_q must be provided") raise ValueError("cu_seqlens_q must be provided")
q, k, v = [maybe_contiguous(x) for x in (q, k, v)] q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
s_aux = maybe_contiguous(s_aux)
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
...@@ -4333,10 +4446,6 @@ def hg_flash_attn_varlen_func( ...@@ -4333,10 +4446,6 @@ def hg_flash_attn_varlen_func(
raise ValueError("cu_seqlens_k and seqused_k cannot be provided at the same time") raise ValueError("cu_seqlens_k and seqused_k cannot be provided at the same time")
if block_table is None: if block_table is None:
raise ValueError("block_table must be provided when seqused_k is used") raise ValueError("block_table must be provided when seqused_k is used")
if return_attn_probs:
raise NotImplementedError(
"return_attn_probs is not supported for HG prefix/paged compatibility paths"
)
if dropout_p != 0.0: if dropout_p != 0.0:
raise NotImplementedError("dropout_p must be 0.0 for HG prefix/paged compatibility paths") raise NotImplementedError("dropout_p must be 0.0 for HG prefix/paged compatibility paths")
...@@ -4346,6 +4455,33 @@ def hg_flash_attn_varlen_func( ...@@ -4346,6 +4455,33 @@ def hg_flash_attn_varlen_func(
k_dtype_bits = torch.iinfo(k.dtype).bits k_dtype_bits = torch.iinfo(k.dtype).bits
if max_seqlen_q > 16 or (k_dtype_bits == 8 and max_seqlen_q > 1): if max_seqlen_q > 16 or (k_dtype_bits == 8 and max_seqlen_q > 1):
if k_dtype_bits == 16 and q.shape[-1] == 256 and v.shape[-1] == 256:
out, softmax_lse = flash_attn_cuda.varlen_fwd_unified(
q,
k,
v,
out,
cu_seqlens_q,
max_seqlen_q,
seqused_k,
max_seqlen_k,
block_table,
softmax_scale,
softcap,
None, # q_descale
None, # k_descale
None, # v_descale
None, # output_scale
causal,
window_size[0],
window_size[1],
alibi_slopes,
False, # use_alibi_sqrt
None, # qq_bias
s_aux,
None, # mm_prefix_range
)
return (out, softmax_lse) if wants_aux else out
prefix_prefill = _require_hg_varlen_symbol("hg_prefix_prefill_varlen_fwd") prefix_prefill = _require_hg_varlen_symbol("hg_prefix_prefill_varlen_fwd")
result = prefix_prefill( result = prefix_prefill(
q, q,
...@@ -4366,15 +4502,16 @@ def hg_flash_attn_varlen_func( ...@@ -4366,15 +4502,16 @@ def hg_flash_attn_varlen_func(
window_size[0], window_size[0],
window_size[1], window_size[1],
softcap, softcap,
return_softmax_lse, return_softmax_lse or return_attn_probs,
1, 1,
None if k_dtype_bits == 16 else q_descale, None if k_dtype_bits == 16 else q_descale,
None if k_dtype_bits == 16 else k_descale, None if k_dtype_bits == 16 else k_descale,
None if k_dtype_bits == 16 else v_descale, None if k_dtype_bits == 16 else v_descale,
s_aux,
is_bf16_output, is_bf16_output,
) )
fa_output = result[0] fa_output = result[0]
return (fa_output, result[1]) if return_softmax_lse else fa_output return (fa_output, result[1]) if (return_softmax_lse or return_attn_probs) else fa_output
if k_dtype_bits == 16: if k_dtype_bits == 16:
prefix_decode = _require_hg_varlen_symbol("hg_prefix_decode_varlen_fwd") prefix_decode = _require_hg_varlen_symbol("hg_prefix_decode_varlen_fwd")
...@@ -4397,13 +4534,18 @@ def hg_flash_attn_varlen_func( ...@@ -4397,13 +4534,18 @@ def hg_flash_attn_varlen_func(
window_size[0], window_size[0],
window_size[1], window_size[1],
softcap, softcap,
return_softmax_lse, return_softmax_lse or return_attn_probs,
1, 1,
None,
None,
None,
s_aux,
is_bf16_output,
) )
fa_output = result[0] fa_output = result[0]
return (fa_output, result[1]) if return_softmax_lse else fa_output return (fa_output, result[1]) if (return_softmax_lse or return_attn_probs) else fa_output
if return_softmax_lse: if return_softmax_lse or return_attn_probs:
raise NotImplementedError( raise NotImplementedError(
"return_softmax_lse is not supported for the HG paged-kvcache compatibility path" "return_softmax_lse is not supported for the HG paged-kvcache compatibility path"
) )
...@@ -4411,28 +4553,6 @@ def hg_flash_attn_varlen_func( ...@@ -4411,28 +4553,6 @@ def hg_flash_attn_varlen_func(
_validate_hg_paged_kv_contract(k, v) _validate_hg_paged_kv_contract(k, v)
if k.shape[1] != 128: if k.shape[1] != 128:
raise NotImplementedError("HG paged-kvcache path currently requires page_block_size == 128") raise NotImplementedError("HG paged-kvcache path currently requires page_block_size == 128")
_apply_hg_kvcache_safe_env()
q_descale = _normalize_hg_paged_q_scales(
q_descale,
batch_size=block_table.shape[0],
num_heads_q=q.shape[1],
num_heads_k=k.shape[2],
)
k_descale = _expand_hg_paged_kv_scales(
k_descale,
block_table=block_table,
page_block_size=k.shape[1],
num_heads_k=k.shape[2],
name="k_descale",
)
v_descale = _expand_hg_paged_kv_scales(
v_descale,
block_table=block_table,
page_block_size=v.shape[1],
num_heads_k=v.shape[2],
name="v_descale",
)
hg_kvcache = _require_hg_varlen_symbol("hg_fwd_kvcache_bshd") hg_kvcache = _require_hg_varlen_symbol("hg_fwd_kvcache_bshd")
result = hg_kvcache( result = hg_kvcache(
q.unsqueeze(1), q.unsqueeze(1),
...@@ -4442,7 +4562,7 @@ def hg_flash_attn_varlen_func( ...@@ -4442,7 +4562,7 @@ def hg_flash_attn_varlen_func(
None, None,
None, None,
seqused_k, seqused_k,
max_seqlen_k if max_seqlen_k > 0 else int(seqused_k.max().item()), 1,
None, None,
None, None,
None, None,
...@@ -4456,12 +4576,12 @@ def hg_flash_attn_varlen_func( ...@@ -4456,12 +4576,12 @@ def hg_flash_attn_varlen_func(
window_size[1], window_size[1],
softcap, softcap,
False, False,
num_splits, -1,
None, None,
None, None,
q_descale, None if k_dtype_bits == 16 else q_descale,
k_descale, None if k_dtype_bits == 16 else k_descale,
v_descale, None if k_dtype_bits == 16 else v_descale,
is_bf16_output, is_bf16_output,
) )
return result[0].squeeze(1) return result[0].squeeze(1)
...@@ -52,6 +52,20 @@ SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE ...@@ -52,6 +52,20 @@ SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE
FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE" FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE"
def cutlass_include_dirs():
candidates = [
Path(this_dir) / "csrc" / "cutlass" / "include",
]
cutlass_home = os.getenv("CUTLASS_HOME")
if cutlass_home:
candidates.append(Path(cutlass_home) / "include")
candidates.extend([
Path("/workspace/cutlass_3.2.1/include"),
Path("/public/home/huangly/数据采集/cutlass_3.2.1/include"),
])
return [str(path) for path in candidates if path.exists()]
def get_platform(): def get_platform():
""" """
Returns the platform name as used in wheel filenames. Returns the platform name as used in wheel filenames.
...@@ -110,8 +124,18 @@ _HG_EXPLICIT_SOURCES_BY_MODE = { ...@@ -110,8 +124,18 @@ _HG_EXPLICIT_SOURCES_BY_MODE = {
"src/target/flash_fwd_hdim128_fp16.cpp", "src/target/flash_fwd_hdim128_fp16.cpp",
"src/target/flash_fwd_hdim128_padding_mask_bf16.cpp", "src/target/flash_fwd_hdim128_padding_mask_bf16.cpp",
"src/target/flash_fwd_hdim128_padding_mask_fp16.cpp", "src/target/flash_fwd_hdim128_padding_mask_fp16.cpp",
"src/target/flash_fp8_fwd_hdim128_bf16.cpp",
"src/target/flash_fp8_fwd_hdim128_fp16.cpp",
"src/target/flash_fp8_fwd_hdim128_prefix_prefill_bf16.cpp",
"src/target/flash_fp8_fwd_hdim128_prefix_prefill_fp16.cpp",
"src/target/flash_fwd_hdim128_prefix_prefill_bf16.cpp", "src/target/flash_fwd_hdim128_prefix_prefill_bf16.cpp",
"src/target/flash_fwd_hdim128_prefix_prefill_fp16.cpp", "src/target/flash_fwd_hdim128_prefix_prefill_fp16.cpp",
"src/target/flash_fp8_fwd_hdim128_prefix_prefill_bf16.cpp",
"src/target/flash_fp8_fwd_hdim128_prefix_prefill_fp16.cpp",
"src/target/flash_fp8_fwd_hdimqk192_hdimv128_prefix_prefill_bf16.cpp",
"src/target/flash_fp8_fwd_hdimqk192_hdimv128_prefix_prefill_fp16.cpp",
"src/target/flash_fp8_fwd_hdim256_prefix_prefill_bf16.cpp",
"src/target/flash_fp8_fwd_hdim256_prefix_prefill_fp16.cpp",
"src/target/flash_fwd_hdim160_bf16.cpp", "src/target/flash_fwd_hdim160_bf16.cpp",
"src/target/flash_fwd_hdim160_fp16.cpp", "src/target/flash_fwd_hdim160_fp16.cpp",
"src/target/flash_fwd_hdim192_bf16.cpp", "src/target/flash_fwd_hdim192_bf16.cpp",
...@@ -262,13 +286,64 @@ def _ninja_shell_join(args) -> str: ...@@ -262,13 +286,64 @@ def _ninja_shell_join(args) -> str:
return " ".join(_ninja_escape(shlex.quote(str(x))) for x in args) return " ".join(_ninja_escape(shlex.quote(str(x))) for x in args)
def _resolve_hg_compiler() -> str:
candidates = [
os.environ.get("FLASH_ATTN_HG_COMPILER"),
"/opt/dtk/bin/aicc",
"aicc",
]
for compiler in candidates:
if compiler and shutil.which(compiler):
return compiler
requested = [c for c in candidates if c]
raise RuntimeError(
"error: no usable HG aicc compiler found from: "
+ ", ".join(repr(c) for c in requested)
+ ". Set FLASH_ATTN_HG_COMPILER to the DTK aicc path."
)
def _normalize_hg_gfx_archs(gfx_version: str):
archs = []
for item in str(gfx_version).replace(",", ";").split(";"):
item = item.strip()
if not item:
continue
archs.append(item if item.startswith("gfx") else f"gfx{item}")
return archs
def _hg_target_define_value(archs):
return ",".join(arch[3:] if arch.startswith("gfx") else arch for arch in archs)
def _hg_arch_device_flags(archs):
flags = []
for arch in archs:
if arch in ("gfx936", "gfx938"):
flags.extend([f"-Xarch_{arch}", "-mllvm=-support-768-vgprs=true"])
elif arch in ("gfx928", "gfx92a", "gfx946"):
flags.extend([f"-Xarch_{arch}", "-mllvm=-support-512-vgprs=true"])
flags.extend([f"-Xarch_{arch}", "-mllvm=-co-issue-vgpr-size=256"])
return flags
def _is_rocm_5_7() -> bool:
version_file = "/opt/rocm/.info/version"
try:
with open(version_file, "r", encoding="utf-8") as f:
return f.read(100).startswith("5.7.0")
except OSError:
return False
def compute_hg_build_descriptor( def compute_hg_build_descriptor(
src_dir, src_dir,
build_dir, build_dir,
mode="all", mode="all",
extra_options_raw="-DGFX_VERSION=938 -Wl,-Bsymbolic", extra_options_raw="-DGFX_VERSION=938;936 -Wl,-Bsymbolic",
): ):
"""Collect HG sources and flags for Ninja (no compile). Default: mode=all, gfx938.""" """Collect HG sources and flags for Ninja (no compile). Default: mode=all, gfx938/gfx936."""
import sysconfig as _sysconfig import sysconfig as _sysconfig
src_dir = os.path.abspath(str(src_dir)) src_dir = os.path.abspath(str(src_dir))
...@@ -279,7 +354,7 @@ def compute_hg_build_descriptor( ...@@ -279,7 +354,7 @@ def compute_hg_build_descriptor(
BUILD_FA_FWD = BUILD_FA_BWD = BUILD_FA_KVCACHE = False BUILD_FA_FWD = BUILD_FA_BWD = BUILD_FA_KVCACHE = False
BUILD_FA_PERMUTE = BUILD_FLASHMLA = False BUILD_FA_PERMUTE = BUILD_FLASHMLA = False
BUILD_C_INTERFACE = False BUILD_C_INTERFACE = False
BUILD_ASM = False BUILD_ASM = True
FA_DEBUG = True FA_DEBUG = True
FA_DEBUG_SUM_MAX = False FA_DEBUG_SUM_MAX = False
HEADDIM_128_ONLY = False HEADDIM_128_ONLY = False
...@@ -355,14 +430,11 @@ def compute_hg_build_descriptor( ...@@ -355,14 +430,11 @@ def compute_hg_build_descriptor(
EXTRA_HIP_FLAGS.append(_tok) EXTRA_HIP_FLAGS.append(_tok)
if GFX_VERSION is None: if GFX_VERSION is None:
GFX_VERSION = "938" GFX_VERSION = "938;936"
ROCM_PATH = os.environ.get("ROCM_PATH", os.environ.get("ROCM_HOME", "/opt/rocm")) ROCM_PATH = os.environ.get("ROCM_PATH", os.environ.get("ROCM_HOME", "/opt/rocm"))
HG_COMPILER = _resolve_hg_compiler()
if not shutil.which("hipcc"):
raise RuntimeError(
"error: hipcc not found in PATH. Please activate the DTK environment first."
)
if not os.path.isdir(os.path.join(ROCM_PATH, "include")): if not os.path.isdir(os.path.join(ROCM_PATH, "include")):
raise RuntimeError( raise RuntimeError(
f"error: {ROCM_PATH}/include not found. " f"error: {ROCM_PATH}/include not found. "
...@@ -400,7 +472,8 @@ def compute_hg_build_descriptor( ...@@ -400,7 +472,8 @@ def compute_hg_build_descriptor(
"-lc10", "-lc10",
] ]
_gfx_comma = GFX_VERSION.replace(";", ",") HG_ARCHS = _normalize_hg_gfx_archs(GFX_VERSION)
_gfx_comma = _hg_target_define_value(HG_ARCHS)
DEFINES = [ DEFINES = [
f"-DTARGET={_gfx_comma}", f"-DTARGET={_gfx_comma}",
"-D__HIP_PLATFORM_AMD__=1", "-D__HIP_PLATFORM_AMD__=1",
...@@ -444,8 +517,10 @@ def compute_hg_build_descriptor( ...@@ -444,8 +517,10 @@ def compute_hg_build_descriptor(
DEFINES.append("-DPA_PAGE_BLOCK_SIZE") DEFINES.append("-DPA_PAGE_BLOCK_SIZE")
if MLA_PAGE_BLOCK_SIZE: if MLA_PAGE_BLOCK_SIZE:
DEFINES.append("-DMLA_PAGE_BLOCK_SIZE") DEFINES.append("-DMLA_PAGE_BLOCK_SIZE")
if _is_rocm_5_7():
DEFINES.append("-DROCM_5_7")
OFFLOAD_FLAGS = [f"--offload-arch=gfx{_g}" for _g in GFX_VERSION.split(";") if _g] OFFLOAD_FLAGS = [f"--offload-arch={_g}" for _g in HG_ARCHS]
INCLUDE_FLAGS = [ INCLUDE_FLAGS = [
f"-I{ROCM_PATH}/include", f"-I{ROCM_PATH}/include",
...@@ -457,23 +532,36 @@ def compute_hg_build_descriptor( ...@@ -457,23 +532,36 @@ def compute_hg_build_descriptor(
INCLUDE_FLAGS += TORCH_INCLUDE_FLAGS INCLUDE_FLAGS += TORCH_INCLUDE_FLAGS
COMMON_FLAGS = [ COMMON_FLAGS = [
"-fPIC",
"-O3", "-O3",
"-std=c++17", "-std=c++17",
"-fPIC",
"-ffast-math", "-ffast-math",
"-fno-finite-math-only", "-fno-finite-math-only",
"-fno-gpu-rdc", "-fno-gpu-rdc",
"-mno-fma",
] ]
DTK_DEVICE_FLAGS = [ DTK_DEVICE_FLAGS = [
"-DHIP_ENABLE_WARP_SYNC_BUILTINS",
"-mllvm", "-mllvm",
"-slp-phi-tree-bb-max-size=10000", "-disable-machine-sink",
"-mllvm", "-mllvm",
"-enable-num-vgprs-512=true", "-disable-code-sink",
"-Rpass-analysis=kernel-resource-usage", "-mcode-object-version=5",
"-ftemplate-backtrace-limit=0",
] ]
if not _is_rocm_5_7():
DTK_DEVICE_FLAGS += [
"-mllvm",
"-amdgpu-enable-rewrite-partial-reg-uses=false",
"-mllvm",
"-allow-gvn-convergent-call=true",
"-mllvm",
"-disallow-uniform-vmed3-combine=true",
"-mllvm",
"-hcu-pre-emit-load-store-opt=false",
"-mllvm",
"-amdgpu-early-inline-all=true",
"-mllvm",
"-amdgpu-function-calls=false",
]
DTK_DEVICE_FLAGS += _hg_arch_device_flags(HG_ARCHS)
if os.environ.get("FLASH_ATTN_HG_SAVE_TEMPS", "") == "1": if os.environ.get("FLASH_ATTN_HG_SAVE_TEMPS", "") == "1":
DTK_DEVICE_FLAGS.append("--save-temps") DTK_DEVICE_FLAGS.append("--save-temps")
...@@ -555,6 +643,7 @@ def compute_hg_build_descriptor( ...@@ -555,6 +643,7 @@ def compute_hg_build_descriptor(
"obj_dir": obj_dir, "obj_dir": obj_dir,
"sources": _all_sources, "sources": _all_sources,
"objects": objects, "objects": objects,
"compiler": HG_COMPILER,
"compile_flags": compile_flags, "compile_flags": compile_flags,
"link_flags": link_flags, "link_flags": link_flags,
"out_so": out_so, "out_so": out_so,
...@@ -566,16 +655,17 @@ def run_hg_ninja_build(descriptor: dict) -> None: ...@@ -566,16 +655,17 @@ def run_hg_ninja_build(descriptor: dict) -> None:
"""Write build_hg.ninja and run ninja (parallel via MAX_JOBS).""" """Write build_hg.ninja and run ninja (parallel via MAX_JOBS)."""
build_dir = descriptor["build_dir"] build_dir = descriptor["build_dir"]
ninja_file = descriptor["ninja_path"] ninja_file = descriptor["ninja_path"]
compiler = _ninja_shell_join([descriptor["compiler"]])
out_so_ninja = _ninja_escape_path(descriptor["out_so"]) out_so_ninja = _ninja_escape_path(descriptor["out_so"])
lines = [ lines = [
"ninja_required_version = 1.3", "ninja_required_version = 1.3",
"", "",
"rule hipcc_compile", "rule hg_compile",
" command = hipcc -c $in -o $out $FLAGS", f" command = {compiler} -c $in -o $out $FLAGS",
" description = HG compile $in", " description = HG compile $in",
"", "",
"rule hipcc_link", "rule hg_link",
" command = hipcc -shared -o $out @$out.rsp $LINK_FLAGS", f" command = {compiler} -shared -o $out @$out.rsp $LINK_FLAGS",
" rspfile = $out.rsp", " rspfile = $out.rsp",
" rspfile_content = $in", " rspfile_content = $in",
" description = HG link $out", " description = HG link $out",
...@@ -585,9 +675,9 @@ def run_hg_ninja_build(descriptor: dict) -> None: ...@@ -585,9 +675,9 @@ def run_hg_ninja_build(descriptor: dict) -> None:
"", "",
] ]
for src, obj in zip(descriptor["sources"], descriptor["objects"]): for src, obj in zip(descriptor["sources"], descriptor["objects"]):
lines.append(f"build {_ninja_escape_path(obj)}: hipcc_compile {_ninja_escape_path(src)}") lines.append(f"build {_ninja_escape_path(obj)}: hg_compile {_ninja_escape_path(src)}")
obj_list = " ".join(_ninja_escape_path(obj) for obj in descriptor["objects"]) obj_list = " ".join(_ninja_escape_path(obj) for obj in descriptor["objects"])
lines.append(f"build {out_so_ninja}: hipcc_link {obj_list}") lines.append(f"build {out_so_ninja}: hg_link {obj_list}")
lines.append("") lines.append("")
os.makedirs(build_dir, exist_ok=True) os.makedirs(build_dir, exist_ok=True)
...@@ -616,6 +706,7 @@ HG_BUILD_DIR = os.path.join(this_dir, "build", "flash_attn_hg") ...@@ -616,6 +706,7 @@ HG_BUILD_DIR = os.path.join(this_dir, "build", "flash_attn_hg")
HG_SO_BUILD = os.path.join(HG_BUILD_DIR, "libflash_attention.so") HG_SO_BUILD = os.path.join(HG_BUILD_DIR, "libflash_attention.so")
HG_SO_PKG = os.path.join(this_dir, "flash_attn", "lib", "libflash_attention.so") HG_SO_PKG = os.path.join(this_dir, "flash_attn", "lib", "libflash_attention.so")
HG_LIB_DIR = os.path.dirname(HG_SO_PKG) HG_LIB_DIR = os.path.dirname(HG_SO_PKG)
os.environ['PYTORCH_NVCC'] = 'aicc'
# We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp # We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp
# files included in the source distribution, in case the user compiles from source. # files included in the source distribution, in case the user compiles from source.
...@@ -663,7 +754,21 @@ if not SKIP_CUDA_BUILD: ...@@ -663,7 +754,21 @@ if not SKIP_CUDA_BUILD:
# HAS_HG_DISPATCH / -lflash_attention are applied there if the .so exists. # HAS_HG_DISPATCH / -lflash_attention are applied there if the .so exists.
hg_compile_defs = [] hg_compile_defs = []
hg_link_args = [] hg_link_args = []
aicc_flags = [
"-mcode-object-version=5",
"-mllvm=-support-768-vgprs=true",
"-mllvm=-disable-machine-sink",
"-mllvm=-disable-code-sink",
"-mllvm=-amdgpu-enable-rewrite-partial-reg-uses=false",
"-mllvm=-allow-gvn-convergent-call=true",
"-mllvm=-disallow-uniform-vmed3-combine=true",
"-mllvm=-hcu-pre-emit-load-store-opt=false",
"-mllvm=-amdgpu-early-inline-all=true",
"-mllvm=-amdgpu-function-calls=false",
"-fno-finite-math-only",
"--gpu-max-threads-per-block=256",
"-mllvm=-unroll-threshold=10000"
]
# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
# torch._C._GLIBCXX_USE_CXX11_ABI # torch._C._GLIBCXX_USE_CXX11_ABI
# https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
...@@ -896,7 +1001,7 @@ if not SKIP_CUDA_BUILD: ...@@ -896,7 +1001,7 @@ if not SKIP_CUDA_BUILD:
"-std=c++17", "-std=c++17",
"-DDCU_ASM", "-DDCU_ASM",
# "-mllvm -not-combine-fma=true", # "-mllvm -not-combine-fma=true",
"-mllvm -slp-phi-tree-bb-max-size=10000", # "-mllvm -slp-phi-tree-bb-max-size=10000",
# "-mllvm -allow-cse-cross-bb-convergent-call=true", # "-mllvm -allow-cse-cross-bb-convergent-call=true",
# "-mllvm -full-vectorize-slp=true", # "-mllvm -full-vectorize-slp=true",
f"-DFLASH_ATTENTION_BF16_TYPE={bf16_type}", f"-DFLASH_ATTENTION_BF16_TYPE={bf16_type}",
...@@ -936,6 +1041,7 @@ if not SKIP_CUDA_BUILD: ...@@ -936,6 +1041,7 @@ if not SKIP_CUDA_BUILD:
] ]
+ generator_flag + generator_flag
+ hg_compile_defs + hg_compile_defs
+ aicc_flags
# + cc_flag # + cc_flag
), ),
}, },
...@@ -944,6 +1050,7 @@ if not SKIP_CUDA_BUILD: ...@@ -944,6 +1050,7 @@ if not SKIP_CUDA_BUILD:
Path(this_dir) / "csrc" / "flash_attn", Path(this_dir) / "csrc" / "flash_attn",
Path(this_dir) / "csrc" / "flash_attn" / "src", Path(this_dir) / "csrc" / "flash_attn" / "src",
Path(this_dir) / "csrc" / "cutlass" / "include", Path(this_dir) / "csrc" / "cutlass" / "include",
],
) )
) )
...@@ -1051,13 +1158,16 @@ class NinjaBuildExtension(BuildExtension): ...@@ -1051,13 +1158,16 @@ class NinjaBuildExtension(BuildExtension):
if os.path.isdir(HG_SRC_DIR): if os.path.isdir(HG_SRC_DIR):
os.makedirs(HG_BUILD_DIR, exist_ok=True) os.makedirs(HG_BUILD_DIR, exist_ok=True)
_maybe_clean_hg_build_dir(HG_BUILD_DIR) _maybe_clean_hg_build_dir(HG_BUILD_DIR)
print("=== Building HG libflash_attention.so (mode=all, gfx938, ninja) ===")
try: try:
desc = compute_hg_build_descriptor( desc = compute_hg_build_descriptor(
HG_SRC_DIR, HG_SRC_DIR,
HG_BUILD_DIR, HG_BUILD_DIR,
mode="all", mode="all",
extra_options_raw="-DGFX_VERSION=938 -Wl,-Bsymbolic", extra_options_raw="-DGFX_VERSION=938;936 -Wl,-Bsymbolic",
)
print(
"=== Building HG libflash_attention.so "
f"(mode=all, gfx938/gfx936, ninja, compiler={desc['compiler']}) ==="
) )
run_hg_ninja_build(desc) run_hg_ninja_build(desc)
if os.path.isfile(HG_SO_BUILD): if os.path.isfile(HG_SO_BUILD):
...@@ -1066,11 +1176,11 @@ class NinjaBuildExtension(BuildExtension): ...@@ -1066,11 +1176,11 @@ class NinjaBuildExtension(BuildExtension):
use_hg = True use_hg = True
print(f"=== Copied HG .so -> {HG_SO_PKG} ===") print(f"=== Copied HG .so -> {HG_SO_PKG} ===")
else: else:
print("WARNING: HG build completed but output .so is missing; continuing without HG dispatch") raise RuntimeError("Error: HG build completed but output .so is missing")
except Exception as e: except Exception as e:
print(f"WARNING: HG build failed ({e}), continuing without HG dispatch") raise RuntimeError(f"Error: HG build failed ({e})")
else: else:
print(f"WARNING: HG source directory not found ({HG_SRC_DIR}), continuing without HG dispatch") raise RuntimeError(f"Error: HG source directory not found ({HG_SRC_DIR})")
else: else:
# FLASH_BUILD_HG=0 should deterministically disable dispatch even if stale artifacts exist. # FLASH_BUILD_HG=0 should deterministically disable dispatch even if stale artifacts exist.
if os.path.isfile(HG_SO_PKG): if os.path.isfile(HG_SO_PKG):
......
import os
import sys
import math
import torch
import pickle
import time
import numpy
import argparse
import random
from datetime import datetime
use_cuda_toolkits = os.path.exists("/usr/local/cuda/bin/nvcc")
use_rocm_toolkits = os.path.exists("/opt/rocm/llvm/bin/clang")
use_dtk_toolkits = os.path.exists("/opt/dtk/bin/aicc")
if (use_cuda_toolkits):
from vllm.vllm_flash_attn import flash_attn_varlen_func
elif (use_rocm_toolkits or use_dtk_toolkits):
try:
from flash_attention_interface import flash_attn_varlen_func, flash_attn_2_cuda, flash_attn_with_kvcache
except ModuleNotFoundError:
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
import flash_attn_2_cuda as flash_attn_cuda
def _require_hg_varlen_symbol(name: str):
symbol = getattr(flash_attn_cuda, name, None)
if symbol is None:
raise RuntimeError(
f"{name} is unavailable in this build. Rebuild flash_attn with HAS_HG_DISPATCH enabled."
)
return symbol
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, do_assert=True, cos_threshold=1e-5) -> None:
assert x.shape == y.shape, "for {}, x and y must have the same shape".format(name)
x, y = x.double(), y.double()
RMSE = ((x - y) * (x - y)).mean().sqrt().item()
cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12)
amax_diff = (x - y).abs().max().item()
rel_diff_mean = (x / y).abs().mean().item()
rel_diff_max = (x / y).abs().max().item()
print("name:{} cos_diff={:.12f}, RMSE=\x1b[35m{:.12f}\x1b[0m, amax_diff=\x1b[35m{:.12f}\x1b[0m, REL=\x1b[35m{:.12f}\x1b[0m, rel_max=\x1b[35m{:.12f}\x1b[0m".format(
name, cos_diff, RMSE, amax_diff, rel_diff_mean, rel_diff_max))
if (do_assert): assert cos_diff < cos_threshold
def scaled_dot_product_attention(__query, __key, __value, h_q, h_kv, is_causal=False, USE_CPU=False, return_max_sum=False, original_seqlen_kv=0, split_slice=0, is_bshd=False, window_size=(-1, -1)):
__query = __query.transpose(0, 1).contiguous()
__key = __key.transpose(0, 1).contiguous()
__value = __value.transpose(0, 1).contiguous()
# 判断是否使用 CPU 计算 golden, 避免 blas 的影响
original_device = __query.device
original_dtype = __query.dtype
if (USE_CPU):
__query = __query.cpu()
__key = __key.cpu()
__value = __value.cpu()
# print("scaled_dot_product_attention: ", query.shape, key.shape, value.shape)
__query = __query.float()
__key = __key.float()
__value = __value.float()
# 如果按照官方的方法返回
if (not return_max_sum):
__key = __key.repeat_interleave(h_q // h_kv, dim=0)
__value = __value.repeat_interleave(h_q // h_kv, dim=0)
attn_weight = __query @ __key.transpose(-2, -1) / math.sqrt(__query.size(-1))
# MTP > 1, causal/local mask applied
if (window_size != (-1, -1)):
s_q = __query.shape[-2]
s_k = __key.shape[-2]
left, right = window_size
if left < 0:
left = s_k
if right < 0:
right = s_k
row_idx = torch.arange(s_q, dtype=torch.int32, device=attn_weight.device)[:, None]
col_idx = torch.arange(s_k, dtype=torch.int32, device=attn_weight.device)[None, :]
col_idx_limit_left = row_idx + s_k - s_q - left
col_idx_limit_right = row_idx + s_k - s_q + right
temp_mask = (col_idx >= col_idx_limit_left) & (col_idx <= col_idx_limit_right)
attn_weight = attn_weight.masked_fill(temp_mask.logical_not(), float("-inf"))
elif (is_causal):
s_q = __query.shape[-2]
s_k = __key.shape[-2]
attn_bias = torch.zeros(s_q, s_k, dtype=__query.dtype, device=attn_weight.device)
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool, device=attn_weight.device).tril(diagonal=s_k - s_q)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(__query.dtype)
attn_weight += attn_bias
# some codes for debug
scores_max = attn_weight.to(torch.float32).max(-1)[0]
scores_sum = torch.exp(attn_weight.to(torch.float32) - scores_max.unsqueeze(-1)).sum(dim=-1)
# original codes
lse = attn_weight.logsumexp(dim=-1)
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
output = attn_weight @ __value
output = output.transpose(0, 1).contiguous()
return output.to(original_device).to(original_dtype), lse.to(original_device), scores_max.to(original_device), scores_sum.to(original_device)
def set_random_seed(seed=0):
random.seed(seed) # 设置 Python 的随机种子
numpy.random.seed(seed) # 设置 NumPy 的随机种子
torch.manual_seed(seed) # 设置 PyTorch 的随机种子
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed) # 设置所有 GPU 的随机种子
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['OMP_NUM_THREADS'] = '1' # 设置 OpenMP 的线程数
torch.set_num_threads(1) # 设置 PyTorch 的线程数
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--load', default=False, action='store_true', help='load path')
parser.add_argument('--trace', default=False, action='store_true', help='whether dump perf traces')
parser.add_argument('--bf16', default=False, action='store_true', help='whether use bfloat16 as main dtype')
parser.add_argument('--fp8', default=False, action='store_true', help='whether use fp8_e4m3 inputs for HG decode')
parser.add_argument('--pressure', default=False, action='store_true', help='whether do pressure test')
parser.add_argument('--cpu', default=False, action='store_true', help='whether compute golden via cpu')
parser.add_argument('--pad', default=False, action='store_true', help='whether make query uncontiguous to simulate vllm behaviors')
parser.add_argument('--iterations', type=int, default=100, help='pressure test times')
parser.add_argument('--block_size', type=int, default=128, help='page block_size')
parser.add_argument('--batch-size', type=int, default=1, help='batch size for generated inputs')
parser.add_argument('--seq-q', type=int, default=4, help='query length per batch for generated inputs')
parser.add_argument('--seq-k', type=int, default=2048, help='kv length per batch for generated inputs')
parser.add_argument('--num-heads', type=int, default=24, help='number of query heads for generated inputs')
parser.add_argument('--num-heads-kv', type=int, default=2, help='number of kv heads for generated inputs')
parser.add_argument('--head-dim-qk', type=int, default=128, help='query/key head dimension')
parser.add_argument('--head-dim-v', type=int, default=128, help='value head dimension')
parser.add_argument('--no-causal', dest='causal', default=True, action='store_false', help='disable causal mask for generated inputs')
parser.add_argument('--window-left', type=int, default=-1, help='left sliding window size')
parser.add_argument('--window-right', type=int, default=-1, help='right sliding window size')
parser.add_argument('--seed', default=False, action='store_true', help='whether do pressure test')
args = parser.parse_args()
if (args.seed):
set_random_seed(212)
# 从文件加载输入
if (args.load):
nvidia_packet = torch.load("./demo.pt")
query, key, value, cu_seqlens_q, max_seqlen_q, cache_seqlens, max_seqlen_k, softmax_scale, causal, window_size, alibi_slopes, page_table, softcap, fa_version, q_descale, k_descale, v_descale = nvidia_packet["inputs"]
vllm_golden = nvidia_packet["outputs"]
# 解析出必要的参数
batch_size = page_table.shape[0]
assert batch_size == cu_seqlens_q.shape[0] - 1, "check batch size"
page_block_size = key.shape[1]
num_heads_kv = key.shape[2]
num_heads = query.shape[1]
head_dim_qk = query.shape[2]
head_dim_v = key.shape[3]
infer_dtype = query.dtype
else:
# 随机生成 seqkv
batch_size = args.batch_size
# 得到 Q 的长度
seqlen_q = [args.seq_q for i in range(batch_size)]
seqlen_q_sum = sum(seqlen_q)
max_seqlen_q = max(seqlen_q)
cu_seqlens_q = numpy.array([0] + numpy.cumsum(seqlen_q).tolist()).astype("int32")
cu_seqlens_q = torch.from_numpy(cu_seqlens_q)
# 得到 KV 的长度
cache_seqlens = [args.seq_k for i in range(batch_size)]
# 指定分页块的大小, nvidia 64, ours 128
page_block_size = 16 if (use_cuda_toolkits) else args.block_size
# 根据分页块大小计算实际需要的页表的大小
max_seqlen_k = max(cache_seqlens)
seqlen_kv_real_required_page = [math.ceil(it / page_block_size) for it in cache_seqlens]
seqlen_kv_real_required_page_sum = sum(seqlen_kv_real_required_page)
# 默认按照最大 seqlenkv 的来分配
seqlen_kv_max_required_page = math.ceil(max_seqlen_k / page_block_size)
seqlen_kv_max_required_page_total = batch_size * seqlen_kv_max_required_page
# 打乱页表
shuffle = True
if (shuffle):
block_random = torch.randperm(seqlen_kv_max_required_page_total, dtype=torch.int32, device="cuda")
else:
block_random = torch.arange(seqlen_kv_max_required_page_total , dtype=torch.int32)
page_table = []
seq_block_incre = 0
for i in range(batch_size):
blocks_pad = [0] * seqlen_kv_max_required_page
if (shuffle):
blocks_pad[:seqlen_kv_real_required_page[i]] = block_random[seq_block_incre: seq_block_incre + seqlen_kv_real_required_page[i]].cpu().tolist()
seq_block_incre += seqlen_kv_real_required_page[i]
else:
blocks_pad = block_random[seq_block_incre: seq_block_incre + seqlen_kv_max_required_page].cpu().tolist()
seq_block_incre += seqlen_kv_max_required_page
page_table.append(torch.IntTensor(blocks_pad))
page_table = torch.stack(page_table).contiguous().to("cuda")
# 创建基本参数
head_dim_qk = args.head_dim_qk
head_dim_v = args.head_dim_v
num_heads = args.num_heads
num_heads_kv = args.num_heads_kv
infer_dtype = torch.float16 # deepseek 默认使用 bfloat16 推理
if (args.bf16): infer_dtype = torch.bfloat16 # 除非命令行指定用 fp16, 不受 args.dtype 影响
softmax_scale = 1.0 / math.sqrt(head_dim_qk)
causal = args.causal
window_size = (args.window_left, args.window_right)
alibi_slopes = None
softcap = 0.0
fa_version = 2
q_descale = torch.ones((batch_size, num_heads), dtype=torch.float32, device="cuda")
k_descale = torch.ones((batch_size, num_heads_kv), dtype=torch.float32, device="cuda")
v_descale = torch.ones((batch_size, num_heads_kv), dtype=torch.float32, device="cuda")
# 创建输入张量
if (args.pad):
query_origin_tensor = torch.randn((seqlen_q_sum, num_heads + 16, head_dim_qk), dtype=infer_dtype, device="cuda")
q = query_origin_tensor[:, :num_heads]
else:
q = torch.randn((seqlen_q_sum, num_heads, head_dim_qk), dtype=infer_dtype, device="cuda")
k_cache = torch.randn((seqlen_kv_max_required_page_total, page_block_size, num_heads_kv, head_dim_qk), device="cuda", dtype=infer_dtype)
v_cache = torch.randn((seqlen_kv_max_required_page_total, page_block_size, num_heads_kv, head_dim_v), device="cuda", dtype=infer_dtype)
vllm_golden = None
cu_seqlens_q = cu_seqlens_q.to(q.device)
cache_seqlens = torch.from_numpy(numpy.array(cache_seqlens).astype("int32")).to(q.device)
q_ref = q
k_cache_ref = k_cache
v_cache_ref = v_cache
if args.fp8:
if not hasattr(torch, "float8_e4m3fn"):
raise RuntimeError("This PyTorch build does not support torch.float8_e4m3fn")
q = q.to(torch.float8_e4m3fn)
k_cache = k_cache.to(torch.float8_e4m3fn)
v_cache = v_cache.to(torch.float8_e4m3fn)
q_ref = q.to(infer_dtype)
k_cache_ref = k_cache.to(infer_dtype)
v_cache_ref = v_cache.to(infer_dtype)
# 展示一下输入数据
print("--------------------------------------------------------------------------------------------")
print("q: ", q.shape, q.dtype, q.is_contiguous(), q.stride())
print("k_cache: ", k_cache.shape, k_cache.dtype, k_cache.is_contiguous(), k_cache.stride())
print("v_cache: ", v_cache.shape, v_cache.dtype, v_cache.is_contiguous(), v_cache.stride())
print("cu_seqlens_q: ", cu_seqlens_q.shape, cu_seqlens_q.dtype, cu_seqlens_q.is_contiguous())
print("cu_seqlens_q: ", cu_seqlens_q)
print("max_seqlen_q: ", max_seqlen_q)
print("cache_seqlens: ", cache_seqlens)
print("max_seqlen_k: ", max_seqlen_k)
print("softmax_scale: ", softmax_scale)
print("causal: ", causal)
print("window_size: ", window_size)
print("alibi_slopes: ", alibi_slopes)
print("page_table: ", page_table.shape, page_table.dtype, page_table.is_contiguous(), page_table.stride())
print("page_table: ", page_table)
print("softcap: ", softcap)
print("fa_version: ", fa_version)
print("q_descale: ", q_descale.shape, q_descale.dtype, q_descale.tolist())
print("k_descale: ", k_descale.shape, k_descale.dtype, k_descale.tolist())
print("v_descale: ", v_descale.shape, v_descale.dtype, v_descale.tolist())
print("--------------------------------------------------------------------------------------------")
# 先从 kvcache 中还原出 key 和 value
key_original = []
value_original = []
for b in range(batch_size):
# 获取页表索引
index = page_table[b]
# 获取实际的索引
max_page_blocks = math.ceil(cache_seqlens[b] / page_block_size)
actual_index = index[:max_page_blocks]
# 根据该页表索引获取当前 seqlenkv 的内容
key_content = k_cache_ref[actual_index]
# reshape 回去
key_content = key_content.view(-1, num_heads_kv, head_dim_qk)[:cache_seqlens[b]].contiguous()
# 同理
value_content = v_cache_ref[actual_index].view(-1, num_heads_kv, head_dim_v)[:cache_seqlens[b]].contiguous()
key_original.append(key_content)
value_original.append(value_content)
# 同理还原出 query 的内容
query_original = []
cum_q = 0
for b in range(batch_size):
query_len = cu_seqlens_q[b + 1] - cu_seqlens_q[b]
query_content = q_ref[cum_q: cum_q + query_len]
query_original.append(query_content.contiguous())
cum_q += query_len
# 重新实现 self-attention
golden = []
golden_lse = []
golden_max = []
for b in range(batch_size):
tmp_output, lse, scores_max, scores_sum = scaled_dot_product_attention(query_original[b], key_original[b], value_original[b], num_heads, num_heads_kv, is_causal=causal, USE_CPU=args.cpu, window_size=window_size)
golden.append(tmp_output)
golden_lse.append(lse)
golden_max.append(scores_max)
golden = torch.cat(golden, dim=0)
golden_lse = torch.cat(golden_lse, dim=-1)
golden_max = torch.cat(golden_max, dim=-1)
print("golden: ", golden.shape)
print("golden_lse: ", golden_lse.shape)
print("--------------------------------------------------------------------------------------------")
if (True):
# fa_output, fa_lse = flash_attn_2_cuda.prefix_decode_varlen_fwd(
bshd_pa_decode = _require_hg_varlen_symbol("hg_prefix_decode_varlen_fwd")
fa_output, fa_lse = bshd_pa_decode(
q,
k_cache,
v_cache,
None, # out_
cu_seqlens_q,
None, # cu_seqlens_k
cache_seqlens,
alibi_slopes,
page_table,
max_seqlen_q,
max_seqlen_k,
0.0, # dropout
softmax_scale,
False, # zero_tensors
causal,
window_size[0],
window_size[1],
softcap,
True, # return_softmax_lse,
1,
q_descale if args.fp8 else None,
k_descale if args.fp8 else None,
v_descale if args.fp8 else None,
None, # s_aux
infer_dtype == torch.bfloat16,
)
else:
fa_output, fa_lse, *rest = flash_attn_with_kvcache(
q=q,
k_cache=k_cache,
v_cache=v_cache,
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k_new=cache_seqlens,
max_seqlen_q=max_seqlen_q,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
softcap=softcap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
)
torch.cuda.synchronize()
if (vllm_golden is not None):
# 检查保存流程是否有错误
cal_diff(fa_output, vllm_golden, "check")
print("fa_output: ", fa_output.shape)
if (fa_lse is not None): print("fa_lse: ", fa_lse.shape)
# 检验精度如何
fp8_threshold = 5e-3
cal_diff(golden, fa_output, "accuracy", True, fp8_threshold if args.fp8 else 1e-5)
if (fa_lse is not None): cal_diff(golden_lse, fa_lse, "softmax_lse", True, fp8_threshold if args.fp8 else 1e-5)
print("--------------------------------------------------------------------------------------------")
# benchmark 性能数据
import triton
def benchmark_prefix_prefill():
_ = bshd_pa_decode(
q,
k_cache,
v_cache,
None,
cu_seqlens_q,
None,
cache_seqlens,
alibi_slopes,
page_table,
max_seqlen_q,
max_seqlen_k,
0.0,
softmax_scale,
False,
causal,
window_size[0],
window_size[1],
softcap,
True,
1,
q_descale if args.fp8 else None,
k_descale if args.fp8 else None,
v_descale if args.fp8 else None,
None,
infer_dtype == torch.bfloat16,
)
# 适时关闭, 用于 debug
if ((os.getenv("FA_DEBUG") is None) and (os.getenv("HIP_LOG_LEVEL") is None) and not args.trace):
import triton
t = triton.testing.do_bench_cudagraph(benchmark_prefix_prefill)
FLOPS = float(0)
BYTES = float(0)
for b in range(batch_size):
batch_seqlen_q = cu_seqlens_q[b + 1] - cu_seqlens_q[b]
batch_seqlen_k = cache_seqlens[b]
effective_seqlen_k = batch_seqlen_k
if window_size != (-1, -1):
window_left, window_right = window_size
left = batch_seqlen_k if window_left < 0 else window_left
right = batch_seqlen_k if window_right < 0 else window_right
effective_seqlen_k = min(batch_seqlen_k, left + batch_seqlen_q + right)
undo_flops = batch_seqlen_q * batch_seqlen_q / 2 if (causal and window_size == (-1, -1)) else 0
attn_elems = batch_seqlen_q * effective_seqlen_k - undo_flops
qk_flops = num_heads * attn_elems * head_dim_qk * 2
pv_flops = num_heads * attn_elems * head_dim_v * 2
FLOPS += qk_flops + pv_flops
q_load = batch_seqlen_q * num_heads * head_dim_qk
k_load = effective_seqlen_k * num_heads_kv * head_dim_qk # k load not only once
v_load = effective_seqlen_k * num_heads_kv * head_dim_v
BYTES += q_load * q.element_size() + k_load * k_cache.element_size() + v_load * v_cache.element_size() # ignore storation ?
print(f"Performance: {t:.3f} ms, \x1b[35m{FLOPS / 10 ** 9 / t:.2f}\x1b[0m TFLOPS, \x1b[35m{BYTES / 10 ** 6 / t:.0f}\x1b[0m GB/s")
# 压力测试
if (args.pressure):
pressure_count = max(100, args.iterations)
for p in range(pressure_count):
pressure_fa_output = torch.zeros_like(fa_output)
pressure_fa_output, _ = bshd_pa_decode(
q.clone(),
k_cache.clone(),
v_cache.clone(),
None,
cu_seqlens_q,
None,
cache_seqlens,
alibi_slopes,
page_table,
max_seqlen_q,
max_seqlen_k,
0.0,
softmax_scale,
False,
causal,
window_size[0],
window_size[1],
softcap,
True,
1,
q_descale if args.fp8 else None,
k_descale if args.fp8 else None,
v_descale if args.fp8 else None,
infer_dtype == torch.bfloat16,
)
torch.cuda.synchronize()
is_equal = torch.equal(pressure_fa_output, fa_output)
if (not is_equal): cal_diff(pressure_fa_output, fa_output, "pressure")
assert is_equal, "\x1b[31mUnstable\x1b[0m!"
del pressure_fa_output
sys.stdout.write("\rPressure Test: {}/{}".format(p + 1, pressure_count))
print(" \x1b[32mPASS\x1b[0m")
print("-----------------------------------------------------------------------------------")
import torch import torch
import os
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -6,6 +7,39 @@ from vllm.triton_utils import tl, triton ...@@ -6,6 +7,39 @@ from vllm.triton_utils import tl, triton
import math import math
import time import time
from typing import Optional
UNIFIED_BLOCK_SIZE = int(os.getenv("UNIFIED_BLOCK_SIZE", "128"))
def estimate_unified_attention_bytes(
batch_size,
seqlen_q,
seqlen_k,
nheads,
nheads_k,
d,
block_size,
q_bytes,
k_bytes,
v_bytes,
out_bytes=0,
d_v=None,
window_size=(-1, -1),
):
d_v = d if d_v is None else d_v
effective_seqlen_k = seqlen_k
if window_size != (-1, -1):
window_left, window_right = window_size
left = seqlen_k if window_left < 0 else window_left
right = seqlen_k if window_right < 0 else window_right
effective_seqlen_k = min(seqlen_k, left + seqlen_q + right)
num_blocks = math.ceil(effective_seqlen_k / block_size)
q_bytes_total = batch_size * seqlen_q * nheads * d * q_bytes
kv_bytes_total = effective_seqlen_k * batch_size * nheads_k * (d * k_bytes + d_v * v_bytes)
out_bytes_total = batch_size * seqlen_q * nheads * d_v * out_bytes
metadata_bytes = (batch_size + 1) * 4 + batch_size * 4 + batch_size * num_blocks * 4
return q_bytes_total + kv_bytes_total + out_bytes_total + metadata_bytes
import pytest import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -18,19 +52,8 @@ import pdb ...@@ -18,19 +52,8 @@ import pdb
from einops import rearrange, repeat from einops import rearrange, repeat
from flash_attn import ( from flash_attn import (
flash_attn_func,
flash_attn_kvpacked_func,
flash_attn_qkvpacked_func,
flash_attn_varlen_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func,
flash_attn_with_kvcache,
varlen_fwd_unified, varlen_fwd_unified,
) )
from flash_attn import flash_attn_func
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import _get_block_size_n
from flash_attn.layers.rotary import apply_rotary_emb
MAX_HEADDIM_SM8x = 192 MAX_HEADDIM_SM8x = 192
...@@ -109,6 +132,7 @@ def kernel_unified_attention_2d( ...@@ -109,6 +132,7 @@ def kernel_unified_attention_2d(
TILE_SIZE: tl.constexpr, # int must be power of 2 TILE_SIZE: tl.constexpr, # int must be power of 2
HEAD_SIZE: tl.constexpr, # int HEAD_SIZE: tl.constexpr, # int
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
VALUE_HEAD_SIZE: tl.constexpr, # int
USE_ALIBI_SLOPES: tl.constexpr, # bool USE_ALIBI_SLOPES: tl.constexpr, # bool
USE_ALIBI_SQRT: tl.constexpr, # bool USE_ALIBI_SQRT: tl.constexpr, # bool
USE_QQ_BIAS: tl.constexpr, # bool USE_QQ_BIAS: tl.constexpr, # bool
...@@ -167,6 +191,7 @@ def kernel_unified_attention_2d( ...@@ -167,6 +191,7 @@ def kernel_unified_attention_2d(
) )
dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1)
value_dim_mask = tl.where(offs_d < VALUE_HEAD_SIZE, 1, 0).to(tl.int1)
query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1)
query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1) query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1)
...@@ -296,7 +321,7 @@ def kernel_unified_attention_2d( ...@@ -296,7 +321,7 @@ def kernel_unified_attention_2d(
# V : (TILE_SIZE, HEAD_SIZE) # V : (TILE_SIZE, HEAD_SIZE)
V_load = tl.load( V_load = tl.load(
value_cache_ptr + v_offset, value_cache_ptr + v_offset,
mask=dim_mask[None, :] & tile_mask[:, None], mask=value_dim_mask[None, :] & tile_mask[:, None],
other=0.0, other=0.0,
) )
...@@ -425,7 +450,7 @@ def kernel_unified_attention_2d( ...@@ -425,7 +450,7 @@ def kernel_unified_attention_2d(
tl.store( tl.store(
output_ptr + output_offset, output_ptr + output_offset,
acc, acc,
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], mask=value_dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
) )
...@@ -1039,6 +1064,7 @@ def unified_attention( ...@@ -1039,6 +1064,7 @@ def unified_attention(
TILE_SIZE=TILE_SIZE_PREFILL, TILE_SIZE=TILE_SIZE_PREFILL,
HEAD_SIZE=head_size, HEAD_SIZE=head_size,
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
VALUE_HEAD_SIZE=v.shape[3],
USE_ALIBI_SLOPES=use_alibi_slopes, USE_ALIBI_SLOPES=use_alibi_slopes,
USE_ALIBI_SQRT=use_alibi_sqrt, USE_ALIBI_SQRT=use_alibi_sqrt,
USE_QQ_BIAS=use_qq_bias, USE_QQ_BIAS=use_qq_bias,
...@@ -1235,7 +1261,14 @@ def make_paged_kv(k_list, v_list, block_size, device, dtype): ...@@ -1235,7 +1261,14 @@ def make_paged_kv(k_list, v_list, block_size, device, dtype):
return k_cache, v_cache, block_table return k_cache, v_cache, block_table
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool = False, is_e5m2: bool = False) -> None: def cal_diff(
x: torch.Tensor,
y: torch.Tensor,
name: str,
use_fp8: bool = False,
is_e5m2: bool = False,
cos_threshold: Optional[float] = None,
) -> None:
torch_dtype = x.dtype torch_dtype = x.dtype
x, y = x.double(), y.double() x, y = x.double(), y.double()
RMSE = ((x - y) * (x - y)).mean().sqrt().item() RMSE = ((x - y) * (x - y)).mean().sqrt().item()
...@@ -1245,23 +1278,60 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool = False, ...@@ -1245,23 +1278,60 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool = False,
if is_e5m2: if is_e5m2:
assert cos_diff < 1e-2 assert cos_diff < 1e-2
elif use_fp8: elif use_fp8:
assert cos_diff < 1e-3 assert cos_diff < 5e-3
elif cos_threshold is not None:
assert cos_diff < cos_threshold
else: else:
assert cos_diff < (1e-4 if torch_dtype == torch.bfloat16 else 1e-5) assert cos_diff < (1e-4 if torch_dtype == torch.bfloat16 else 1e-5)
class _NoUnifiedFallback:
def __init__(self, module):
self._module = module
def __getattr__(self, name):
if name == "varlen_fwd_unified":
raise AssertionError("unexpected fallback to flash_attn_cuda.varlen_fwd_unified")
return getattr(self._module, name)
def varlen_fwd_unified_expect_hg(expected_symbol, *args, **kwargs):
fn_globals = varlen_fwd_unified.__globals__
original_module = fn_globals["flash_attn_cuda"]
original_require = fn_globals["_require_hg_varlen_symbol"]
called_symbols = []
def require_hg_symbol(name):
called_symbols.append(name)
return original_require(name)
fn_globals["flash_attn_cuda"] = _NoUnifiedFallback(original_module)
fn_globals["_require_hg_varlen_symbol"] = require_hg_symbol
try:
result = varlen_fwd_unified(*args, **kwargs)
finally:
fn_globals["flash_attn_cuda"] = original_module
fn_globals["_require_hg_varlen_symbol"] = original_require
assert expected_symbol in called_symbols, (
f"expected {expected_symbol}, got HG calls {called_symbols}"
)
return result
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Accuracy tests # Accuracy tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("use_fp8", [False, True])
@pytest.mark.parametrize("mha_type", ["gqa"]) @pytest.mark.parametrize("mha_type", ["gqa"])
@pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("softcap", [0.0]) @pytest.mark.parametrize("softcap", [0.0])
@pytest.mark.parametrize("window_size", [(-1, -1)]) @pytest.mark.parametrize("window_size", [(-1, -1), (511, 0)])
@pytest.mark.parametrize("use_alibi_sqrt", [True, False]) @pytest.mark.parametrize("use_alibi_sqrt", [True, False])
@pytest.mark.parametrize("use_qq_bias", [True, False]) # seqlen_q > seqlen_k 时 skip @pytest.mark.parametrize("use_qq_bias", [True, False]) # seqlen_q > seqlen_k 时 skip
@pytest.mark.parametrize("use_sinks", [True, False]) @pytest.mark.parametrize("use_sinks", [True, False])
@pytest.mark.parametrize("use_mm_prefix", [True, False]) @pytest.mark.parametrize("use_mm_prefix", [True, False])
@pytest.mark.parametrize("d", [128, 256]) @pytest.mark.parametrize("d,d_v", [(128, 128), (192, 128), (256, 256)])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"batch_size,seqlen_q,seqlen_k,block_size", "batch_size,seqlen_q,seqlen_k,block_size",
[ [
...@@ -1274,43 +1344,63 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool = False, ...@@ -1274,43 +1344,63 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool = False,
# --- 场景 2: Decode 场景 (增量推理) --- # --- 场景 2: Decode 场景 (增量推理) ---
# 验证 seqlen_q=1 时,如何正确从 KV Cache 的最后位置读取信息 # 验证 seqlen_q=1 时,如何正确从 KV Cache 的最后位置读取信息
# 此时 qq_bias 实际上只退化为向量加法,最容易测出指针偏移错误 # 此时 qq_bias 实际上只退化为向量加法,最容易测出指针偏移错误
(8, 1, 2048, 128), # 高 Batch 的标准 Decode (64, 1, 2048, UNIFIED_BLOCK_SIZE),
(1, 1, 4096, 128), # 超长上下文 Decode,验证大索引寻址 (16, 1, 2048, UNIFIED_BLOCK_SIZE),
(1, 1, 4096, UNIFIED_BLOCK_SIZE),
# --- 场景 3: Chunked Prefill / Speculative Decoding (分段/投机采样) --- (64, 4, 2048, UNIFIED_BLOCK_SIZE),
# Q 小于 K,但大于 1。这是最难写的逻辑,验证 Is_causal 的动态截断 (32, 4, 2048, UNIFIED_BLOCK_SIZE),
(2, 128, 1024, 128), # Q 是一小段,K 是长历史 (16, 4, 2048, UNIFIED_BLOCK_SIZE),
(4, 256, 512, 128), # 验证 Q 和 K 比例较近时的处理 (8, 4, 2048, UNIFIED_BLOCK_SIZE), # 高 Batch 的标准 Decode
(4, 4, 2048, UNIFIED_BLOCK_SIZE),
# --- 场景 4: 边界非对称尺寸 (非 2 的幂次) --- (2, 4, 2048, UNIFIED_BLOCK_SIZE),
# 专门用来抓那些“假设数据一定是 BlockSize 整数倍”的 Bug (1, 4, 4096, UNIFIED_BLOCK_SIZE), # 超长上下文 Decode,验证大索引寻址
# --- 场景 3: Prefix Prefill ---
(1, 16, 128, UNIFIED_BLOCK_SIZE), # fp8 prefill lower boundary
(1, 32, 512, UNIFIED_BLOCK_SIZE),
(2, 32, 513, UNIFIED_BLOCK_SIZE), # non block-aligned KV length
(2, 64, 2048, UNIFIED_BLOCK_SIZE),
(3, 96, 1537, UNIFIED_BLOCK_SIZE), # non power-of-two batch/length
(2, 128, 1024, UNIFIED_BLOCK_SIZE),
# # --- 场景 4: 边界非对称尺寸 (非 2 的幂次) ---
# # 专门用来抓那些“假设数据一定是 BlockSize 整数倍”的 Bug
(1, 127, 127, 128), # 刚好差 1 个填满 Block (1, 127, 127, 128), # 刚好差 1 个填满 Block
(2, 33, 1025, 128), # 非常细碎的 Block 和不规则长度 (2, 33, 1025, 128), # 非常细碎的 Block 和不规则长度
], ],
) )
def test_unified_attn_2d( def test_unified_attn_2d(
batch_size, seqlen_q, seqlen_k, block_size, batch_size, seqlen_q, seqlen_k, block_size,
d, causal, window_size, softcap, d, d_v, causal, window_size, softcap,
mha_type, dtype, mha_type, dtype,
use_alibi_sqrt, use_qq_bias, use_sinks, use_mm_prefix, use_alibi_sqrt, use_qq_bias, use_sinks, use_mm_prefix, use_fp8,
): ):
device = torch.device("cuda") device = torch.device("cuda")
torch.manual_seed(42) torch.manual_seed(42)
nheads = 8 nheads = 16
nheads_k = 1 if mha_type == "gqa" else nheads nheads_k = 2 if mha_type == "gqa" else nheads
softmax_scale = d ** (-0.5) softmax_scale = d ** (-0.5)
MAX_MM_RANGES = 2 MAX_MM_RANGES = 2
# skip invalid combos # skip invalid combos
if use_alibi_sqrt and not causal: if use_alibi_sqrt and not causal:
pytest.skip("alibi_sqrt only tested with causal=True") pytest.skip("alibi_sqrt only tested with causal=True")
if use_alibi_sqrt:
pytest.skip("HG unified attention does not support alibi_sqrt yet")
if use_qq_bias and seqlen_q > seqlen_k: if use_qq_bias and seqlen_q > seqlen_k:
pytest.skip("qq_bias requires seqlen_q <= seqlen_k") pytest.skip("qq_bias requires seqlen_q <= seqlen_k")
if use_qq_bias:
pytest.skip("HG unified attention does not support qq_bias yet")
if use_mm_prefix and seqlen_q > seqlen_k: if use_mm_prefix and seqlen_q > seqlen_k:
pytest.skip("mm_prefix not supported when seqlen_q > seqlen_k") pytest.skip("mm_prefix not supported when seqlen_q > seqlen_k")
if use_mm_prefix:
pytest.skip("HG unified attention does not support mm_prefix yet")
if (not use_fp8) and use_sinks and (seqlen_q == 1 or 1 < seqlen_q < 16):
pytest.skip("b16 prefix decode sinks are not supported yet")
if (not use_fp8) and use_sinks and seqlen_q >= 16 and d == 256 and d_v == 256:
pytest.skip("b16 prefix prefill 256/256 sinks are not supported yet")
# if use_mm_prefix and not causal: # if use_mm_prefix and not causal:
# pytest.skip("mm_prefix_range is only meaningful with causal=True") # pytest.skip("mm_prefix_range is only meaningful with causal=True")
...@@ -1318,15 +1408,30 @@ def test_unified_attn_2d( ...@@ -1318,15 +1408,30 @@ def test_unified_attn_2d(
for _ in range(batch_size): for _ in range(batch_size):
q_list.append(torch.randn(seqlen_q, nheads, d, device=device, dtype=dtype)) q_list.append(torch.randn(seqlen_q, nheads, d, device=device, dtype=dtype))
k_list.append(torch.randn(seqlen_k, nheads_k, d, device=device, dtype=dtype)) k_list.append(torch.randn(seqlen_k, nheads_k, d, device=device, dtype=dtype))
v_list.append(torch.randn(seqlen_k, nheads_k, d, device=device, dtype=dtype)) v_list.append(torch.randn(seqlen_k, nheads_k, d_v, device=device, dtype=dtype))
if use_fp8:
fp8_dtype = current_platform.fp8_dtype()
q_kernel_list = [q.to(fp8_dtype) for q in q_list]
k_kernel_list = [k.to(fp8_dtype) for k in k_list]
v_kernel_list = [v.to(fp8_dtype) for v in v_list]
q_ref_list = [q.to(dtype) for q in q_kernel_list]
k_ref_list = [k.to(dtype) for k in k_kernel_list]
v_ref_list = [v.to(dtype) for v in v_kernel_list]
else:
q_kernel_list, k_kernel_list, v_kernel_list = q_list, k_list, v_list
q_ref_list, k_ref_list, v_ref_list = q_list, k_list, v_list
q_varlen = torch.cat(q_list, dim=0) q_varlen = torch.cat(q_kernel_list, dim=0)
cu_seqlens_q = torch.zeros(batch_size + 1, device=device, dtype=torch.int32) cu_seqlens_q = torch.zeros(batch_size + 1, device=device, dtype=torch.int32)
cu_seqlens_q[1:] = torch.cumsum( cu_seqlens_q[1:] = torch.cumsum(
torch.tensor([seqlen_q] * batch_size, dtype=torch.int32), dim=0 torch.tensor([seqlen_q] * batch_size, dtype=torch.int32), dim=0
) )
seqused_k = torch.tensor([seqlen_k] * batch_size, device=device, dtype=torch.int32) seqused_k = torch.tensor([seqlen_k] * batch_size, device=device, dtype=torch.int32)
k_cache, v_cache, block_table = make_paged_kv(k_list, v_list, block_size, device, dtype) k_cache, v_cache, block_table = make_paged_kv(k_kernel_list, v_kernel_list, block_size, device, q_varlen.dtype)
q_descale = torch.ones((batch_size, nheads), device=device, dtype=torch.float32) if use_fp8 else None
k_descale = torch.ones((batch_size, nheads_k), device=device, dtype=torch.float32) if use_fp8 else None
v_descale = torch.ones((batch_size, nheads_k), device=device, dtype=torch.float32) if use_fp8 else None
# Build optional tensors # Build optional tensors
alibi_slopes = None alibi_slopes = None
...@@ -1339,7 +1444,8 @@ def test_unified_attn_2d( ...@@ -1339,7 +1444,8 @@ def test_unified_attn_2d(
sinks = None sinks = None
if use_sinks: if use_sinks:
sinks = torch.randn(nheads, device=device, dtype=dtype) sink_dtype = torch.bfloat16 if use_fp8 else torch.float32
sinks = (torch.randn(nheads, device=device, dtype=sink_dtype) * 0.25) + 2.0
mm_prefix_range = None mm_prefix_range = None
if use_mm_prefix: if use_mm_prefix:
...@@ -1354,7 +1460,7 @@ def test_unified_attn_2d( ...@@ -1354,7 +1460,7 @@ def test_unified_attn_2d(
for i in range(batch_size): for i in range(batch_size):
ref_outs.append( ref_outs.append(
ref_attn( ref_attn(
q_list[i], k_list[i], v_list[i], q_ref_list[i], k_ref_list[i], v_ref_list[i],
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
...@@ -1369,7 +1475,19 @@ def test_unified_attn_2d( ...@@ -1369,7 +1475,19 @@ def test_unified_attn_2d(
ref_out = torch.cat(ref_outs, dim=0) ref_out = torch.cat(ref_outs, dim=0)
# ---- CUDA kernel ---- # ---- CUDA kernel ----
cuda_out, cuda_lse = varlen_fwd_unified( expected_hg_symbol = None
if seqlen_q >= 16:
if not (not use_fp8 and d == 256 and d_v == 256):
expected_hg_symbol = "hg_prefix_prefill_varlen_fwd"
elif use_fp8 or seqlen_q == 1 or 1 < seqlen_q < 16:
expected_hg_symbol = "hg_prefix_decode_varlen_fwd"
varlen_runner = (
varlen_fwd_unified
if expected_hg_symbol is None
else lambda *args, **kwargs: varlen_fwd_unified_expect_hg(expected_hg_symbol, *args, **kwargs)
)
cuda_out, cuda_lse = varlen_runner(
q_varlen, k_cache, v_cache, q_varlen, k_cache, v_cache,
cu_seqlens_q, seqused_k, block_table, cu_seqlens_q, seqused_k, block_table,
max_seqlen_q=seqlen_q, max_seqlen_q=seqlen_q,
...@@ -1384,6 +1502,9 @@ def test_unified_attn_2d( ...@@ -1384,6 +1502,9 @@ def test_unified_attn_2d(
s_aux=sinks, s_aux=sinks,
mm_prefix_range=mm_prefix_range, mm_prefix_range=mm_prefix_range,
return_softmax_lse=True, return_softmax_lse=True,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
) )
# # ---- Triton kernel ---- # # ---- Triton kernel ----
...@@ -1416,15 +1537,164 @@ def test_unified_attn_2d( ...@@ -1416,15 +1537,164 @@ def test_unified_attn_2d(
# triton_max_diff = (triton_out - ref_out).abs().max().item() # triton_max_diff = (triton_out - ref_out).abs().max().item()
print( print(
f"\n[{dtype} | causal={causal} | {mha_type} | bs={batch_size} " f"\n[{dtype} | fp8={use_fp8} | causal={causal} | {mha_type} | bs={batch_size} "
f"sq={seqlen_q} sk={seqlen_k} blk={block_size} | " f"sq={seqlen_q} sk={seqlen_k} d={d}/{d_v} blk={block_size} | "
f"alibi_sqrt={use_alibi_sqrt} qq_bias={use_qq_bias} " f"alibi_sqrt={use_alibi_sqrt} qq_bias={use_qq_bias} "
f"sinks={use_sinks} mm_prefix={use_mm_prefix}]" f"sinks={use_sinks} mm_prefix={use_mm_prefix}]"
f"\n CUDA max_diff={cuda_max_diff:.4e}" f"\n CUDA max_diff={cuda_max_diff:.4e}"
# f"\n Triton max_diff={triton_max_diff:.4e}" # f"\n Triton max_diff={triton_max_diff:.4e}"
) )
cal_diff(cuda_out, ref_out, "out") cutlass_b16_256_prefill = (not use_fp8 and seqlen_q >= 16 and d == 256 and d_v == 256)
cal_diff(
cuda_out,
ref_out,
"out",
use_fp8=use_fp8,
cos_threshold=1e-3 if cutlass_b16_256_prefill else None,
)
@pytest.mark.parametrize(
"batch_size,seqlen_q,seqlen_k,causal,window_size",
[
pytest.param(1, 256, 4096, True, (-1, -1), id="causal-large-sq"),
pytest.param(2, 257, 4103, True, (-1, -1), id="causal-unaligned-sq-sk"),
pytest.param(4, 512, 8193, True, (-1, -1), id="causal-large-bs-sq-unaligned-sk"),
# The torch reference materializes dense [heads, seq_q, seq_kv] scores,
# so keep 4K/8K correctness at bs=1 while still hitting long prefill.
pytest.param(1, 4096, 4096, True, (-1, -1), id="causal-sq4096"),
pytest.param(1, 8192, 8192, True, (-1, -1), id="causal-sq8192"),
pytest.param(1, 4097, 8193, True, (-1, -1), id="causal-unaligned-sq4097-sk8193"),
pytest.param(1, 17, 129, False, (511, 0), id="swa-lower-boundary-unaligned"),
pytest.param(3, 65, 1025, False, (511, 0), id="swa-mid-unaligned"),
pytest.param(2, 513, 1537, False, (511, 0), id="swa-large-unaligned-sq-sk"),
pytest.param(1, 4096, 8193, False, (511, 0), id="swa-sq4096-sk8193"),
pytest.param(1, 8192, 8193, False, (511, 0), id="swa-sq8192-sk8193"),
],
)
def test_unified_attn_fp8_192x128_prefill_corner_cases(
batch_size, seqlen_q, seqlen_k, causal, window_size
):
test_unified_attn_2d(
batch_size=batch_size,
seqlen_q=seqlen_q,
seqlen_k=seqlen_k,
block_size=UNIFIED_BLOCK_SIZE,
d=192,
d_v=128,
causal=causal,
window_size=window_size,
softcap=0.0,
mha_type="gqa",
dtype=torch.bfloat16,
use_alibi_sqrt=False,
use_qq_bias=False,
use_sinks=False,
use_mm_prefix=False,
use_fp8=True,
)
@pytest.mark.parametrize(
"batch_size,seqlen_q,seqlen_k,causal,window_size",
[
pytest.param(1, 17, 129, True, (-1, -1), id="causal-lower-boundary-unaligned"),
pytest.param(2, 65, 1025, True, (-1, -1), id="causal-mid-unaligned"),
pytest.param(1, 256, 4097, True, (-1, -1), id="causal-large-unaligned-sk"),
pytest.param(1, 17, 129, False, (511, 0), id="swa-lower-boundary-unaligned"),
pytest.param(2, 65, 1025, False, (511, 0), id="swa-mid-unaligned"),
pytest.param(1, 256, 4097, False, (511, 0), id="swa-large-unaligned-sk"),
],
)
def test_unified_attn_fp8_256x256_prefill_corner_cases(
batch_size, seqlen_q, seqlen_k, causal, window_size
):
test_unified_attn_2d(
batch_size=batch_size,
seqlen_q=seqlen_q,
seqlen_k=seqlen_k,
block_size=UNIFIED_BLOCK_SIZE,
d=256,
d_v=256,
causal=causal,
window_size=window_size,
softcap=0.0,
mha_type="gqa",
dtype=torch.bfloat16,
use_alibi_sqrt=False,
use_qq_bias=False,
use_sinks=False,
use_mm_prefix=False,
use_fp8=True,
)
@pytest.mark.parametrize(
"batch_size,seqlen_q,seqlen_k,causal,window_size",
[
pytest.param(1, 1, 4096, True, (-1, -1), id="decode-sq1-large-sk"),
pytest.param(2, 4, 2048, True, (-1, -1), id="decode-mtp4-large-sk"),
pytest.param(1, 17, 129, True, (-1, -1), id="prefill-causal-lower-boundary"),
pytest.param(2, 65, 1025, False, (511, 0), id="prefill-swa-unaligned"),
],
)
def test_unified_attn_fp8_192x128_sinks_corner_cases(
batch_size, seqlen_q, seqlen_k, causal, window_size
):
test_unified_attn_2d(
batch_size=batch_size,
seqlen_q=seqlen_q,
seqlen_k=seqlen_k,
block_size=UNIFIED_BLOCK_SIZE,
d=192,
d_v=128,
causal=causal,
window_size=window_size,
softcap=0.0,
mha_type="gqa",
dtype=torch.bfloat16,
use_alibi_sqrt=False,
use_qq_bias=False,
use_sinks=True,
use_mm_prefix=False,
use_fp8=True,
)
@pytest.mark.parametrize("block_size", [64, 128])
@pytest.mark.parametrize("d,d_v", [(128, 128), (192, 128)])
@pytest.mark.parametrize(
"batch_size,seqlen_q,seqlen_k,causal,window_size",
[
pytest.param(1, 64, 257, True, (-1, -1), id="prefill-causal-unaligned"),
pytest.param(2, 65, 1025, False, (511, 0), id="prefill-swa-unaligned"),
pytest.param(16, 1, 2048, True, (-1, -1), id="decode-sq1"),
pytest.param(16, 4, 2048, False, (-1, -1), id="decode-mtp4-noncausal"),
pytest.param(16, 4, 2048, False, (511, 0), id="decode-mtp4-swa"),
],
)
def test_unified_attn_b16_page64_regression(
batch_size, seqlen_q, seqlen_k, causal, window_size, d, d_v, block_size
):
test_unified_attn_2d(
batch_size=batch_size,
seqlen_q=seqlen_q,
seqlen_k=seqlen_k,
block_size=block_size,
d=d,
d_v=d_v,
causal=causal,
window_size=window_size,
softcap=0.0,
mha_type="gqa",
dtype=torch.bfloat16,
use_alibi_sqrt=False,
use_qq_bias=False,
use_sinks=False,
use_mm_prefix=False,
use_fp8=False,
)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
...@@ -1440,8 +1710,8 @@ def benchmark_unified_attention(): ...@@ -1440,8 +1710,8 @@ def benchmark_unified_attention():
torch.manual_seed(42) torch.manual_seed(42)
dtype = torch.float16 dtype = torch.float16
d = 256 d = 128
block_size = 320 block_size = UNIFIED_BLOCK_SIZE
warmup = 10 warmup = 10
repeat = 50 repeat = 50
...@@ -1452,30 +1722,35 @@ def benchmark_unified_attention(): ...@@ -1452,30 +1722,35 @@ def benchmark_unified_attention():
MAX_MM_RANGES = 2 MAX_MM_RANGES = 2
# GQA # GQA
nheads = 8 nheads = 24
nheads_k = 1 nheads_k = 2
# workload shapes # workload shapes
shapes = [ shapes = [
# (8, 2048, 2048), # (8, 2048, 2048),
# (4, 2048, 2048), # (4, 2048, 2048),
(4, 1, 2048), (1, 4, 51200),
(2, 4, 51200),
(4, 4, 51200),
(8, 4, 51200),
(16, 4, 51200),
(32, 4, 51200),
# (4, 2048, 4096), # (4, 2048, 4096),
] ]
# feature configs (C A Q S P) # feature configs (C A Q S P)
feature_configs = [ feature_configs = [
(1,0,0,0,0), (1,0,0,0,0),
(1,1,0,0,0), # (1,1,0,0,0),
(1,0,1,0,0), # (1,0,1,0,0),
(1,0,0,1,0), # (1,0,0,1,0),
(1,0,0,0,1), # (1,0,0,0,1),
(1,1,1,0,0), # (1,1,1,0,0),
(1,1,0,1,0), # (1,1,0,1,0),
(1,1,0,0,1), # (1,1,0,0,1),
(1,1,1,1,0), # (1,1,1,1,0),
(1,1,1,0,1), # (1,1,1,0,1),
(1,1,1,1,1), # (1,1,1,1,1),
] ]
print("\nUnified Attention GQA Benchmark") print("\nUnified Attention GQA Benchmark")
...@@ -1485,7 +1760,8 @@ def benchmark_unified_attention(): ...@@ -1485,7 +1760,8 @@ def benchmark_unified_attention():
f"{'BS':>3} {'SQ':>6} {'SK':>6} | " f"{'BS':>3} {'SQ':>6} {'SK':>6} | "
f"{'C':>1} {'A':>1} {'Q':>1} {'S':>1} {'P':>1} | " f"{'C':>1} {'A':>1} {'Q':>1} {'S':>1} {'P':>1} | "
f"{'CUDA(ms)':>10} {'Triton(ms)':>11} | " f"{'CUDA(ms)':>10} {'Triton(ms)':>11} | "
f"{'CUDA TFLOPS':>11} {'Triton TFLOPS':>13} | {'Speedup':>8}" f"{'CUDA TFLOPS':>11} {'Triton TFLOPS':>13} | "
f"{'CUDA(GB/s)':>10} {'Triton(GB/s)':>12} | {'Speedup':>8}"
) )
print("-" * 120) print("-" * 120)
...@@ -1548,6 +1824,15 @@ def benchmark_unified_attention(): ...@@ -1548,6 +1824,15 @@ def benchmark_unified_attention():
triton_out = torch.zeros_like(q_varlen) triton_out = torch.zeros_like(q_varlen)
total_bytes = estimate_unified_attention_bytes(
batch_size, seqlen_q, seqlen_k, nheads, nheads_k, d, block_size,
q_bytes=torch.finfo(dtype).bits // 8,
k_bytes=torch.finfo(dtype).bits // 8,
v_bytes=torch.finfo(dtype).bits // 8,
d_v=d,
window_size=window_size,
)
for C, A, Q, S, P in feature_configs: for C, A, Q, S, P in feature_configs:
causal = bool(C) causal = bool(C)
...@@ -1574,7 +1859,8 @@ def benchmark_unified_attention(): ...@@ -1574,7 +1859,8 @@ def benchmark_unified_attention():
sinks = None sinks = None
if use_sinks: if use_sinks:
sinks = torch.randn(nheads, device=device, dtype=dtype) sink_dtype = torch.bfloat16 if use_fp8 else torch.float32
sinks = torch.randn(nheads, device=device, dtype=sink_dtype)
mm_prefix_range = None mm_prefix_range = None
if use_mm_prefix: if use_mm_prefix:
...@@ -1655,6 +1941,9 @@ def benchmark_unified_attention(): ...@@ -1655,6 +1941,9 @@ def benchmark_unified_attention():
# FLOPs # FLOPs
flops = 4.0 * batch_size * nheads * seqlen_q * seqlen_k * d flops = 4.0 * batch_size * nheads * seqlen_q * seqlen_k * d
cuda_bandwidth = total_bytes / 1e9 / cuda_ms * 1000
triton_bandwidth = total_bytes / 1e9 / triton_ms * 1000
cuda_tflops = flops / cuda_ms / 1e9 cuda_tflops = flops / cuda_ms / 1e9
triton_tflops = flops / triton_ms / 1e9 triton_tflops = flops / triton_ms / 1e9
...@@ -1663,11 +1952,277 @@ def benchmark_unified_attention(): ...@@ -1663,11 +1952,277 @@ def benchmark_unified_attention():
f"{C} {A} {Q} {S} {P} | " f"{C} {A} {Q} {S} {P} | "
f"{cuda_ms:10.3f} {triton_ms:11.3f} | " f"{cuda_ms:10.3f} {triton_ms:11.3f} | "
f"{cuda_tflops:11.2f} {triton_tflops:13.2f} | " f"{cuda_tflops:11.2f} {triton_tflops:13.2f} | "
f"{cuda_bandwidth:10.2f} {triton_bandwidth:12.2f} | "
f"{triton_ms/cuda_ms:8.2f}x" f"{triton_ms/cuda_ms:8.2f}x"
) )
print("=" * 120) print("=" * 120)
def benchmark_hg_b16_fp8_pa():
device = torch.device("cuda")
torch.manual_seed(42)
dtype = torch.bfloat16
fp8_dtype = current_platform.fp8_dtype()
d = 192
d_v = 128
block_size = UNIFIED_BLOCK_SIZE
warmup = 10
repeat = 50
nheads = 16
nheads_k = 2
softcap = 0.0
shapes = [
(bs, seqlen_q, 51200)
for seqlen_q in (1, 4)
for bs in (1, 2, 4, 8, 16, 32, 64)
]
shapes += [
(bs, seqlen_q, 4096)
for seqlen_q in (32, 128, 256, 512)
for bs in (1, 2, 4)
]
shapes += [
(1, 257, 4103),
(2, 513, 8193),
(1, 4096, 4096),
(1, 4097, 8193),
(1, 8192, 8192),
(1, 8192, 8193),
]
windows = [(-1, -1), (511, 0)]
def time_fn(fn):
for _ in range(warmup):
fn()
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(repeat):
fn()
torch.cuda.synchronize()
return (time.perf_counter() - start) / repeat * 1000
def unwrap_out(result):
if isinstance(result, (tuple, list)):
return result[0]
return result
def diff_stats(x, y):
x = x.float()
y = y.float()
denom = torch.clamp((x * x + y * y).sum(), min=1e-12)
cos_diff = 1 - 2 * (x * y).sum() / denom
max_diff = (x - y).abs().max()
return cos_diff.item(), max_diff.item()
max_ref_scores = int(os.getenv("BENCH_REF_MAX_SCORES", "20000000"))
print("\nHG Unified PA BF16/FP8 Benchmark")
print("=" * 170)
print(
f"{'BS':>3} {'SQ':>3} {'SK':>6} {'D':>7} {'WINDOW':>12} | "
f"{'HG_OK':>5} {'HG(ms)':>8} {'TRI(ms)':>8} {'FP8(ms)':>8} | "
f"{'FP8/HG':>8} {'FP8/TRI':>8} {'HG/TRI':>8} | "
f"{'REF_cos':>9} {'REF_max':>9} | {'FP8 GB/s':>9} {'NOTE':>18}"
)
print("-" * 170)
summary = {
"total": 0,
"hg_ok": 0,
"hg_fail": 0,
"fp8_hg_speedups": [],
"fp8_triton_speedups": [],
}
for window_size in windows:
for batch_size, seqlen_q, seqlen_k in shapes:
summary["total"] += 1
causal = window_size == (-1, -1)
softmax_scale = d ** (-0.5)
q_list = [torch.randn(seqlen_q, nheads, d, device=device, dtype=dtype) for _ in range(batch_size)]
k_list = [torch.randn(seqlen_k, nheads_k, d, device=device, dtype=dtype) for _ in range(batch_size)]
v_list = [torch.randn(seqlen_k, nheads_k, d_v, device=device, dtype=dtype) for _ in range(batch_size)]
q_b16 = torch.cat(q_list, dim=0)
k_b16, v_b16, block_table = make_paged_kv(k_list, v_list, block_size, device, dtype)
q_fp8 = q_b16.to(fp8_dtype)
k_fp8 = k_b16.to(fp8_dtype)
v_fp8 = v_b16.to(fp8_dtype)
cu_seqlens_q = torch.zeros(batch_size + 1, device=device, dtype=torch.int32)
cu_seqlens_q[1:] = torch.cumsum(
torch.tensor([seqlen_q] * batch_size, dtype=torch.int32), dim=0
)
seqused_k = torch.tensor([seqlen_k] * batch_size, device=device, dtype=torch.int32)
q_descale = torch.ones((batch_size, nheads), device=device, dtype=torch.float32)
k_descale = torch.ones((batch_size, nheads_k), device=device, dtype=torch.float32)
v_descale = torch.ones((batch_size, nheads_k), device=device, dtype=torch.float32)
expected_hg_symbol = (
"hg_prefix_prefill_varlen_fwd"
if seqlen_q >= 16
else "hg_prefix_decode_varlen_fwd"
)
def run_b16_hg_checked():
return unwrap_out(varlen_fwd_unified_expect_hg(
expected_hg_symbol,
q_b16, k_b16, v_b16, cu_seqlens_q, seqused_k, block_table,
max_seqlen_q=seqlen_q, max_seqlen_k=seqlen_k,
softmax_scale=softmax_scale, causal=causal,
window_size=window_size, softcap=softcap,
return_softmax_lse=False,
))
def run_b16_hg():
return unwrap_out(varlen_fwd_unified(
q_b16, k_b16, v_b16, cu_seqlens_q, seqused_k, block_table,
max_seqlen_q=seqlen_q, max_seqlen_k=seqlen_k,
softmax_scale=softmax_scale, causal=causal,
window_size=window_size, softcap=softcap,
return_softmax_lse=False,
))
def run_fp8_checked():
return unwrap_out(varlen_fwd_unified_expect_hg(
expected_hg_symbol,
q_fp8, k_fp8, v_fp8, cu_seqlens_q, seqused_k, block_table,
max_seqlen_q=seqlen_q, max_seqlen_k=seqlen_k,
softmax_scale=softmax_scale, causal=causal,
window_size=window_size, softcap=softcap,
return_softmax_lse=False,
q_descale=q_descale, k_descale=k_descale, v_descale=v_descale,
))
def run_fp8():
return unwrap_out(varlen_fwd_unified(
q_fp8, k_fp8, v_fp8, cu_seqlens_q, seqused_k, block_table,
max_seqlen_q=seqlen_q, max_seqlen_k=seqlen_k,
softmax_scale=softmax_scale, causal=causal,
window_size=window_size, softcap=softcap,
return_softmax_lse=False,
q_descale=q_descale, k_descale=k_descale, v_descale=v_descale,
))
triton_out = torch.empty(
(q_b16.shape[0], nheads, d_v), device=device, dtype=dtype
)
def run_triton_b16():
unified_attention(
q=q_b16,
k=k_b16,
v=v_b16,
out=triton_out,
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=seqlen_q,
seqused_k=seqused_k,
max_seqlen_k=seqlen_k,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
block_table=block_table,
softcap=softcap,
q_descale=None,
k_descale=None,
v_descale=None,
seq_threshold_3D=128,
)
return triton_out
note = ""
hg_ms = float("nan")
fp8_over_hg = float("nan")
hg_over_triton = float("nan")
hg_cos = float("nan")
hg_max = float("nan")
triton_ms = time_fn(run_triton_b16)
run_fp8_checked()
torch.cuda.synchronize()
fp8_ms = time_fn(run_fp8)
try:
b16_hg_out = run_b16_hg_checked()
torch.cuda.synchronize()
num_ref_scores = batch_size * seqlen_q * seqlen_k * nheads
hg_correct = True
if num_ref_scores <= max_ref_scores:
ref_out = torch.cat([
ref_attn(
q_list[i], k_list[i], v_list[i],
causal=causal,
window_size=window_size,
softmax_scale=softmax_scale,
softcap=softcap,
)
for i in range(batch_size)
], dim=0)
hg_cos, hg_max = diff_stats(b16_hg_out, ref_out)
hg_correct = hg_cos < 1e-3
if not hg_correct:
note = "HG_BF16_REF_DIFF"
else:
note = "REF_SKIP"
if not hg_correct:
summary["hg_fail"] += 1
else:
hg_ms = time_fn(run_b16_hg)
fp8_over_hg = hg_ms / fp8_ms
hg_over_triton = triton_ms / hg_ms
summary["hg_ok"] += 1
summary["fp8_hg_speedups"].append(fp8_over_hg)
except Exception as exc:
summary["hg_fail"] += 1
note = type(exc).__name__
fp8_over_triton = triton_ms / fp8_ms
summary["fp8_triton_speedups"].append(fp8_over_triton)
fp8_bytes = estimate_unified_attention_bytes(
batch_size, seqlen_q, seqlen_k, nheads, nheads_k, d, block_size,
q_bytes=1, k_bytes=1, v_bytes=1, out_bytes=2, d_v=d_v,
window_size=window_size,
)
print(
f"{batch_size:3d} {seqlen_q:3d} {seqlen_k:6d} {str(d) + '/' + str(d_v):>7} {str(window_size):>12} | "
f"{str(math.isfinite(hg_ms)):>5} {hg_ms:8.3f} {triton_ms:8.3f} {fp8_ms:8.3f} | "
f"{fp8_over_hg:8.2f} {fp8_over_triton:8.2f} {hg_over_triton:8.2f} | "
f"{hg_cos:9.2e} {hg_max:9.2e} | "
f"{fp8_bytes / 1e9 / fp8_ms * 1000:9.2f} {note:>18}"
)
print("-" * 170)
if summary["fp8_hg_speedups"]:
hg_speedups = torch.tensor(summary["fp8_hg_speedups"], dtype=torch.float32)
print(
"HG_BF16 baseline summary: "
f"ok={summary['hg_ok']}/{summary['total']} "
f"fail={summary['hg_fail']} "
f"fp8/hg mean={hg_speedups.mean().item():.3f} "
f"median={hg_speedups.median().item():.3f} "
f"min={hg_speedups.min().item():.3f} "
f"max={hg_speedups.max().item():.3f}"
)
triton_speedups = torch.tensor(summary["fp8_triton_speedups"], dtype=torch.float32)
print(
"Triton BF16 reference summary: "
f"fp8/triton mean={triton_speedups.mean().item():.3f} "
f"median={triton_speedups.median().item():.3f} "
f"min={triton_speedups.min().item():.3f} "
f"max={triton_speedups.max().item():.3f}"
)
print("=" * 170)
if __name__ == "__main__": if __name__ == "__main__":
benchmark_unified_attention() if os.getenv("RUN_UNIFIED_BENCHMARK") == "1":
\ No newline at end of file benchmark_hg_b16_fp8_pa()
else:
test_unified_attn_2d(
batch_size=1, seqlen_q=1, seqlen_k=40960, block_size=UNIFIED_BLOCK_SIZE,
d=192, d_v=128, causal=False, window_size=(511, 0), softcap=0.0,
mha_type="gqa", dtype=torch.bfloat16,
use_alibi_sqrt=False, use_qq_bias=False, use_sinks=False, use_mm_prefix=False,
use_fp8=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment