Commit 2033d805 authored by zhanghj2's avatar zhanghj2
Browse files

支持纯bf16

parent 58b43d4a
...@@ -75,7 +75,7 @@ dense_attn_decode_interface( ...@@ -75,7 +75,7 @@ dense_attn_decode_interface(
const int num_heads = num_heads_k; const int num_heads = num_heads_k;
q = q.view({batch_size, seqlen_q_ori, num_heads_k, num_q_heads_per_hk, head_size_k}).transpose(2, 3) q = q.view({batch_size, seqlen_q_ori, num_heads_k, num_q_heads_per_hk, head_size_k}).transpose(2, 3)
.reshape({batch_size, q_seq_per_hk, num_heads, head_size_k}); .reshape({batch_size, q_seq_per_hk, num_heads, head_size_k});
int num_sm_parts = std::max(arch.num_sms / num_heads_k / cutlass::ceil_div(seqlen_q_ori*num_heads_q/num_heads_k, 64), 1); int num_sm_parts = std::max(arch.num_sms / num_heads_k / cutlass::ceil_div(seqlen_q_ori*num_heads_q/num_heads_k, 16), 1);
KU_CHECK_SHAPE(q, batch_size, q_seq_per_hk, num_heads, head_size_k); KU_CHECK_SHAPE(q, batch_size, q_seq_per_hk, num_heads, head_size_k);
KU_CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k); KU_CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k);
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
namespace Config { namespace Config {
static constexpr int BLOCK_SIZE_M = 64; static constexpr int BLOCK_SIZE_M = 16;
static constexpr int PAGE_BLOCK_SIZE = 64; static constexpr int PAGE_BLOCK_SIZE = 64;
static constexpr int HEAD_DIM_K = 576; static constexpr int HEAD_DIM_K = 576;
......
This diff is collapsed.
...@@ -7,13 +7,12 @@ ...@@ -7,13 +7,12 @@
#include "config.h" #include "config.h"
using TMABarrier = cutlass::arch::ClusterTransactionBarrier;
using namespace cute; using namespace cute;
template<typename InputT_> template<typename InputT_, bool Is_causal_>
struct Traits { struct Traits {
using InputT = InputT_; using InputT = InputT_;
static constexpr bool Is_causal = Is_causal_;
static constexpr int BLOCK_SIZE_M = Config::BLOCK_SIZE_M; static constexpr int BLOCK_SIZE_M = Config::BLOCK_SIZE_M;
static constexpr int PAGE_BLOCK_SIZE = Config::PAGE_BLOCK_SIZE; static constexpr int PAGE_BLOCK_SIZE = Config::PAGE_BLOCK_SIZE;
static constexpr int HEAD_DIM_K = Config::HEAD_DIM_K; static constexpr int HEAD_DIM_K = Config::HEAD_DIM_K;
...@@ -23,63 +22,105 @@ struct Traits { ...@@ -23,63 +22,105 @@ struct Traits {
static_assert(std::is_same_v<InputT, cutlass::bfloat16_t> || std::is_same_v<InputT, cutlass::half_t>); static_assert(std::is_same_v<InputT, cutlass::bfloat16_t> || std::is_same_v<InputT, cutlass::half_t>);
using TiledMMA_QK_sQ = decltype(make_tiled_mma( static constexpr int kBlockM = BLOCK_SIZE_M;
GMMA::ss_op_selector<InputT, InputT, float, Shape<Int<BLOCK_SIZE_M>, Int<PAGE_BLOCK_SIZE>, Int<HEAD_DIM_K>>, GMMA::Major::K, GMMA::Major::K>(), static constexpr int kBlockN = PAGE_BLOCK_SIZE;
Layout<Shape<_1, _1, _1>>{} static constexpr int kHeadDim = HEAD_DIM_K;
)); static constexpr int kHeadDimV = HEAD_DIM_V;
static constexpr int kNWarps = 4;
using TiledMMA_QK_rQ = decltype(make_tiled_mma(
GMMA::rs_op_selector<InputT, InputT, float, Shape<Int<BLOCK_SIZE_M>, Int<PAGE_BLOCK_SIZE>, Int<HEAD_DIM_K>>, GMMA::Major::K, GMMA::Major::K>(),
Layout<Shape<_1, _1, _1>>{}
));
using TiledMMA_PV_LocalP = decltype(make_tiled_mma(
GMMA::rs_op_selector<InputT, InputT, float, Shape<Int<BLOCK_SIZE_M>, Int<HEAD_DIM_V/2>, Int<PAGE_BLOCK_SIZE>>, GMMA::Major::K, GMMA::Major::MN>(),
Layout<Shape<_1, _1, _1>>{}
));
using TiledMMA_PV_RemoteP = decltype(make_tiled_mma(
GMMA::ss_op_selector<InputT, InputT, float, Shape<Int<BLOCK_SIZE_M>, Int<HEAD_DIM_V/2>, Int<PAGE_BLOCK_SIZE>>, GMMA::Major::K, GMMA::Major::MN>(),
Layout<Shape<_1, _1, _1>>{}
));
using SmemLayoutQ = decltype(tile_to_shape(
GMMA::Layout_K_SW128_Atom<InputT>{},
Shape<Int<BLOCK_SIZE_M>, Int<HEAD_DIM_K>>{}
));
using Element = InputT;
using elem_type = Element;
using ElementAccum = float;
using SmemLayoutRow = Layout<Shape<_128>, Stride<_1>>;
using SmemLayoutAtomK = decltype(composition(
Swizzle<3, 3, 3>{},
Layout<Shape<Int<8>, Int<32>>, Stride<Int<32>, _1>>{}));
using SmemLayoutK = decltype(tile_to_shape( using SmemLayoutK = decltype(tile_to_shape(
GMMA::Layout_K_SW128_Atom<InputT>{}, SmemLayoutAtomK{},
Shape<Int<PAGE_BLOCK_SIZE>, Int<HEAD_DIM_K>>{} Shape<Int<kBlockN>, Int<16 * 32>>{}));
));
using SmemLayoutK_place_holder = decltype(tile_to_shape(
using SmemLayoutV = decltype(composition( SmemLayoutAtomK{},
SmemLayoutK{}, Shape<Int<kBlockN>, Int<15 * 32>>{}));
make_layout(Shape<Int<HEAD_DIM_V>, Int<PAGE_BLOCK_SIZE>>{}, GenRowMajor{}) using SmemLayoutAtomV = SmemLayoutAtomK;
)); // A transposed version of SmemLayoutK using SmemLayoutV = decltype(tile_to_shape(
SmemLayoutAtomV{},
Shape<Int<kBlockN>, Int<kHeadDimV>>{}));
using SmemLayoutAtomP = Layout<Shape<Int<4*16*16>>, Stride<Int<1>>>;
using SmemLayoutP = decltype(tile_to_shape(
SmemLayoutAtomP{},
Shape<Int<4*16*16>>{}));
using SmemLayoutVtransposed = decltype(
composition(SmemLayoutV{}, make_layout(Shape<Int<kHeadDimV>, Int<kBlockN>>{}, GenRowMajor{})));
using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{}));
using SmemLayoutAtomQ = decltype(composition(
Swizzle<3, 3, 3>{},
Layout<Shape<Int<8>, Int<64>>, Stride<Int<64>, _1>>{}));
using SmemLayoutQ = decltype(tile_to_shape(
SmemLayoutAtomQ{},
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
using ValLayoutMNK = Layout<Shape<_1, _1, _1>>;
// #if defined(__gfx936__) || defined(__gfx938__)
using MMA_Atom_Arch = std::conditional_t<
std::is_same_v<elem_type, cutlass::half_t>,
MMA_Atom<GFX928_16x16x32_F32F16F16F32_NT>,
MMA_Atom<GFX928_16x16x32_F32BF16BF16F32_NT>
>;
using TiledMma = TiledMMA<
MMA_Atom_Arch,
Layout<Shape<_1, Int<kNWarps>, _1>>, // 1x4x1 or 1x8x1 thread group
ValLayoutMNK>;
// #elif defined(__gfx928__)
// using MMA_Atom_Arch = std::conditional_t<
// std::is_same_v<elem_type, cutlass::half_t>,
// MMA_Atom<GFX928_16x16x32_F32F16F16F32_NT>,
// MMA_Atom<GFX928_16x16x32_F32BF16BF16F32_NT>
// >;
// using TiledMma = TiledMMA<
// MMA_Atom_Arch,
// Layout<Shape<_1, Int<kNWarps>, _1>>, // 1x4x1 or 1x8x1 thread group
// ValLayoutMNK>;
// #endif
using MMA_Atom_Arch_16x32 = std::conditional_t<
std::is_same_v<elem_type, cutlass::half_t>,
MMA_Atom<GFX928_16x32x16_F32F16F16F32_NT>,
MMA_Atom<GFX928_16x32x16_F32BF16BF16F32_NT>
>;
using TiledMma_O = TiledMMA<
MMA_Atom_Arch_16x32,
Layout<Shape<_1, Int<kNWarps>, _1>>, // 1x4x1 or 1x8x1 thread group
ValLayoutMNK>;
using GmemLayoutAtomQ = Layout<Shape <_32, _8>,
Stride< _8, _1>>;
using GmemTiledCopyQ = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtomQ{},
Layout<Shape<_1, _8>>{}));
using SmemLayoutP0 = decltype(tile_to_shape(
GMMA::Layout_K_SW128_Atom<InputT>{},
Shape<Int<BLOCK_SIZE_M>, Int<PAGE_BLOCK_SIZE>>{}
));
using rP0Layout = decltype(layout(partition_fragment_C(
TiledMMA_QK_sQ{},
Shape<Int<BLOCK_SIZE_M>, Int<PAGE_BLOCK_SIZE>>{}
)));
struct SharedMemoryPlan { struct SharedMemoryPlan {
cute::array_aligned<InputT, cosize_v<SmemLayoutQ>> smem_sQ; union {
cute::array_aligned<InputT, cosize_v<SmemLayoutK>> smem_sK0; struct {
cute::array_aligned<InputT, cosize_v<SmemLayoutK>> smem_sK1; cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v; // Double buffer
cute::array_aligned<InputT, cosize_v<SmemLayoutP0>> smem_sP0;
cute::array_aligned<float, BLOCK_SIZE_M> smem_sM; };
cute::array_aligned<float, 2*BLOCK_SIZE_M> sL_reduction_wksp; struct {
cute::array_aligned<float, BLOCK_SIZE_M> smem_sScale0; cute::array_aligned<Element, cute::cosize_v<SmemLayoutK_place_holder>> smem_temp; // Double buffer
cute::array_aligned<float, BLOCK_SIZE_M> smem_sScale1; cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
TMABarrier barriers_K0[HEAD_DIM_K/64]; cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutRow>> smem_row_sum;
TMABarrier barriers_K1[HEAD_DIM_K/64]; cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutRow>> smem_row_max;
TMABarrier barrier_Q; };
struct {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
};
};
}; };
}; };
......
...@@ -88,6 +88,18 @@ struct RingBufferState { ...@@ -88,6 +88,18 @@ struct RingBufferState {
} }
}; };
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr static bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
namespace flash { namespace flash {
using namespace cute; using namespace cute;
...@@ -559,5 +571,170 @@ lds_direct_copy_for_prefill_sparse_mla( ...@@ -559,5 +571,170 @@ lds_direct_copy_for_prefill_sparse_mla(
} }
template <
bool Is_even_MN=true,
bool Is_even_K=true,
bool mma_layout = false,
bool use_asm = false,
class SrcEngine, class SrcLayout
>
CUTE_HOST_DEVICE
void
buffer_load_copy(
Tensor<SrcEngine, SrcLayout> const& src,
uint128_t & dst,
int k_idx_, const int row_stride,
int offset_k,
const int max_MN=0)
{
constexpr int warp_size = 64;
int tidx = threadIdx.x;
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;
constexpr int element_size = 2;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
constexpr int elements_per_thread = 8;
if constexpr (mma_layout)
{
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
uint32x4_t global_addr = {0};
global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former);
global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
int mma_k = 32*64;
int row = tidx % 16;
int col = lane / 16;
int row_offset = row + (warp_id * 16) ;
int col_offset = col * elements_per_thread + k_idx * 32;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
if constexpr(use_asm) {
asm volatile(
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n"
" \n\t" :"=v"(dst),
"+v"(offset_v), "+s"(global_addr)
);
}
else {
auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false);
dst = *reinterpret_cast<uint128_t*>(&res);
}
}
else
{
uint32x4_t global_addr = {0};
*(uint64_t*)&global_addr = reinterpret_cast<uint64_t>(src.data().get());
// global_addr[1] += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
int mma_k = 32*64;
int row = tidx / 4;
int col = lane % 4;
int row_offset = row;
int col_offset = col * elements_per_thread + k_idx * 32;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
if constexpr(use_asm) {
asm volatile(
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n"
" \n\t" :"=v"(dst),
"+v"(offset_v), "+s"(global_addr)
);
}
else {
auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false);
dst = *reinterpret_cast<uint128_t*>(&res);
}
}
}
template<
class SrcEngine, class SrcLayout>
CUTE_HOST_DEVICE
void
buffer_to_tensor(const uint128_t & src, Tensor<SrcEngine, SrcLayout> & dst, int k_idx)
{
uint128_t* d = reinterpret_cast<uint128_t*>(&dst(0, 0, k_idx));
d[0] = src;
}
template <class TiledMma, class TiledMma_O,
typename Engine0, typename Layout0,
typename Engine1, typename Layout1
>
__forceinline__ __device__ auto convert_layout_acc_Aregs_dense(const TiledMma& tiled_mma, const TiledMma_O& tiled_mma_o, Tensor<Engine0, Layout0> const& tOrP,
Tensor<Engine1, Layout1> const& sAcc)
{
using Value_type = typename Engine0::value_type;
int tid = threadIdx.x % 64;
int warp_id = threadIdx.x / 64;
// __fp16 *smem_ptr =
// sAcc((tid % 16 ) * 4 + (tid / 16) + warp_id * 16 * 16) = tOrP(0, 0, 0);
// sAcc((tid % 16 ) * 4 + (tid / 16) + 16 * 4 + warp_id * 16 * 16) = tOrP(1, 0, 0);
// sAcc((tid % 16 ) * 4 + (tid / 16) + 2 * 16 * 4 + warp_id * 16 * 16) = tOrP(2, 0, 0);
// sAcc((tid % 16 ) * 4 + (tid / 16) + 3 * 16 * 4 + warp_id * 16 * 16) = tOrP(3, 0, 0);
sAcc((tid % 16 ) * 8 + (tid / 16) + (warp_id % 2) * 4 + (warp_id / 2) * 16 * 32) = tOrP(0, 0, 0);
sAcc((tid % 16 ) * 8 + (tid / 16) + 1 * 16 * 8 + (warp_id % 2) * 4 + (warp_id / 2) * 16 * 32) = tOrP(1, 0, 0);
sAcc((tid % 16 ) * 8 + (tid / 16) + 2 * 16 * 8 + (warp_id % 2) * 4 + (warp_id / 2) * 16 * 32) = tOrP(2, 0, 0);
sAcc((tid % 16 ) * 8 + (tid / 16) + 3 * 16 * 8 + (warp_id % 2) * 4 + (warp_id / 2) * 16 * 32) = tOrP(3, 0, 0);
__syncthreads();
using SmemLayoutAtomP = Layout<Shape<Int<16>, Int<64>>, Stride<Int<64>, _1>>;
using SmemLayoutP = decltype(tile_to_shape(
SmemLayoutAtomP{},
Shape<Int<16>, Int<64>>{}));
Tensor sP_tmp = make_tensor(sAcc.data(),SmemLayoutP{});
auto thr_mma = tiled_mma_o.get_thread_slice(tid);
Tensor tSrACC = thr_mma.partition_fragment_A(sP_tmp);
tSrACC(0, 0, 0) = sAcc(tid * 8 + 0);
tSrACC(1, 0, 0) = sAcc(tid * 8 + 1);
tSrACC(2, 0, 0) = sAcc(tid * 8 + 2);
tSrACC(3, 0, 0) = sAcc(tid * 8 + 3);
tSrACC(0, 0, 1) = sAcc(tid * 8 + 0 + 4);
tSrACC(1, 0, 1) = sAcc(tid * 8 + 1 + 4);
tSrACC(2, 0, 1) = sAcc(tid * 8 + 2 + 4);
tSrACC(3, 0, 1) = sAcc(tid * 8 + 3 + 4);
tSrACC(0, 0, 2) = sAcc(tid * 8 + 0 + 16*32);
tSrACC(1, 0, 2) = sAcc(tid * 8 + 1 + 16*32);
tSrACC(2, 0, 2) = sAcc(tid * 8 + 2 + 16*32);
tSrACC(3, 0, 2) = sAcc(tid * 8 + 3 + 16*32);
tSrACC(0, 0, 3) = sAcc(tid * 8 + 0 + 4 + 16*32);
tSrACC(1, 0, 3) = sAcc(tid * 8 + 1 + 4 + 16*32);
tSrACC(2, 0, 3) = sAcc(tid * 8 + 2 + 4 + 16*32);
tSrACC(3, 0, 3) = sAcc(tid * 8 + 3 + 4 + 16*32);
return tSrACC;
}
} }
\ No newline at end of file
...@@ -223,9 +223,10 @@ def main(torch_dtype): ...@@ -223,9 +223,10 @@ def main(torch_dtype):
] ]
performance_cases = [ performance_cases = [
TestParam(128, s_q, s_k, is_varlen=True, is_causal=is_causal, test_performance=True) TestParam(128, s_q, s_k, is_varlen=True, is_causal=is_causal, h_q = h_q, test_performance=True)
for is_causal in [False, True] for is_causal in [False, True]
for s_q in [1, 2] for s_q in [1, 2]
for h_q in [16, 128]
for s_k in [4096, 8192, 16384, 32768] for s_k in [4096, 8192, 16384, 32768]
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment