Commit 0e1300f7 authored by zhanghj2's avatar zhanghj2
Browse files

适配v32的decode kernel

parent 7abe5160
......@@ -4,7 +4,7 @@
// #include <cutlass/arch/barrier.h>
using bf16 = cutlass::bfloat16_t;
using fp8 = cutlass::float_e4m3_t;
using fp8 = unsigned char;
// using transac_bar_t = cutlass::arch::ClusterTransactionBarrier;
// using cutlass::arch::fence_view_async_shared;
// using cutlass::arch::fence_barrier_init;
......
......@@ -16,10 +16,11 @@ template<ModelType MODEL_TYPE, int NUM_HEADS>
class KernelTemplate {
public:
static_assert(NUM_HEADS == 64 || NUM_HEADS == 128);
static constexpr int NUM_M_BLOCKS = NUM_HEADS / 64;
static constexpr int CLUSTER_SIZE = NUM_M_BLOCKS;
static_assert(NUM_HEADS == 64 || NUM_HEADS == 128 || NUM_HEADS == 16);
// todo only support tp8
static constexpr int BLOCK_M = 16;
static constexpr int NUM_M_BLOCKS = NUM_HEADS / BLOCK_M;
static constexpr bool Is_causal = false;
static constexpr int HEAD_DIM_K = MODEL_TYPE == ModelType::V32 ? 576 : 512;
static constexpr int HEAD_DIM_V = 512;
static constexpr int HEAD_DIM_ROPE = 64;
......@@ -28,67 +29,88 @@ static constexpr int HEAD_DIM_NOPE = HEAD_DIM_K - HEAD_DIM_ROPE;
static constexpr int QUANT_TILE_SIZE = MODEL_TYPE == ModelType::V32 ? 128 : 64;
static constexpr int NUM_SCALES = MODEL_TYPE == ModelType::V32 ? 4 : 8; // For MODEL1: 7 fp8_e4m3 + 1 padding
static constexpr int NUM_THREADS = 128*3;
static constexpr int BLOCK_M = 64;
static constexpr int NUM_THREADS = 256;
static constexpr int TOPK_BLOCK_SIZE = 64;
static constexpr int NUM_K_BUFS = 2;
using SmemLayoutQTile = decltype(tile_to_shape(
GMMA::Layout_SW128_Atom<bf16, GMMA::Major::K>{},
Shape<Int<BLOCK_M>, Int<64>>{}
));
template<int NUM_TILES>
using SmemLayoutQTiles = decltype(tile_to_shape(
SmemLayoutQTile{},
Shape<Int<BLOCK_M>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
));
using SmemLayoutQ = SmemLayoutQTiles<HEAD_DIM_K/64>;
using SmemLayoutKTile = decltype(tile_to_shape(
GMMA::Layout_INTER_Atom<bf16, GMMA::Major::K>{},
Shape<Int<TOPK_BLOCK_SIZE>, _64>{},
Step<_1, _2>{}
));
template<int NUM_TILES>
using SmemLayoutKTiles = decltype(tile_to_shape(
SmemLayoutKTile{},
Shape<Int<TOPK_BLOCK_SIZE>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
));
template<int NUM_TILES>
using SmemLayoutKTilesTransposed = decltype(composition(
SmemLayoutKTiles<NUM_TILES>{},
Layout<Shape<Int<64*NUM_TILES>, Int<TOPK_BLOCK_SIZE>>, Stride<Int<TOPK_BLOCK_SIZE>, _1>>{}
));
static constexpr int OBUF_SW = 64;
using SmemLayoutOBufAtom = GMMA::Layout_K_SW128_Atom<bf16>;
using SmemLayoutOBuf = decltype(tile_to_shape(
SmemLayoutOBufAtom{},
Shape<Int<BLOCK_M>, Int<HEAD_DIM_V>>{},
Step<_1, _2>{}
));
using SmemLayoutOAccumBuf = Layout<
Shape<Int<BLOCK_M>, Int<HEAD_DIM_V>>,
Stride<Int<520>, _1> // We use stride = 520 here to avoid bank conflict
using elem_type = cutlass::bfloat16_t;
using MMA_Atom_Arch = std::conditional_t<
std::is_same_v<elem_type, cutlass::half_t>,
MMA_Atom<GFX928_16x16x64_F32F16F16F32_NT>,
MMA_Atom<GFX928_16x16x64_F32BF16BF16F32_NT>
>;
using SmemLayoutK = SmemLayoutKTiles<HEAD_DIM_K/64>;
using SmemLayoutV = SmemLayoutKTilesTransposed<HEAD_DIM_V/64>;
using SmemLayoutHalfV = SmemLayoutKTilesTransposed<HEAD_DIM_V/64/2>;
using SmemLayoutS = decltype(tile_to_shape(
GMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<BLOCK_M>, Int<TOPK_BLOCK_SIZE>>{}
));
static constexpr int kNWarps = 4;
using ValLayoutMNK = Layout<Shape<_1, _1, _1>>;
using TiledMma = TiledMMA<
MMA_Atom_Arch,
Layout<Shape<_1, Int<kNWarps>, _1>>, // 1x4x1 or 1x8x1 thread group
ValLayoutMNK>;
using MMA_Atom_Arch_16_16_32 = std::conditional_t<
std::is_same_v<elem_type, cutlass::half_t>,
MMA_Atom<GFX928_16x16x32_F32F16F16F32_NN>,
MMA_Atom<GFX928_16x16x32_F32BF16BF16F32_NN>
>;
using TiledMma_16_16_32 = TiledMMA<
MMA_Atom_Arch_16_16_32,
Layout<Shape<_1, Int<kNWarps>, _1>>, // 1x4x1 or 1x8x1 thread group
ValLayoutMNK>;
using MMA_Atom_Arch_16x32_NT = std::conditional_t<
std::is_same_v<elem_type, cutlass::half_t>,
MMA_Atom<GFX928_16x32x16_F32F16F16F32_NT>,
MMA_Atom<GFX928_16x32x16_F32BF16BF16F32_NT>
>;
using TiledMma_O = TiledMMA<
MMA_Atom_Arch_16x32_NT,
Layout<Shape<_1, Int<kNWarps>, _1>>, // 1x4x1 or 1x8x1 thread group
ValLayoutMNK>;
using SmemLayoutAtomK = decltype(composition(
Swizzle<3, 3, 3>{},
Layout<Shape<Int<8>, Int<32>>, Stride<Int<32>, _1>>{}));
using SmemLayoutK = decltype(tile_to_shape(
SmemLayoutAtomK{},
Shape<Int<TOPK_BLOCK_SIZE>, Int<8 * 32>>{}));
using SmemLayoutAtomV = SmemLayoutAtomK;
using SmemLayoutV = decltype(tile_to_shape(
SmemLayoutAtomV{},
Shape<Int<TOPK_BLOCK_SIZE>, Int<512>>{}));
using SmemLayoutVtransposed = decltype(
composition(SmemLayoutV{}, make_layout(Shape<Int<512>, Int<TOPK_BLOCK_SIZE>>{}, GenRowMajor{})));
using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{}));
using SmemLayoutAtomP = Layout<Shape<Int<4*16*16>>, Stride<Int<1>>>;
using SmemLayoutP = decltype(tile_to_shape(
SmemLayoutAtomP{},
Shape<Int<4*16*16>>{}));
using SmemLayoutRow = Layout<Shape<_128>, Stride<_1>>;
using Element = cutlass::bfloat16_t;
using ElementAccum = float;
struct SharedMemoryPlan {
union {
struct {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
};
struct {
// cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutV_tmp>> smem_v_tmp; // Double buffer
cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutRow>> smem_row_sum;
cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutRow>> smem_row_max;
};
// struct {
// cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutO>> smem_o;
// // cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutP>> smem_p;
// // cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_row_sum;
// // cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_row_max;
// };
// struct {
// cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutQ>> smem_q;
// };
};
// array_aligned<bf16, cosize_v<SmemLayoutQ>> q;
// union {
// array_aligned<bf16, cosize_v<SmemLayoutK>> k[NUM_K_BUFS];
......@@ -131,9 +153,8 @@ struct SharedMemoryPlan {
static __device__ __forceinline__ void
compute_attn_1rowblock_splitkv_sparse_mla_fp8(const SparseAttnDecodeParams &params, const DecodingSchedMeta& sched_meta, int batch_idx);
static __device__ __forceinline__ void
devfunc(const SparseAttnDecodeParams &params);
......
This diff is collapsed.
#pragma once
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
#include <cstdint>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
#include <cute/tensor.hpp>
#include "defines.h"
#define CHECK_CUDA(call) \
do { \
cudaError_t status_ = call; \
......@@ -80,3 +87,265 @@ struct RingBufferState {
return new_state;
}
};
namespace flash {
using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct MaxOp {
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; }
};
template <>
struct MaxOp<float> {
// This is slightly faster
__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct SumOp {
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int THREADS>
struct Allreduce {
static_assert(THREADS == 64 || THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4 || THREADS == 2);
template<typename T, typename Operator>
static __device__ __forceinline__ T run(T x, Operator &op) {
constexpr int OFFSET = THREADS / 2;
x = op(x, __shfl_xor(x, OFFSET, 64));
return Allreduce<OFFSET>::run(x, op);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Allreduce<1> {
// static_assert(THREADS == 64 || THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4 || THREADS == 2);
template<typename T, typename Operator>
static __device__ __forceinline__ T run(T x, Operator &op) {
return x;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Allreduce<32> {
template<typename T, typename Operator>
static __device__ __forceinline__ T run(T x, Operator &op) {
x = op(x, __shfl_xor(x, 16, 64));
return x;
}
};
template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0, int begin_k=0) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
// There's no case where !Clear_OOB_K && Clear_OOB_MN
static_assert(!(Clear_OOB_MN && !Clear_OOB_K));
#pragma unroll
for (int m = 0; m < size<1>(S); ++m) {
if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
#pragma unroll
for (int k = 0; k < size<2>(S); ++k) {
if (Is_even_K || predicate_K(k)) {
cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
} else if (Clear_OOB_K) {
cute::clear(D(_, m, k));
}
}
} else if (Clear_OOB_MN) {
cute::clear(D(_, m, _));
}
}
}
template<int row, int col, int r_row, typename Tensor0, typename Tensor1>
__forceinline__ __device__ void __ds_read_m32x16_row_col_rrow(Tensor0& src, Tensor1& dst)
{
auto lds = reinterpret_cast<__fp16 *>(src.data().get());
auto layout = src.layout();
constexpr short offset = layout(0, row, col) * 2;
auto d = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds), offset);
uint16_t * d_ptr = reinterpret_cast<uint16_t*>(&d);
uint16_t * dst_ptr = reinterpret_cast<uint16_t*>(&(dst(0, r_row, col)));
dst_ptr[0] = d_ptr[0];
dst_ptr[1] = d_ptr[1];
dst_ptr[2] = d_ptr[2];
dst_ptr[3] = d_ptr[3];
dst_ptr[4] = d_ptr[4];
dst_ptr[5] = d_ptr[5];
dst_ptr[6] = d_ptr[6];
dst_ptr[7] = d_ptr[7];
}
template<int row, int col, typename Tensor0, typename Tensor1>
__forceinline__ __device__ void __ds_read_m32x16_row_col(Tensor0& src, Tensor1& dst)
{
auto lds = reinterpret_cast<__fp16 *>(src.data().get());
auto layout = src.layout();
constexpr short offset = layout(0, row, col) * 2;
auto d = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds), offset);
uint16_t * d_ptr = reinterpret_cast<uint16_t*>(&d);
uint16_t * dst_ptr = reinterpret_cast<uint16_t*>(&(dst(0, row, col)));
dst_ptr[0] = d_ptr[0];
dst_ptr[1] = d_ptr[1];
dst_ptr[2] = d_ptr[2];
dst_ptr[3] = d_ptr[3];
dst_ptr[4] = d_ptr[4];
dst_ptr[5] = d_ptr[5];
dst_ptr[6] = d_ptr[6];
dst_ptr[7] = d_ptr[7];
}
inline __device__ float fp8e4m3_to_fp32(const fp8& input) {
const uint32_t w = (uint32_t)input << 24;
const uint32_t sign = w & UINT32_C(0x80000000);
const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF);
uint32_t renorm_shift = __clz(nonsign);
renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0;
uint32_t result = sign | ((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23));
union {
uint32_t as_bits;
float as_value;
} fp32 = {result};
return fp32.as_value;
}
template<typename Layout>
__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
// static_assert(decltype(size<0>(acc_layout))::value == 4 || decltype(size<0>(acc_layout))::value == 8);
static_assert(decltype(rank(acc_layout))::value == 3);
auto l = logical_divide(acc_layout, Shape<_1>{}); // (_4,_1,_2):(_1,_0,_4) -> ((_1,_4),_1,_2):((_0,_1),_0,_4)
return make_layout(make_layout(get<1>(l)), make_layout(get<1>(get<0>(l)), get<2>(l))); // (1, (4, 2)):((_0),(_1,_4))
};
template <typename To_type, typename Engine, typename Layout>
__forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
using From_type = typename Engine::value_type;
if constexpr (std::is_same_v<To_type, From_type>)
{
return tensor;
}
constexpr int numel = decltype(size(tensor))::value;
Tensor tensor_To_type = make_tensor<To_type>(layout(tensor));
cutlass::Array<To_type, numel> *result_ptr = reinterpret_cast<cutlass::Array<To_type, numel> *>(tensor_To_type.data());
#if defined(__gfx938__)
{
if constexpr (std::is_same_v<To_type, cutlass::bfloat16_t>) {
cutlass::NumericArrayConverter<To_type, From_type, numel, cutlass::FloatRoundStyle::round_to_nearest> convert_op;
*result_ptr = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
}
else if constexpr (std::is_same_v<To_type, cutlass::float_e4m3_t>) {
cutlass::NumericArrayConverter<To_type, From_type, numel,cutlass::FloatRoundStyle::round_to_nearest> convert_op;
*result_ptr = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
}
else {
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
*result_ptr = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
}
return tensor_To_type;
}
#else
{
if constexpr (std::is_same_v<To_type, cutlass::bfloat16_t>) {
cutlass::NumericArrayConverter<To_type, From_type, numel, cutlass::FloatRoundStyle::round_toward_zero> convert_op;
*result_ptr = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
} else {
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
*result_ptr = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
}
return tensor_To_type;
}
#endif
// cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
// // HACK: this requires tensor to be "contiguous"
// auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
// return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
}
template <class TiledMma, class TiledMma_O,
typename Engine0, typename Layout0,
typename Engine1, typename Layout1
>
__forceinline__ __device__ auto convert_layout_acc_Aregs(const TiledMma& tiled_mma, const TiledMma_O& tiled_mma_o, Tensor<Engine0, Layout0> const& tOrP,
Tensor<Engine1, Layout1> const& sAcc)
{
using Value_type = typename Engine0::value_type;
int tid = threadIdx.x % 64;
int warp_id = threadIdx.x / 64;
sAcc((tid % 16 ) * 8 + (tid / 16) + (warp_id % 2) * 4 + (warp_id / 2) * 16 * 32) = tOrP(0, 0, 0);
sAcc((tid % 16 ) * 8 + (tid / 16) + 1 * 16 * 8 + (warp_id % 2) * 4 + (warp_id / 2) * 16 * 32) = tOrP(1, 0, 0);
sAcc((tid % 16 ) * 8 + (tid / 16) + 2 * 16 * 8 + (warp_id % 2) * 4 + (warp_id / 2) * 16 * 32) = tOrP(2, 0, 0);
sAcc((tid % 16 ) * 8 + (tid / 16) + 3 * 16 * 8 + (warp_id % 2) * 4 + (warp_id / 2) * 16 * 32) = tOrP(3, 0, 0);
__syncthreads();
using SmemLayoutAtomP = Layout<Shape<Int<16>, Int<64>>, Stride<Int<64>, _1>>;
using SmemLayoutP = decltype(tile_to_shape(
SmemLayoutAtomP{},
Shape<Int<16>, Int<64>>{}));
Tensor sP_tmp = make_tensor(sAcc.data(),SmemLayoutP{});
auto thr_mma = tiled_mma_o.get_thread_slice(tid);
Tensor tSrACC = thr_mma.partition_fragment_A(sP_tmp);
tSrACC(0, 0, 0) = sAcc(tid * 8 + 0);
tSrACC(1, 0, 0) = sAcc(tid * 8 + 1);
tSrACC(2, 0, 0) = sAcc(tid * 8 + 2);
tSrACC(3, 0, 0) = sAcc(tid * 8 + 3);
tSrACC(0, 0, 1) = sAcc(tid * 8 + 0 + 4);
tSrACC(1, 0, 1) = sAcc(tid * 8 + 1 + 4);
tSrACC(2, 0, 1) = sAcc(tid * 8 + 2 + 4);
tSrACC(3, 0, 1) = sAcc(tid * 8 + 3 + 4);
tSrACC(0, 0, 2) = sAcc(tid * 8 + 0 + 16*32);
tSrACC(1, 0, 2) = sAcc(tid * 8 + 1 + 16*32);
tSrACC(2, 0, 2) = sAcc(tid * 8 + 2 + 16*32);
tSrACC(3, 0, 2) = sAcc(tid * 8 + 3 + 16*32);
tSrACC(0, 0, 3) = sAcc(tid * 8 + 0 + 4 + 16*32);
tSrACC(1, 0, 3) = sAcc(tid * 8 + 1 + 4 + 16*32);
tSrACC(2, 0, 3) = sAcc(tid * 8 + 2 + 4 + 16*32);
tSrACC(3, 0, 3) = sAcc(tid * 8 + 3 + 4 + 16*32);
return tSrACC;
}
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment