Commit 34e67b1e authored by zhangshao's avatar zhangshao
Browse files

first commit

parents
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#pragma once
#include <cute/tensor.hpp>
#include "utils.h"
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace flash {
using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool Is_even_K=true, bool Clear_OOB_K=true,
typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
__forceinline__ __device__ void copy_rotary_interleaved(Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D,
Tensor<Engine2, Layout2> const &Cos,
Tensor<Engine2, Layout2> const &Sin,
Tensor<Engine3, Layout3> const &identity_MN,
const int max_MN, const int min_MN,
const int dim, const int rotary_dim) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K
CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K
static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2);
static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32
Tensor rCos = make_fragment_like(Cos);
Tensor rSin = make_fragment_like(Sin);
Tensor rS = make_fragment_like(S);
#pragma unroll
for (int m = 0; m < size<1>(S); ++m) {
if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
#pragma unroll
for (int k = 0; k < size<2>(S); ++k) {
if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {
cute::copy(S(_, m, k), rS(_, m, k));
if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {
cute::copy(Cos(_, m, k), rCos(_, m, k));
cute::copy(Sin(_, m, k), rSin(_, m, k));
Tensor S_fp32 = convert_type<float>(rS(_, m, k));
Tensor cos_fp32 = convert_type<float>(rCos(_, m, k));
Tensor sin_fp32 = convert_type<float>(rSin(_, m, k));
#pragma unroll
for (int i = 0; i < size<0>(rS) / 2; ++i) {
float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i);
float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i);
S_fp32(2 * i) = real;
S_fp32(2 * i + 1) = imag;
}
// Idk but I need to copy for the convert_type to work
Tensor S_fp32_copy = make_fragment_like(S_fp32);
cute::copy(S_fp32, S_fp32_copy);
using T = typename Engine0::value_type;
Tensor S_og_type = convert_type<T>(S_fp32_copy);
cute::copy(S_og_type, rS(_, m, k));
}
cute::copy(rS(_, m, k), D(_, m, k));
} else if (Clear_OOB_K) {
cute::clear(D(_, m, k));
}
}
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool Is_even_K=true, bool Clear_OOB_K=true,
typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
__forceinline__ __device__ void copy_rotary_contiguous(Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D,
Tensor<Engine2, Layout2> const &Cos,
Tensor<Engine2, Layout2> const &Sin,
Tensor<Engine3, Layout3> const &identity_MN,
const int max_MN, const int min_MN,
const int dim, const int rotary_dim) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA
CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin));
static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32
Tensor rCos = make_fragment_like(Cos);
Tensor rSin = make_fragment_like(Sin);
Tensor rS = make_fragment_like(S);
Tensor rS_other = make_fragment_like(rS(_, 0, 0));
#pragma unroll
for (int m = 0; m < size<1>(S); ++m) {
if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
#pragma unroll
for (int k = 0; k < size<2>(S); ++k) {
if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {
cute::copy(S(_, m, k), rS(_, m, k));
if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {
const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2;
Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout());
cute::copy(gS_other, rS_other);
// if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); }
Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout());
Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout());
cute::copy(gCos, rCos(_, m, k));
cute::copy(gSin, rSin(_, m, k));
// if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); }
Tensor S_fp32 = convert_type<float>(rS(_, m, k));
Tensor S_other_fp32 = convert_type<float>(rS_other);
Tensor cos_fp32 = convert_type<float>(rCos(_, m, k));
Tensor sin_fp32 = convert_type<float>(rSin(_, m, k));
#pragma unroll
for (int i = 0; i < size<0>(rS); ++i) {
S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i));
}
// Idk but I need to copy for the convert_type to work
Tensor S_fp32_copy = make_fragment_like(S_fp32);
cute::copy(S_fp32, S_fp32_copy);
using T = typename Engine0::value_type;
Tensor S_og_type = convert_type<T>(S_fp32_copy);
cute::copy(S_og_type, rS(_, m, k));
// if (cute::thread0()) { print_tensor(rS(_, m, k)); }
}
cute::copy(rS(_, m, k), D(_, m, k));
} else if (Clear_OOB_K) {
cute::clear(D(_, m, k));
}
}
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace flash
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#pragma once
#include <cmath>
#include <cute/tensor.hpp>
#include <cutlass/numeric_types.h>
#include "philox.cuh"
#include "utils.h"
namespace flash {
using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); mi++) {
summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
#pragma unroll
for (int ni = 1; ni < size<1>(tensor); ni++) {
// float ori = summary(mi);
summary(mi) = op(summary(mi), tensor(mi, ni));
// wangaq debug
// if (thread0()) {
// printf("thread_reduce_ mi:%d ni:%d %7.4f %7.4f %7.4f\n", mi, ni, ori, tensor(mi, ni), summary(mi));
// }
}
}
}
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
#pragma unroll
for (int i = 0; i < size(dst); i++){
dst(i) = Allreduce<64>::run(src(i), op);
// if (blockIdx.x == 0) {
// printf("tid:%3d A:%7.4f B:%7.4f \n", threadIdx.x,
// src(i), dst(i));
// }
}
}
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void quad_allreduce_sum_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
int tidx = threadIdx.x % 64;
float a, b = 1.0;
#pragma unroll
for (int i = 0; i < size(dst); i++){
v4f d = {0};
a = src(i);
d = __builtin_amdgcn_mmac_f32_16x16x4f32(a, b, d);
dst(i) = d.x;
// if (blockIdx.x == 0) {
// printf("tid:%3d A:%7.4f B:%7.4f "
// "D:%10.4f %10.4f %10.4f %10.4f sum:%7.4f\n", threadIdx.x,
// a, b,
// d[0], d[1], d[2], d[3], dst(i));
// }
}
}
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void quad_allreduce_with_mmac_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
int tidx = threadIdx.x % 64;
float a, b = 0.0 + (tidx / 4 == 0 && tidx / 16 == 0) + (tidx / 4 == 5 && tidx / 16 == 1) +
(tidx / 4 == 10 && tidx / 16 == 2) + (tidx / 4 == 15 && tidx / 16 == 3);
#pragma unroll
for (int i = 0; i < size(dst); i++){
v4f d = {0};
a = src(i) == -INFINITY ? -10000.0 : src(i);
d = __builtin_amdgcn_mmac_f32_16x16x4f32(a, b, d);
dst(i) = isnan(d.x) ? -INFINITY : d.x;
dst(i) = op(dst(i), isnan(d.y) ? -INFINITY : d.y);
dst(i) = op(dst(i), isnan(d.z) ? -INFINITY : d.z);
dst(i) = op(dst(i), isnan(d.w) ? -INFINITY : d.w);
// if (blockIdx.x == 0) {
// printf("tid:%3d A:%7.4f B:%7.4f "
// "D:%10.4f %10.4f %10.4f %10.4f max:%7.4f\n", threadIdx.x,
// a, b,
// d[0], d[1], d[2], d[3], dst(i));
// }
}
}
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
thread_reduce_<zero_init>(tensor, summary, op);
#if 1
quad_allreduce_(summary, summary, op);
#else
quad_allreduce_with_mmac_(summary, summary, op);
#endif
}
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
MaxOp<float> max_op;
reduce_<zero_init>(tensor, max, max_op);
}
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
SumOp<float> sum_op;
thread_reduce_<zero_init>(tensor, sum, sum_op);
}
// Apply the exp to all the elements.
template <bool Scale_max=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
// If we don't have float around M_LOG2E the multiplication is done in fp64.
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E));
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
// The following macro will disable the use of fma.
// See: https://github.com/pytorch/pytorch/issues/121558 for more details
// This macro is set in PyTorch and not FlashAttention
tensor(mi, ni) = custom_exp2f(tensor(mi, ni) * scale - max_scaled);
}
}
}
// Apply the exp to all the elements.
// template <bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
// __forceinline__ __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
// static_assert(Layout0::rank == 2, "Only support 2D Tensor");
// static_assert(Layout1::rank == 1, "Only support 1D Tensor");
// CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
// #pragma unroll
// for (int mi = 0; mi < size<0>(tensor); ++mi) {
// MaxOp<float> max_op;
// max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0));
// #pragma unroll
// for (int ni = 1; ni < size<1>(tensor); ni++) {
// max(mi) = max_op(max(mi), tensor(mi, ni));
// }
// max(mi) = Allreduce<4>::run(max(mi), max_op);
// // If max is -inf, then all elements must have been -inf (possibly due to masking).
// // We don't want (-inf - (-inf)) since that would give NaN.
// const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale;
// sum(mi) = 0;
// #pragma unroll
// for (int ni = 0; ni < size<1>(tensor); ++ni) {
// // Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// // max * log_2(e)) This allows the compiler to use the ffma
// // instruction instead of fadd and fmul separately.
// tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
// sum(mi) += tensor(mi, ni);
// }
// SumOp<float> sum_op;
// sum(mi) = Allreduce<4>::run(sum(mi), sum_op);
// }
// }
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int kNRows>
struct Softmax {
using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
TensorT row_max, row_sum;
float skip_softmax_threshold;
uint32_t total_blocks;
uint32_t skipped_blocks;
__forceinline__ __device__ Softmax() : skip_softmax_threshold(0.f), total_blocks(0), skipped_blocks(0) {};
template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1>
__forceinline__ __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2) {
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
// wangaq debug
// __syncthreads();
// if (thread0()) {
// printf("scores %7.4f %7.4f %7.4f %7.4f %7.4f %7.4f %7.4f %7.4f "
// "%7.4f %7.4f %7.4f %7.4f %7.4f %7.4f %7.4f %7.4f "
// "%7.4f %7.4f %7.4f %7.4f %7.4f %7.4f %7.4f %7.4f "
// "%7.4f %7.4f %7.4f %7.4f %7.4f %7.4f %7.4f %7.4f\n",
// scores(0, 0), scores(0, 1), scores(0, 2), scores(0, 3),
// scores(0, 4), scores(0, 5), scores(0, 6), scores(0, 7),
// scores(0, 8), scores(0, 9), scores(0, 10), scores(0, 11),
// scores(0, 12), scores(0, 13), scores(0, 14), scores(0, 15),
// scores(1, 0), scores(1, 1), scores(1, 2), scores(1, 3),
// scores(1, 4), scores(1, 5), scores(1, 6), scores(1, 7),
// scores(1, 8), scores(1, 9), scores(1, 10), scores(1, 11),
// scores(1, 12), scores(1, 13), scores(1, 14), scores(1, 15)
// );
// }
static_assert(decltype(size<0>(scores))::value == kNRows);
if (Is_first) {
flash::template reduce_max</*zero_init=*/true>(scores, row_max);
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
} else {
Tensor scores_max_prev = make_fragment_like(row_max);
cute::copy(row_max, scores_max_prev);
flash::template reduce_max</*zero_init=*/false>(scores, row_max);
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
float scores_max_cur = !Check_inf
? row_max(mi)
: (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
float scores_scale = custom_exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
row_sum(mi) *= scores_scale;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
}
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
// We don't do the reduce across threads here since we don't need to use the row_sum.
// We do that reduce at the end when we need to normalize the softmax.
flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
}
};
template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1>
__forceinline__ __device__ bool softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2, uint32_t * skip_softmax_vote) {
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == kNRows);
if (Is_first) {
total_blocks++;
flash::template reduce_max</*zero_init=*/true>(scores, row_max);
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
return false;
} else {
total_blocks++;
bool skip = true;
float scores_scale[kNRows];
Tensor scores_max_prev = make_fragment_like(row_max);
Tensor scores_max_local = make_fragment_like(row_max);
cute::copy(row_max, scores_max_prev);
flash::template reduce_max</*zero_init=*/true>(scores, scores_max_local);
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
MaxOp<float> max_op;
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
skip &= (custom_exp2f((scores_max_local(mi) - scores_max_prev(mi)) * softmax_scale_log2) < skip_softmax_threshold);
// wangaq debug
// if (blockIdx.x == 0) {
// float skip_max = custom_exp2f((scores_max_local(mi) - scores_max_prev(mi)) * softmax_scale_log2);
// printf("tid:%d mi:%d total_blocks:%d scores_max_local:%10.4f scores_max_prev:%10.4f "
// "skip_max:%10.4f skip_softmax_threshold:%10.4f skip:%d "
// "%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
// "%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f\n",
// threadIdx.x, mi, total_blocks, scores_max_local(mi), scores_max_prev(mi),
// skip_max, skip_softmax_threshold, skip,
// scores(mi, 0), scores(mi, 1), scores(mi, 2), scores(mi, 3),
// scores(mi, 4), scores(mi, 5), scores(mi, 6), scores(mi, 7),
// scores(mi, 8), scores(mi, 9), scores(mi, 10), scores(mi, 11),
// scores(mi, 12), scores(mi, 13), scores(mi, 14), scores(mi, 15)
// );
// }
scores_max_local(mi) = max_op(scores_max_local(mi), scores_max_prev(mi));
}
skip = __all_sync((uint64_t)0xffffffffffffffff, skip);
if (threadIdx.x % 64 == 0) {
// The leader of each warp votes.
atomicAnd(skip_softmax_vote, uint32_t(skip));
}
// __syncthreads();
s_barrier();
// asm volatile("s_waitcnt lgkmcnt(0); s_barrier\n");
// skip = *((uint32_t volatile*) skip_softmax_vote);
uint32_t skip_vote;
int skip_softmax_vote_addr = reinterpret_cast<size_t>(skip_softmax_vote);
asm volatile("ds_read_b32 %0, %1 offset:0\n" : "=v"(skip_vote) : "v"(skip_softmax_vote_addr) :);
asm volatile("s_waitcnt lgkmcnt(0); s_barrier\n");
if (skip_vote)
{
skipped_blocks++;
// wangaq debug
// if (blockIdx.x == 0) {
// printf("tid:%d total_blocks:%d skipped_blocks:%d\n",
// threadIdx.x, total_blocks, skipped_blocks
// );
// }
return true;
}
cute::copy(scores_max_local, row_max);
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
float scores_max_cur = !Check_inf
? row_max(mi)
: (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
float scores_scale = custom_exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
row_sum(mi) *= scores_scale;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
}
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
// We don't do the reduce across threads here since we don't need to use the row_sum.
// We do that reduce at the end when we need to normalize the softmax.
flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
return false;
}
};
template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1, typename Tensor2>
__forceinline__ __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, Tensor2 &acc_o_tail, float softmax_scale_log2) {
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
// wangaq debug
// __syncthreads();
// if (thread0()) {
// printf("scores %7.4f %7.4f %7.4f %7.4f %7.4f %7.4f %7.4f %7.4f "
// "%7.4f %7.4f %7.4f %7.4f %7.4f %7.4f %7.4f %7.4f "
// "%7.4f %7.4f %7.4f %7.4f %7.4f %7.4f %7.4f %7.4f "
// "%7.4f %7.4f %7.4f %7.4f %7.4f %7.4f %7.4f %7.4f\n",
// scores(0, 0), scores(0, 1), scores(0, 2), scores(0, 3),
// scores(0, 4), scores(0, 5), scores(0, 6), scores(0, 7),
// scores(0, 8), scores(0, 9), scores(0, 10), scores(0, 11),
// scores(0, 12), scores(0, 13), scores(0, 14), scores(0, 15),
// scores(1, 0), scores(1, 1), scores(1, 2), scores(1, 3),
// scores(1, 4), scores(1, 5), scores(1, 6), scores(1, 7),
// scores(1, 8), scores(1, 9), scores(1, 10), scores(1, 11),
// scores(1, 12), scores(1, 13), scores(1, 14), scores(1, 15)
// );
// }
static_assert(decltype(size<0>(scores))::value == kNRows);
if (Is_first) {
flash::template reduce_max</*zero_init=*/true>(scores, row_max);
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
} else {
Tensor scores_max_prev = make_fragment_like(row_max);
cute::copy(row_max, scores_max_prev);
flash::template reduce_max</*zero_init=*/false>(scores, row_max);
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
Tensor acc_o_tail_rowcol = make_tensor(acc_o_tail.data(), flash::convert_layout_acc_rowcol(acc_o_tail.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
float scores_max_cur = !Check_inf
? row_max(mi)
: (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
float scores_scale = custom_exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
row_sum(mi) *= scores_scale;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_tail_rowcol); ++ni) { acc_o_tail_rowcol(mi, ni) *= scores_scale; }
}
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
// We don't do the reduce across threads here since we don't need to use the row_sum.
// We do that reduce at the end when we need to normalize the softmax.
flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
}
};
template<bool Is_first, bool Check_inf=false,
typename Tensor0,
typename Tensor1, typename Tensor2, typename Tensor3, typename Tensor4>
__forceinline__ __device__ void softmax_rescale_o(
Tensor0 &acc_s,
Tensor1 &acc_o0, Tensor2 &acc_o1, Tensor3 &acc_o2, Tensor4 &acc_o3,
float softmax_scale_log2) {
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == kNRows);
if (Is_first) {
flash::template reduce_max</*zero_init=*/true>(scores, row_max);
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
} else {
Tensor scores_max_prev = make_fragment_like(row_max);
cute::copy(row_max, scores_max_prev);
flash::template reduce_max</*zero_init=*/false>(scores, row_max);
// === 将四个 acc_o 都转为 rowcol 布局 ===
Tensor acc_o0_rowcol = make_tensor(acc_o0.data(), flash::convert_layout_acc_rowcol(acc_o0.layout()));
Tensor acc_o1_rowcol = make_tensor(acc_o1.data(), flash::convert_layout_acc_rowcol(acc_o1.layout()));
Tensor acc_o2_rowcol = make_tensor(acc_o2.data(), flash::convert_layout_acc_rowcol(acc_o2.layout()));
Tensor acc_o3_rowcol = make_tensor(acc_o3.data(), flash::convert_layout_acc_rowcol(acc_o3.layout()));
static_assert(decltype(size<0>(acc_o0_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
float scores_max_cur = !Check_inf
? row_max(mi)
: (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
float scores_scale = custom_exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
row_sum(mi) *= scores_scale;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o0_rowcol); ++ni) {
acc_o0_rowcol(mi, ni) *= scores_scale;
acc_o1_rowcol(mi, ni) *= scores_scale;
acc_o2_rowcol(mi, ni) *= scores_scale;
acc_o3_rowcol(mi, ni) *= scores_scale;
}
}
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
}
};
template<bool Is_first, bool Check_inf=false,
typename Tensor0,
typename Tensor1, typename Tensor2, typename Tensor3>
__forceinline__ __device__ void softmax_rescale_o(
Tensor0 &acc_s,
Tensor1 &acc_o0, Tensor2 &acc_o1, Tensor3 &acc_o2,
float softmax_scale_log2) {
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == kNRows);
if (Is_first) {
flash::template reduce_max</*zero_init=*/true>(scores, row_max);
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
} else {
Tensor scores_max_prev = make_fragment_like(row_max);
cute::copy(row_max, scores_max_prev);
flash::template reduce_max</*zero_init=*/false>(scores, row_max);
// === 将四个 acc_o 都转为 rowcol 布局 ===
Tensor acc_o0_rowcol = make_tensor(acc_o0.data(), flash::convert_layout_acc_rowcol(acc_o0.layout()));
Tensor acc_o1_rowcol = make_tensor(acc_o1.data(), flash::convert_layout_acc_rowcol(acc_o1.layout()));
Tensor acc_o2_rowcol = make_tensor(acc_o2.data(), flash::convert_layout_acc_rowcol(acc_o2.layout()));
static_assert(decltype(size<0>(acc_o0_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
float scores_max_cur = !Check_inf
? row_max(mi)
: (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
float scores_scale = custom_exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
row_sum(mi) *= scores_scale;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o0_rowcol); ++ni) {
acc_o0_rowcol(mi, ni) *= scores_scale;
acc_o1_rowcol(mi, ni) *= scores_scale;
acc_o2_rowcol(mi, ni) *= scores_scale;
}
}
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
}
};
// Softmax rescale with max_diff return for dynamic PV skip optimization
// Returns max_diff = max(current_block_local_max - previous_global_max) following SpargeAttn convention
// Execute P@V when: max_diff + pv_threshold > 0 (current block contribution significant)
// Skip P@V when: max_diff + pv_threshold <= 0 (current block contribution negligible)
template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1>
__forceinline__ __device__ float softmax_rescale_o_with_diff(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2) {
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == kNRows);
float local_max_diff = -INFINITY;
if (Is_first) {
flash::template reduce_max</*zero_init=*/true>(scores, row_max);
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
// Note: row_sum will be initialized in accumulate_softmax_sum() for first block
// First block must always compute P@V, return +INFINITY to force execution
local_max_diff = INFINITY;
} else {
// ========== OPTIMIZED: Align with SpargeAttn, minimize overhead ==========
// Step 1: Save previous global max
Tensor scores_max_prev = make_fragment_like(row_max);
cute::copy(row_max, scores_max_prev);
// Step 2: Compute current block's LOCAL max into row_max temporarily
// This overwrites row_max with local max (will restore cumulative later)
flash::template reduce_max</*zero_init=*/true>(scores, row_max);
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
// Step 3: Compute max_diff and update to cumulative max in single pass
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
// row_max now contains LOCAL max from current block
float scores_max_cur_local = !Check_inf
? row_max(mi)
: (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
float scores_max_prev_val = scores_max_prev(mi);
// Compute max_diff = local_max - global_max (can be negative!)
// This matches SpargeAttn convention (attn_utils.cuh:445)
float row_diff = (scores_max_cur_local - scores_max_prev_val) * softmax_scale_log2;
local_max_diff = max(local_max_diff, row_diff);
// Update row_max to cumulative max for rescaling
float scores_max_new_global = max(scores_max_prev_val, scores_max_cur_local);
row_max(mi) = scores_max_new_global;
// Rescale previous accumulations if global max increased
float scores_scale = custom_exp2f((scores_max_prev_val - scores_max_new_global) * softmax_scale_log2);
row_sum(mi) *= scores_scale;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
acc_o_rowcol(mi, ni) *= scores_scale;
}
}
// Compute exp(scores - max) for P@V, but don't accumulate to row_sum yet
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
// NOTE: row_sum accumulation is deferred to accumulate_softmax_sum()
}
return local_max_diff;
};
// Accumulate softmax probabilities to row_sum (denominator)
template<bool Is_first, typename Tensor0>
__forceinline__ __device__ void accumulate_softmax_sum(Tensor0 &acc_s) {
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == kNRows);
// Accumulate exp(scores) to row_sum
// acc_s already contains exp(scores - max) from softmax_rescale_o_with_diff
flash::reduce_sum</*zero_init=*/Is_first>(scores, row_sum);
};
template<bool Is_dropout=false, bool Split=false, typename Tensor0>
__forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) {
SumOp<float> sum_op;
#if 1
quad_allreduce_(row_sum, row_sum, sum_op);
#else
quad_allreduce_sum_(row_sum, row_sum, sum_op);
#endif
TensorT lse = make_fragment_like(row_sum);
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
float sum = row_sum(mi);
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
}
return lse;
};
template<bool Is_dropout=false, bool Split=false, typename Tensor0, typename Tensor1>
__forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, Tensor1 &acc_o_tail, float softmax_scale, float rp_dropout=1.0) {
SumOp<float> sum_op;
#if 1
quad_allreduce_(row_sum, row_sum, sum_op);
#else
quad_allreduce_sum_(row_sum, row_sum, sum_op);
#endif
TensorT lse = make_fragment_like(row_sum);
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
Tensor acc_o_tail_rowcol = make_tensor(acc_o_tail.data(), flash::convert_layout_acc_rowcol(acc_o_tail.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
float sum = row_sum(mi);
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
for (int ni = 0; ni < size<1>(acc_o_tail_rowcol); ++ni) { acc_o_tail_rowcol(mi, ni) *= scale; }
}
return lse;
};
template<bool Is_dropout=false, bool Split=false,
typename Tensor0, typename Tensor1, typename Tensor2>
__forceinline__ __device__ TensorT normalize_softmax_lse(
Tensor0 &acc_o0, Tensor1 &acc_o1, Tensor2 &acc_o2,
float softmax_scale, float rp_dropout=1.0) {
SumOp<float> sum_op;
#if 1
quad_allreduce_(row_sum, row_sum, sum_op);
#else
quad_allreduce_sum_(row_sum, row_sum, sum_op);
#endif
TensorT lse = make_fragment_like(row_sum);
// === 将四个 acc_o 转换为 rowcol 布局 ===
Tensor acc_o0_rowcol = make_tensor(acc_o0.data(), flash::convert_layout_acc_rowcol(acc_o0.layout()));
Tensor acc_o1_rowcol = make_tensor(acc_o1.data(), flash::convert_layout_acc_rowcol(acc_o1.layout()));
Tensor acc_o2_rowcol = make_tensor(acc_o2.data(), flash::convert_layout_acc_rowcol(acc_o2.layout()));
static_assert(decltype(size<0>(acc_o0_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size<0>(acc_o0_rowcol); ++mi) {
float sum = row_sum(mi);
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
lse(mi) = (sum == 0.f || sum != sum)
? (Split ? -INFINITY : INFINITY)
: row_max(mi) * softmax_scale + __logf(sum);
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o0_rowcol); ++ni) {
acc_o0_rowcol(mi, ni) *= scale;
acc_o1_rowcol(mi, ni) *= scale;
acc_o2_rowcol(mi, ni) *= scale;
}
}
return lse;
}
template<bool Is_dropout=false, bool Split=false,
typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3>
__forceinline__ __device__ TensorT normalize_softmax_lse(
Tensor0 &acc_o0, Tensor1 &acc_o1, Tensor2 &acc_o2, Tensor3 &acc_o3,
float softmax_scale, float rp_dropout=1.0) {
SumOp<float> sum_op;
#if 1
quad_allreduce_(row_sum, row_sum, sum_op);
#else
quad_allreduce_sum_(row_sum, row_sum, sum_op);
#endif
TensorT lse = make_fragment_like(row_sum);
// === 将四个 acc_o 转换为 rowcol 布局 ===
Tensor acc_o0_rowcol = make_tensor(acc_o0.data(), flash::convert_layout_acc_rowcol(acc_o0.layout()));
Tensor acc_o1_rowcol = make_tensor(acc_o1.data(), flash::convert_layout_acc_rowcol(acc_o1.layout()));
Tensor acc_o2_rowcol = make_tensor(acc_o2.data(), flash::convert_layout_acc_rowcol(acc_o2.layout()));
Tensor acc_o3_rowcol = make_tensor(acc_o3.data(), flash::convert_layout_acc_rowcol(acc_o3.layout()));
static_assert(decltype(size<0>(acc_o0_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size<0>(acc_o0_rowcol); ++mi) {
float sum = row_sum(mi);
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
lse(mi) = (sum == 0.f || sum != sum)
? (Split ? -INFINITY : INFINITY)
: row_max(mi) * softmax_scale + __logf(sum);
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o0_rowcol); ++ni) {
acc_o0_rowcol(mi, ni) *= scale;
acc_o1_rowcol(mi, ni) *= scale;
acc_o2_rowcol(mi, ni) *= scale;
acc_o3_rowcol(mi, ni) *= scale;
}
}
return lse;
}
// ★ Attention Sinks: normalize with precomputed sink LogSumExp ★
template<bool Is_dropout=false, bool Split=false, typename Tensor0, typename TensorSAux>
__forceinline__ __device__ TensorT normalize_softmax_lse_with_sinks(
Tensor0 &acc_o,
TensorSAux const& tSrS_aux,
float softmax_scale,
float softmax_scale_log2,
float rp_dropout=1.0
) {
SumOp<float> sum_op;
#if 1
quad_allreduce_(row_sum, row_sum, sum_op);
#else
quad_allreduce_sum_(row_sum, row_sum, sum_op);
#endif
TensorT lse = make_fragment_like(row_sum);
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
// Handle -INFINITY case for empty sequences
if (row_max(mi) == -INFINITY) { row_max(mi) = 0.f; }
const float max_scaled = row_max(mi) * softmax_scale_log2;
// Compute sink tokens' contribution to softmax denominator
// exp(s_aux - max/√d) = exp2(log2(e) * s_aux - max * log2(e) / √d)
#ifndef M_LOG2E
#define M_LOG2E 1.44269504088896340736
#endif
float sink_contrib = custom_exp2f(float(M_LOG2E) * tSrS_aux(mi) - max_scaled);
float sum = row_sum(mi) + sink_contrib;
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY)
: row_max(mi) * softmax_scale + __logf(sum);
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
acc_o_rowcol(mi, ni) *= scale;
}
}
return lse;
};
// ★ Attention Sinks: normalize with precomputed sink LogSumExp (with tail for VLLM) ★
template<bool Is_dropout=false, bool Split=false, typename Tensor0, typename Tensor1, typename TensorSAux>
__forceinline__ __device__ TensorT normalize_softmax_lse_with_sinks_tail(
Tensor0 &acc_o,
Tensor1 &acc_o_tail,
TensorSAux const& tSrS_aux,
float softmax_scale,
float softmax_scale_log2,
float rp_dropout=1.0
) {
SumOp<float> sum_op;
#if 1
quad_allreduce_(row_sum, row_sum, sum_op);
#else
quad_allreduce_sum_(row_sum, row_sum, sum_op);
#endif
TensorT lse = make_fragment_like(row_sum);
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
Tensor acc_o_tail_rowcol = make_tensor(acc_o_tail.data(), flash::convert_layout_acc_rowcol(acc_o_tail.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
// Handle -INFINITY case for empty sequences
if (row_max(mi) == -INFINITY) { row_max(mi) = 0.f; }
const float max_scaled = row_max(mi) * softmax_scale_log2;
// Compute sink tokens' contribution to softmax denominator
// exp(s_aux - max/√d) = exp2(log2(e) * s_aux - max * log2(e) / √d)
#ifndef M_LOG2E
#define M_LOG2E 1.44269504088896340736
#endif
float sink_contrib = custom_exp2f(float(M_LOG2E) * tSrS_aux(mi) - max_scaled);
float sum = row_sum(mi) + sink_contrib;
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY)
: row_max(mi) * softmax_scale + __logf(sum);
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
acc_o_rowcol(mi, ni) *= scale;
}
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_tail_rowcol); ++ni) {
acc_o_tail_rowcol(mi, ni) *= scale;
}
}
return lse;
};
template<bool Is_dropout=false, bool Split=false, typename Tensor0>
__forceinline__ __device__ TensorT normalize_softmax_lse_fp8(Tensor0 &acc_o, float softmax_scale, float v_descale,float rp_dropout=1.0) {
SumOp<float> sum_op;
#if 1
quad_allreduce_(row_sum, row_sum, sum_op);
#else
quad_allreduce_sum_(row_sum, row_sum, sum_op);
#endif
TensorT lse = make_fragment_like(row_sum);
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
float sum = row_sum(mi);
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : v_descale / sum;
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
}
return lse;
};
template<bool Is_dropout=false, bool Split=false, typename Tensor0, typename Tensor1>
__forceinline__ __device__ TensorT normalize_softmax_lse_fp8(Tensor0 &acc_o, Tensor1 &acc_o_tail, float softmax_scale, float v_descale,float rp_dropout=1.0) {
SumOp<float> sum_op;
#if 1
quad_allreduce_(row_sum, row_sum, sum_op);
#else
quad_allreduce_sum_(row_sum, row_sum, sum_op);
#endif
TensorT lse = make_fragment_like(row_sum);
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
Tensor acc_o_tail_rowcol = make_tensor(acc_o_tail.data(), flash::convert_layout_acc_rowcol(acc_o_tail.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
float sum = row_sum(mi);
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : v_descale / sum;
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
for (int ni = 0; ni < size<1>(acc_o_tail_rowcol); ++ni) { acc_o_tail_rowcol(mi, ni) *= scale; }
}
return lse;
};
template<bool Is_dropout=false, bool Split=false,
typename Tensor0, typename Tensor1, typename Tensor2>
__forceinline__ __device__ TensorT normalize_softmax_lse_fp8(
Tensor0 &acc_o0, Tensor1 &acc_o1, Tensor2 &acc_o2,
float softmax_scale, float v_scale=1.0, float rp_dropout=1.0) {
SumOp<float> sum_op;
#if 1
quad_allreduce_(row_sum, row_sum, sum_op);
#else
quad_allreduce_sum_(row_sum, row_sum, sum_op);
#endif
TensorT lse = make_fragment_like(row_sum);
// === 将四个 acc_o 转换为 rowcol 布局 ===
Tensor acc_o0_rowcol = make_tensor(acc_o0.data(), flash::convert_layout_acc_rowcol(acc_o0.layout()));
Tensor acc_o1_rowcol = make_tensor(acc_o1.data(), flash::convert_layout_acc_rowcol(acc_o1.layout()));
Tensor acc_o2_rowcol = make_tensor(acc_o2.data(), flash::convert_layout_acc_rowcol(acc_o2.layout()));
static_assert(decltype(size<0>(acc_o0_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size<0>(acc_o0_rowcol); ++mi) {
float sum = row_sum(mi);
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : v_scale / sum;
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o0_rowcol); ++ni) {
acc_o0_rowcol(mi, ni) *= scale;
acc_o1_rowcol(mi, ni) *= scale;
acc_o2_rowcol(mi, ni) *= scale;
}
}
return lse;
}
template<bool Is_dropout=false, bool Split=false,
typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3>
__forceinline__ __device__ TensorT normalize_softmax_lse_fp8(
Tensor0 &acc_o0, Tensor1 &acc_o1, Tensor2 &acc_o2, Tensor3 &acc_o3,
float softmax_scale, float v_scale=1.0, float rp_dropout=1.0) {
SumOp<float> sum_op;
#if 1
quad_allreduce_(row_sum, row_sum, sum_op);
#else
quad_allreduce_sum_(row_sum, row_sum, sum_op);
#endif
TensorT lse = make_fragment_like(row_sum);
// === 将四个 acc_o 转换为 rowcol 布局 ===
Tensor acc_o0_rowcol = make_tensor(acc_o0.data(), flash::convert_layout_acc_rowcol(acc_o0.layout()));
Tensor acc_o1_rowcol = make_tensor(acc_o1.data(), flash::convert_layout_acc_rowcol(acc_o1.layout()));
Tensor acc_o2_rowcol = make_tensor(acc_o2.data(), flash::convert_layout_acc_rowcol(acc_o2.layout()));
Tensor acc_o3_rowcol = make_tensor(acc_o3.data(), flash::convert_layout_acc_rowcol(acc_o3.layout()));
static_assert(decltype(size<0>(acc_o0_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size<0>(acc_o0_rowcol); ++mi) {
float sum = row_sum(mi);
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : v_scale / sum;
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o0_rowcol); ++ni) {
acc_o0_rowcol(mi, ni) *= scale;
acc_o1_rowcol(mi, ni) *= scale;
acc_o2_rowcol(mi, ni) *= scale;
acc_o3_rowcol(mi, ni) *= scale;
}
}
return lse;
}
};
} // namespace flash
// Inspired by
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
#pragma once
/// @param COND - a boolean expression to switch by
/// @param CONST_NAME - a name given for the constexpr bool variable.
/// @param ... - code to execute for true and false
///
/// Usage:
/// ```
/// BOOL_SWITCH(flag, BoolConst, [&] {
/// some_function<BoolConst>(...);
/// });
/// ```
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr static bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
#ifdef FLASHATTENTION_DISABLE_DROPOUT
#define DROPOUT_SWITCH(COND, CONST_NAME, ...) \
[&] { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
}()
#else
#define DROPOUT_SWITCH BOOL_SWITCH
#endif
#ifdef FLASHATTENTION_DISABLE_ALIBI
#define ALIBI_SWITCH(COND, CONST_NAME, ...) \
[&] { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
}()
#else
#define ALIBI_SWITCH BOOL_SWITCH
#endif
#ifdef FLASHATTENTION_DISABLE_UNEVEN_K
#define EVENK_SWITCH(COND, CONST_NAME, ...) \
[&] { \
constexpr static bool CONST_NAME = true; \
return __VA_ARGS__(); \
}()
#else
#define EVENK_SWITCH BOOL_SWITCH
#endif
#ifdef FLASHATTENTION_DISABLE_SOFTCAP
#define SOFTCAP_SWITCH(COND, CONST_NAME, ...) \
[&] { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
}()
#else
#define SOFTCAP_SWITCH BOOL_SWITCH
#endif
#ifdef FLASHATTENTION_DISABLE_LOCAL
#define LOCAL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
}()
#else
#define LOCAL_SWITCH BOOL_SWITCH
#endif
#define FP16_SWITCH(COND, ...) \
[&] { \
if (COND) { \
using elem_type = cutlass::half_t; \
return __VA_ARGS__(); \
} else { \
using elem_type = cutlass::bfloat16_t; \
return __VA_ARGS__(); \
} \
}()
#define FP8_SWITCH(COND, ...) \
[&] { \
if (COND) { \
using elem_type_fp8 = cutlass::float_e4m3_t;\
return __VA_ARGS__(); \
} else { \
using elem_type_fp8 = cutlass::float_e5m2_t; \
return __VA_ARGS__(); \
} \
}()
#define HEADDIM_SWITCH(HEADDIM, ...) \
[&] { \
if (HEADDIM <= 32) { \
constexpr static int kHeadDim = 32; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 64) { \
constexpr static int kHeadDim = 64; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 96) { \
constexpr static int kHeadDim = 96; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 128) { \
constexpr static int kHeadDim = 128; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 160) { \
constexpr static int kHeadDim = 160; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 192) { \
constexpr static int kHeadDim = 192; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 224) { \
constexpr static int kHeadDim = 224; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 256) { \
constexpr static int kHeadDim = 256; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 512) { \
constexpr static int kHeadDim = 512; \
return __VA_ARGS__(); \
} \
}()
#define HEADDIM_SWITCH_FP8(HEADDIM, ...) \
[&] { \
if (HEADDIM <= 64) { \
constexpr static int kHeadDim = 64; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 128) { \
constexpr static int kHeadDim = 128; \
return __VA_ARGS__(); \
}else if (HEADDIM <= 192) { \
constexpr static int kHeadDim = 192; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 256) { \
constexpr static int kHeadDim = 256; \
return __VA_ARGS__(); \
} \
}()
// #define HEADDIM_SWITCH_SLA HEADDIM_SWITCH_FP8
#define HEADDIM_SWITCH_SLA(HEADDIM, ...) \
[&] { \
if (HEADDIM <= 64) { \
constexpr static int kHeadDim = 64; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 128) { \
constexpr static int kHeadDim = 128; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 256) { \
constexpr static int kHeadDim = 256; \
return __VA_ARGS__(); \
} \
}()
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
#include <cuda_fp16.h>
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
#endif
#include <cute/tensor.hpp>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
// #include <cutlass/arch/memory_buffer.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace flash {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<const int COUNT>
__forceinline__ __device__ void s_nop() {
asm volatile("s_nop %0":: "B"(COUNT) :);
}
__forceinline__ __device__ void s_barrier() {
asm volatile("s_barrier");
}
template<const int COUNT>
__forceinline__ __device__ void s_waitcnt() {
asm volatile(
"s_waitcnt vmcnt(%0)\n\t"
"s_barrier\n"
:: "B"(COUNT)
:);
}
template<const int COUNT>
__forceinline__ __device__ void s_waitcnt_nosync() {
asm volatile(
"s_waitcnt vmcnt(%0)\n\t"
:: "B"(COUNT)
:);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
__forceinline__ __device__ uint32_t relu2(const uint32_t x);
template<>
__forceinline__ __device__ uint32_t relu2<cutlass::half_t>(const uint32_t x) {
uint32_t res;
const uint32_t zero = 0u;
#ifdef __HIP_DEVICE_COMPILE__
// 暂时不使用ptx指令,后续优化点
const auto x_p = reinterpret_cast<const cutlass::half_t*>(&x);
auto res_p = reinterpret_cast<cutlass::half_t*>(&res);
res_p[0] = (x_p[0] >= cutlass::half_t(0)) ? x_p[0] : cutlass::half_t(0);
res_p[1] = (x_p[1] >= cutlass::half_t(0)) ? x_p[1] : cutlass::half_t(0);
#else
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero));
#else
// asm volatile( \
// "{\n" \
// "\t .reg .f16x2 sela;\n" \
// "\t set.gtu.u32.f16x2 sela, %1, %2;\n" \
// "\t and.b32 %0, sela, %1;\n"
// "}\n" : "=r"(res) : "r"(x), "r"(zero));
#endif
#endif
return res;
}
template<>
__forceinline__ __device__ uint32_t relu2<cutlass::bfloat16_t>(const uint32_t x) {
uint32_t res;
const uint32_t zero = 0u;
#ifdef __HIP_DEVICE_COMPILE__
// 暂时不使用ptx指令,后续优化点
const auto x_p = reinterpret_cast<const cutlass::bfloat16_t*>(&x);
auto res_p = reinterpret_cast<cutlass::bfloat16_t*>(&res);
res_p[0] = (x_p[0] >= cutlass::bfloat16_t(0)) ? x_p[0] : cutlass::bfloat16_t(0);
res_p[1] = (x_p[1] >= cutlass::bfloat16_t(0)) ? x_p[1] : cutlass::bfloat16_t(0);
#else
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero));
#endif
#endif
return res;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template<typename T>
__forceinline__ __device__ uint32_t convert_relu2(const float2 x);
template<>
__forceinline__ __device__ uint32_t convert_relu2<cutlass::half_t>(const float2 x) {
uint32_t res;
const uint32_t a = reinterpret_cast<const uint32_t&>(x.x);
const uint32_t b = reinterpret_cast<const uint32_t&>(x.y);
// asm volatile("cvt.rn.relu.f16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a));
return res;
}
template<>
__forceinline__ __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) {
uint32_t res;
const uint32_t a = reinterpret_cast<const uint32_t&>(x.x);
const uint32_t b = reinterpret_cast<const uint32_t&>(x.y);
// asm volatile("cvt.rn.relu.bf16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a));
return res;
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct MaxOp {
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; }
};
template <>
struct MaxOp<float> {
// This is slightly faster
__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct SumOp {
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int THREADS>
struct Allreduce {
static_assert(THREADS == 64 || THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4 || THREADS == 2);
template<typename T, typename Operator>
static __device__ __forceinline__ T run(T x, Operator &op) {
constexpr int OFFSET = THREADS / 2;
x = op(x, __shfl_xor(x, OFFSET, 64));
return Allreduce<OFFSET>::run(x, op);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Allreduce<1> {
// static_assert(THREADS == 64 || THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4 || THREADS == 2);
template<typename T, typename Operator>
static __device__ __forceinline__ T run(T x, Operator &op) {
return x;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Allreduce<32> {
template<typename T, typename Operator>
static __device__ __forceinline__ T run(T x, Operator &op) {
x = op(x, __shfl_xor(x, 16, 64));
return x;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool A_in_regs=false, bool B_in_regs=false, typename Tensor0, typename Tensor1,
typename Tensor2, typename Tensor3, typename Tensor4,
typename TiledMma, typename TiledCopyA, typename TiledCopyB,
typename ThrCopyA, typename ThrCopyB>
__forceinline__ __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA,
Tensor4 const& tCsB, TiledMma tiled_mma,
TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B,
ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA);
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
if(!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); }
if(!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); }
#pragma unroll
for(int i = 0; i < size<2>(tCrA); ++i) {
if(i < size<2>(tCrA) - 1) {
if(!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); }
if(!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); }
}
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
if (i == 0) {
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
}
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_sched_barrier(0);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
typename TiledMma, typename TiledCopy, typename ThrCopy>
__forceinline__ __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
ThrCopy smem_thr_copy_B) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));
#pragma unroll
for (int i = 0; i < size<2>(tCrA); ++i) {
if (i < size<2>(tCrA) - 1) {
cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));
}
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
if (i == 0) {
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
}
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_sched_barrier(0);
}
template<typename Layout>
__forceinline__ __device__ auto convert_layout_acc_B(Layout acc_layout) {
static_assert(decltype(size<0>(acc_layout))::value == 16);
// static_assert(decltype(size<2>(acc_layout))::value == 1);
static_assert(decltype(rank(acc_layout))::value == 3);
// return make_layout(get<0>(get<0>(acc_layout)), get<1>(acc_layout), get<1>(get<0>(acc_layout)));
return make_layout(get<0>(get<0>(acc_layout)), make_layout(get<1>(get<0>(acc_layout)), get<1>(acc_layout)), get<2>(acc_layout));
};
template<typename Element, typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
typename TiledMma, typename TiledCopy, typename ThrCopy>
__forceinline__ __device__ void gemm_rs_pad(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
ThrCopy smem_thr_copy_B, int max_mn) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));
#pragma unroll
for (int i = 0; i < size<2>(tCrA); ++i) {
if (i < size<2>(tCrA) - 1) {
cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));
}
auto tCrB_ = make_tensor(tCrB.data(), convert_layout_acc_B(tCrB.layout()));
int col = i * 16 + ((threadIdx.x % 64) / 16) * 4;
for (int j = 0; j < size<0>(tCrB_); j++) {
for (int k = 0; k < size<1>(tCrB_); k++) {
tCrB_(j, k, i) = col + j >= max_mn ? Element(0.0f) : tCrB_(j, k, i);
}
}
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
if (i == 0) {
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
}
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_sched_barrier(0);
}
// template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
// typename TiledMma, typename TiledCopy, typename ThrCopy>
// __forceinline__ __device__ void gemm_rs_debug__(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
// TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
// ThrCopy smem_thr_copy_B) {
// CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
// CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
// CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
// Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
// CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
// cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));
// if (block0())
// {
// printf("tidx = %d %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f \n", threadIdx.x,
// float(tCrB(0, 0, 0)),
// float(tCrB(1, 0, 0)),
// float(tCrB(2, 0, 0)),
// float(tCrB(3, 0, 0)),
// float(tCrB(4, 0, 0)),
// float(tCrB(5, 0, 0)),
// float(tCrB(6, 0, 0)),
// float(tCrB(7, 0, 0))
// );
// }
// // #pragma unroll
// // for (int i = 0; i < size<2>(tCrA); ++i) {
// // if (i < size<2>(tCrA) - 1) {
// // cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));
// // }
// // cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
// // if (i == 0) {
// // __builtin_amdgcn_sched_barrier(0);
// // __builtin_amdgcn_s_setprio(1);
// // __builtin_amdgcn_sched_barrier(0);
// // }
// // }
// __builtin_amdgcn_sched_barrier(0);
// __builtin_amdgcn_s_setprio(0);
// __builtin_amdgcn_sched_barrier(0);
// }
template<typename Element,typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
typename TiledMma, typename TiledCopy, typename ThrCopy>
__forceinline__ __device__ void gemm_k_rs_pad_ws(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
ThrCopy smem_thr_copy_B, int k_idx, int Max_Mn) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
// __builtin_amdgcn_sched_barrier(0);
// __builtin_amdgcn_s_setprio(0);
// __builtin_amdgcn_sched_barrier(0);
// using From_type = typename Tensor0::Engine::value_type;
int tidx = threadIdx.x;
// __builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_B, tCsB(_, _, k_idx), tCrB_copy_view(_, _, k_idx));
// __builtin_amdgcn_sched_barrier(0);
int need_pad_k_idx = Max_Mn / 16;
if (need_pad_k_idx == k_idx) {
auto tCrB_ = make_tensor(tCrB.data(), convert_layout_acc_B(tCrB.layout()));
for (int ni = 0; ni < size<1>(tCrB_); ni++) {
int col = k_idx * 16 + ((tidx % 64) / 16) * 4;
for (int ei = 0; ei < size<0>(tCrB_); ei++) {
tCrB_(ei, ni, k_idx) = col + ei >= Max_Mn ? Element(0) : tCrB_(ei, ni, k_idx);
}
}
}
cute::gemm(tiled_mma, tCrA(_, _, k_idx), tCrB(_, _, k_idx), acc);
}
template<typename Element,typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
typename TiledMma, typename TiledCopy, typename ThrCopy>
__forceinline__ __device__ void gemm_k_rs_pad(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
ThrCopy smem_thr_copy_B, int k_idx, int Max_Mn) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
// __builtin_amdgcn_sched_barrier(0);
// __builtin_amdgcn_s_setprio(0);
// __builtin_amdgcn_sched_barrier(0);
// using From_type = typename Tensor0::Engine::value_type;
int tidx = threadIdx.x;
cute::copy(smem_tiled_copy_B, tCsB(_, _, k_idx), tCrB_copy_view(_, _, k_idx));
int need_pad_k_idx = Max_Mn / 16;
int round_4 = Max_Mn % 4;
if (need_pad_k_idx == k_idx && round_4 != 0) {
auto tCrB_ = make_tensor(tCrB.data(), convert_layout_acc_B(tCrB.layout()));
for (int ni = 0; ni < size<1>(tCrB_); ni++)
{
int col = k_idx * 16 + ((tidx % 64) / 16) * 4;
for (int ei = 0; ei < size<0>(tCrB_); ei++)
{
tCrB_(ei, ni, k_idx) = col + ei >= Max_Mn ? Element(0) : tCrB_(ei, ni, k_idx);
}
}
}
// __builtin_amdgcn_sched_barrier(0);
// __builtin_amdgcn_s_setprio(1);
// __builtin_amdgcn_sched_barrier(0);
cute::gemm(tiled_mma, tCrA(_, _, k_idx), tCrB(_, _, k_idx), acc);
}
template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
typename TiledMma, typename TiledCopy, typename ThrCopy>
__forceinline__ __device__ void gemm_k_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
ThrCopy smem_thr_copy_B, int k_idx) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
// __builtin_amdgcn_sched_barrier(0);
// __builtin_amdgcn_s_setprio(0);
// __builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_B, tCsB(_, _, k_idx), tCrB_copy_view(_, _, k_idx));
// __builtin_amdgcn_sched_barrier(0);
// __builtin_amdgcn_s_setprio(1);
// __builtin_amdgcn_sched_barrier(0);
cute::gemm(tiled_mma, tCrA(_, _, k_idx), tCrB(_, _, k_idx), acc);
}
template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
typename TiledMma, typename TiledCopy, typename ThrCopy>
__forceinline__ __device__ void gemm_k_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
ThrCopy smem_thr_copy_B, int kA_idx, int kB_idx) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
// CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
cute::copy(smem_tiled_copy_B, tCsB(_, _, kB_idx), tCrB_copy_view(_, _, kB_idx));
cute::gemm(tiled_mma, tCrA(_, _, kA_idx), tCrB(_, _, kB_idx), acc);
}
template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
typename TiledMma, typename TiledCopy, typename ThrCopy>
__forceinline__ __device__ void gemm_k_rs_debug(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
ThrCopy smem_thr_copy_B, int k_idx) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
cute::copy(smem_tiled_copy_B, tCsB(_, _, k_idx), tCrB_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tCrA(_, _, k_idx), tCrB(_, _, k_idx), acc);
int tidx = threadIdx.x;
printf("tid:%d k_idx:%d tCrA:%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
"tCrB:%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
"acc:%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
"%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
"%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
"%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f \n", tidx, k_idx,
float(tCrA(0, 0, k_idx)), float(tCrA(1, 0, k_idx)), float(tCrA(2, 0, k_idx)), float(tCrA(3, 0, k_idx)),
float(tCrA(4, 0, k_idx)), float(tCrA(5, 0, k_idx)), float(tCrA(6, 0, k_idx)), float(tCrA(7, 0, k_idx)),
float(tCrB(0, 0, k_idx)), float(tCrB(1, 0, k_idx)), float(tCrB(2, 0, k_idx)), float(tCrB(3, 0, k_idx)),
float(tCrB(4, 0, k_idx)), float(tCrB(5, 0, k_idx)), float(tCrB(6, 0, k_idx)), float(tCrB(7, 0, k_idx)),
acc(0), acc(1), acc(2), acc(3), acc(4), acc(5), acc(6), acc(7),
acc(8), acc(9), acc(10), acc(11), acc(12), acc(13), acc(14), acc(15),
acc(16), acc(17), acc(18), acc(19), acc(20), acc(21), acc(22), acc(23),
acc(24), acc(25), acc(26), acc(27), acc(28), acc(29), acc(30), acc(31)
);
}
template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
typename TiledMma, typename TiledCopy, typename ThrCopy>
__forceinline__ __device__ void gemm_k_rs_debug(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
ThrCopy smem_thr_copy_B, int kA_idx, int kB_idx) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
// CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
cute::copy(smem_tiled_copy_B, tCsB(_, _, kB_idx), tCrB_copy_view(_, _, kB_idx));
cute::gemm(tiled_mma, tCrA(_, _, kA_idx), tCrB(_, _, kB_idx), acc);
int tidx = threadIdx.x;
printf("tid:%d kA_idx:%d kB_idx:%d tCrA:%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
"tCrB:%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
"%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
"%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
"%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
"acc:%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
"%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
"%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
"%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f \n", tidx, kA_idx, kB_idx,
float(tCrA(0, 0, kA_idx)), float(tCrA(1, 0, kA_idx)), float(tCrA(2, 0, kA_idx)), float(tCrA(3, 0, kA_idx)),
float(tCrA(4, 0, kA_idx)), float(tCrA(5, 0, kA_idx)), float(tCrA(6, 0, kA_idx)), float(tCrA(7, 0, kA_idx)),
float(tCrB(0, 0, kB_idx)), float(tCrB(1, 0, kB_idx)), float(tCrB(2, 0, kB_idx)), float(tCrB(3, 0, kB_idx)),
float(tCrB(4, 0, kB_idx)), float(tCrB(5, 0, kB_idx)), float(tCrB(6, 0, kB_idx)), float(tCrB(7, 0, kB_idx)),
float(tCrB(8, 0, kB_idx)), float(tCrB(9, 0, kB_idx)), float(tCrB(10, 0, kB_idx)), float(tCrB(11, 0, kB_idx)),
float(tCrB(12, 0, kB_idx)), float(tCrB(13, 0, kB_idx)), float(tCrB(14, 0, kB_idx)), float(tCrB(15, 0, kB_idx)),
float(tCrB(16, 0, kB_idx)), float(tCrB(17, 0, kB_idx)), float(tCrB(18, 0, kB_idx)), float(tCrB(19, 0, kB_idx)),
float(tCrB(20, 0, kB_idx)), float(tCrB(21, 0, kB_idx)), float(tCrB(22, 0, kB_idx)), float(tCrB(23, 0, kB_idx)),
float(tCrB(24, 0, kB_idx)), float(tCrB(25, 0, kB_idx)), float(tCrB(26, 0, kB_idx)), float(tCrB(27, 0, kB_idx)),
float(tCrB(28, 0, kB_idx)), float(tCrB(29, 0, kB_idx)), float(tCrB(30, 0, kB_idx)), float(tCrB(31, 0, kB_idx)),
acc(0), acc(1), acc(2), acc(3), acc(4), acc(5), acc(6), acc(7),
acc(8), acc(9), acc(10), acc(11), acc(12), acc(13), acc(14), acc(15),
acc(16), acc(17), acc(18), acc(19), acc(20), acc(21), acc(22), acc(23),
acc(24), acc(25), acc(26), acc(27), acc(28), acc(29), acc(30), acc(31)
);
}
template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
typename TiledMma, typename TiledCopy, typename ThrCopy>
__forceinline__ __device__ void gemm_rs_swait(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
ThrCopy smem_thr_copy_B) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));
asm volatile("s_waitcnt lgkmcnt(0)\n\t" : :);
#pragma unroll
for (int i = 0; i < size<2>(tCrA); ++i) {
if (i < size<2>(tCrA) - 1) {
cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));
asm volatile("s_waitcnt lgkmcnt(0)\n\t" : :);
}
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
if (i == 0) {
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
}
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_sched_barrier(0);
}
template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
typename TiledMma, typename TiledCopy, typename ThrCopy>
__forceinline__ __device__ void gemm_k_rs_swait(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
ThrCopy smem_thr_copy_B, int k_idx) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
cute::copy(smem_tiled_copy_B, tCsB(_, _, k_idx), tCrB_copy_view(_, _, k_idx));
asm volatile("s_waitcnt lgkmcnt(0)");
// int tidx = threadIdx.x;
// printf("tid:%d k_idx:%d tCrA:%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
// "%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
// "tCrB:%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f\n", tidx, k_idx,
// float(tCrA(0, 0, k_idx)), float(tCrA(1, 0, k_idx)), float(tCrA(2, 0, k_idx)), float(tCrA(3, 0, k_idx)),
// float(tCrA(4, 0, k_idx)), float(tCrA(5, 0, k_idx)), float(tCrA(6, 0, k_idx)), float(tCrA(7, 0, k_idx)),
// float(tCrA(0, 1, k_idx)), float(tCrA(1, 1, k_idx)), float(tCrA(2, 1, k_idx)), float(tCrA(3, 1, k_idx)),
// float(tCrA(4, 1, k_idx)), float(tCrA(5, 1, k_idx)), float(tCrA(6, 1, k_idx)), float(tCrA(7, 1, k_idx)),
// float(tCsB(0, 0, k_idx)), float(tCsB(1, 0, k_idx)), float(tCsB(2, 0, k_idx)), float(tCsB(3, 0, k_idx)),
// float(tCsB(4, 0, k_idx)), float(tCsB(5, 0, k_idx)), float(tCsB(6, 0, k_idx)), float(tCsB(7, 0, k_idx))
// );
cute::gemm(tiled_mma, tCrA(_, _, k_idx), tCrB(_, _, k_idx), acc);
}
template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
typename TiledMma, typename TiledCopy, typename ThrCopy>
__forceinline__ __device__ void gemm_rs_debug(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
ThrCopy smem_thr_copy_B) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));
// wangaq debug
// if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
// int offset = reinterpret_cast<const char *>(&tCsB(0, 0, _0{})) - (char *)(0x1000000000000);
// printf("tid:%d i:0 tCsB:%p %p %p %p "
// "%p %p %p %p "
// "offset:%d row:%d col:%d\n", threadIdx.x,
// &tCsB(0, 0, _0{}), &tCsB(1, 0, _0{}), &tCsB(2, 0, _0{}), &tCsB(3, 0, _0{}),
// &tCsB(4, 0, _0{}), &tCsB(5, 0, _0{}), &tCsB(6, 0, _0{}), &tCsB(7, 0, _0{}),
// offset, offset/128, (offset % 128)/16);
// }
#pragma unroll
for (int i = 0; i < size<2>(tCrA); ++i) {
if (i < size<2>(tCrA) - 1) {
cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));
// wangaq debug
// if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
// int offset = reinterpret_cast<const char *>(&tCsB(0, 0, _0{})) - (char *)(0x1000000000000);
// printf("tid:%d i:%d tCsB:%p %p %p %p "
// "%p %p %p %p "
// "offset:%d row:%d col:%d\n", threadIdx.x, i + 1,
// &tCsB(0, 0, i + 1), &tCsB(1, 0, i + 1), &tCsB(2, 0, i + 1), &tCsB(3, 0, i + 1),
// &tCsB(4, 0, i + 1), &tCsB(5, 0, i + 1), &tCsB(6, 0, i + 1), &tCsB(7, 0, i + 1),
// offset, offset/128, (offset % 128)/16);
// }
}
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
// wangaq debug
// if(thread0()) {
// printf("i:%d tCrA:%7.4f %7.4f %7.4f %7.4f | %7.4f %7.4f %7.4f %7.4f | %7.4f %7.4f %7.4f %7.4f | %7.4f %7.4f %7.4f %7.4f | "
// "%7.4f %7.4f %7.4f %7.4f | %7.4f %7.4f %7.4f %7.4f | %7.4f %7.4f %7.4f %7.4f | %7.4f %7.4f %7.4f %7.4f \n"
// "tCrB:%7.4f %7.4f %7.4f %7.4f | %7.4f %7.4f %7.4f %7.4f | %7.4f %7.4f %7.4f %7.4f | %7.4f %7.4f %7.4f %7.4f | "
// "%7.4f %7.4f %7.4f %7.4f | %7.4f %7.4f %7.4f %7.4f | %7.4f %7.4f %7.4f %7.4f | %7.4f %7.4f %7.4f %7.4f \n"
// "acc:%7.4f %7.4f %7.4f %7.4f | %7.4f %7.4f %7.4f %7.4f | %7.4f %7.4f %7.4f %7.4f | %7.4f %7.4f %7.4f %7.4f | "
// "%7.4f %7.4f %7.4f %7.4f | %7.4f %7.4f %7.4f %7.4f | %7.4f %7.4f %7.4f %7.4f | %7.4f %7.4f %7.4f %7.4f \n", i,
// float(tCrA(0, 0, i)), float(tCrA(1, 0, i)), float(tCrA(2, 0, i)), float(tCrA(3, 0, i)),
// float(tCrA(4, 0, i)), float(tCrA(5, 0, i)), float(tCrA(6, 0, i)), float(tCrA(7, 0, i)),
// float(tCrA(8, 0, i)), float(tCrA(9, 0, i)), float(tCrA(10, 0, i)), float(tCrA(11, 0, i)),
// float(tCrA(12, 0, i)), float(tCrA(13, 0, i)), float(tCrA(14, 0, i)), float(tCrA(15, 0, i)),
// float(tCrA(0, 1, i)), float(tCrA(1, 1, i)), float(tCrA(2, 1, i)), float(tCrA(3, 1, i)),
// float(tCrA(4, 1, i)), float(tCrA(5, 1, i)), float(tCrA(6, 1, i)), float(tCrA(7, 1, i)),
// float(tCrA(8, 1, i)), float(tCrA(9, 1, i)), float(tCrA(10, 1, i)), float(tCrA(11, 1, i)),
// float(tCrA(12, 1, i)), float(tCrA(13, 1, i)), float(tCrA(14, 1, i)), float(tCrA(15, 1, i)),
// float(tCrB(0, 0, i)), float(tCrB(1, 0, i)), float(tCrB(2, 0, i)), float(tCrB(3, 0, i)),
// float(tCrB(4, 0, i)), float(tCrB(5, 0, i)), float(tCrB(6, 0, i)), float(tCrB(7, 0, i)),
// float(tCrB(0, 1, i)), float(tCrB(1, 1, i)), float(tCrB(2, 1, i)), float(tCrB(3, 1, i)),
// float(tCrB(4, 1, i)), float(tCrB(5, 1, i)), float(tCrB(6, 1, i)), float(tCrB(7, 1, i)),
// float(tCrB(0, 2, i)), float(tCrB(1, 2, i)), float(tCrB(2, 2, i)), float(tCrB(3, 2, i)),
// float(tCrB(4, 2, i)), float(tCrB(5, 2, i)), float(tCrB(6, 2, i)), float(tCrB(7, 2, i)),
// float(tCrB(0, 3, i)), float(tCrB(1, 3, i)), float(tCrB(2, 3, i)), float(tCrB(3, 3, i)),
// float(tCrB(4, 3, i)), float(tCrB(5, 3, i)), float(tCrB(6, 3, i)), float(tCrB(7, 3, i)),
// acc(0), acc(1), acc(2), acc(3), acc(4), acc(5), acc(6), acc(7),
// acc(8), acc(9), acc(10), acc(11), acc(12), acc(13), acc(14), acc(15),
// acc(16), acc(17), acc(18), acc(19), acc(20), acc(21), acc(22), acc(23),
// acc(24), acc(25), acc(26), acc(27), acc(28), acc(29), acc(30), acc(31)
// );
// }
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
__forceinline__ __device__ void gemm_rr(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, TiledMma tiled_mma) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
#pragma unroll
for (int i = 0; i < size<2>(tCrA); ++i) {
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
if (i == 0) {
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
}
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_sched_barrier(0);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int row, int col, typename Tensor0, typename Tensor1>
__forceinline__ __device__ static void __ds_read_m32x16_row_col(Tensor0& src, Tensor1& dst)
{
auto lds = reinterpret_cast<__fp16 *>(src.data().get());
auto layout = src.layout();
constexpr short offset = layout(0, row, col) * 2;
auto d = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds), offset);
uint16_t * d_ptr = reinterpret_cast<uint16_t*>(&d);
uint16_t * dst_ptr = reinterpret_cast<uint16_t*>(&(dst(0, row, col)));
dst_ptr[0] = d_ptr[0];
dst_ptr[1] = d_ptr[1];
dst_ptr[2] = d_ptr[2];
dst_ptr[3] = d_ptr[3];
dst_ptr[4] = d_ptr[4];
dst_ptr[5] = d_ptr[5];
dst_ptr[6] = d_ptr[6];
dst_ptr[7] = d_ptr[7];
}
template<int k_idx, typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
typename TiledMma, typename TiledCopy, typename ThrCopy>
__forceinline__ __device__ void gemm_k_rs_ds_read_m32x16(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
ThrCopy smem_thr_copy_B) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
auto shape = tCsB.shape();
constexpr int rows = get<1>(shape);
static_assert(rows == 6 || rows == 4 || rows == 3 || rows == 2);
if constexpr (rows == 6) {
__ds_read_m32x16_row_col<0, k_idx>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col<1, k_idx>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col<2, k_idx>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col<3, k_idx>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col<4, k_idx>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col<5, k_idx>(tCsB, tCrB_copy_view);
} else if constexpr (rows == 4) {
__ds_read_m32x16_row_col<0, k_idx>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col<1, k_idx>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col<2, k_idx>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col<3, k_idx>(tCsB, tCrB_copy_view);
} else if constexpr (rows == 3) {
__ds_read_m32x16_row_col<0, k_idx>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col<1, k_idx>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col<2, k_idx>(tCsB, tCrB_copy_view);
}
else if constexpr (rows == 2) {
__ds_read_m32x16_row_col<0, k_idx>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col<1, k_idx>(tCsB, tCrB_copy_view);
}
// cute::copy(smem_tiled_copy_B, tCsB(_, _, k_idx), tCrB_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tCrA(_, _, k_idx), tCrB(_, _, k_idx), acc);
}
template<int k_idxA, int k_idxB, typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
typename TiledMma, typename TiledCopy, typename ThrCopy>
__forceinline__ __device__ void gemm_k_rs_ds_read_m32x16(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
ThrCopy smem_thr_copy_B) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
auto shape = tCsB.shape();
constexpr int rows = get<1>(shape);
static_assert(rows == 6 || rows == 4 || rows == 3 || rows == 2);
if constexpr (rows == 6) {
__ds_read_m32x16_row_col<0, k_idxB>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col<1, k_idxB>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col<2, k_idxB>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col<3, k_idxB>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col<4, k_idxB>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col<5, k_idxB>(tCsB, tCrB_copy_view);
} else if constexpr (rows == 4) {
__ds_read_m32x16_row_col<0, k_idxB>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col<1, k_idxB>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col<2, k_idxB>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col<3, k_idxB>(tCsB, tCrB_copy_view);
} else if constexpr (rows == 3) {
__ds_read_m32x16_row_col<0, k_idxB>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col<1, k_idxB>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col<2, k_idxB>(tCsB, tCrB_copy_view);
}
else if constexpr (rows == 2) {
__ds_read_m32x16_row_col<0, k_idxB>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col<1, k_idxB>(tCsB, tCrB_copy_view);
}
cute::gemm(tiled_mma, tCrA(_, _, k_idxA), tCrB(_, _, k_idxB), acc);
}
template<int k_idx, typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
typename TiledMma, typename TiledCopy, typename ThrCopy>
__forceinline__ __device__ void gemm_k_rs_ds_read_m32x16_debug(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
ThrCopy smem_thr_copy_B) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
auto shape = tCsB.shape();
constexpr int rows = get<1>(shape);
static_assert(rows == 6 || rows == 4 || rows == 3 || rows == 2);
if constexpr (rows == 6) {
__ds_read_m32x16_row_col<0, k_idx>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col<1, k_idx>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col<2, k_idx>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col<3, k_idx>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col<4, k_idx>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col<5, k_idx>(tCsB, tCrB_copy_view);
} else if constexpr (rows == 4) {
__ds_read_m32x16_row_col<0, k_idx>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col<1, k_idx>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col<2, k_idx>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col<3, k_idx>(tCsB, tCrB_copy_view);
} else if constexpr (rows == 3) {
__ds_read_m32x16_row_col<0, k_idx>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col<1, k_idx>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col<2, k_idx>(tCsB, tCrB_copy_view);
}
else if constexpr (rows == 2) {
__ds_read_m32x16_row_col<0, k_idx>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col<1, k_idx>(tCsB, tCrB_copy_view);
}
// cute::copy(smem_tiled_copy_B, tCsB(_, _, k_idx), tCrB_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tCrA(_, _, k_idx), tCrB(_, _, k_idx), acc);
if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
int tidx = threadIdx.x;
printf("tid:%d k_idx:%d tCrA:%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
"tCrB:%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
"acc:%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
"%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
"%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
"%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f \n", tidx, k_idx,
float(tCrA(0, 0, k_idx)), float(tCrA(1, 0, k_idx)), float(tCrA(2, 0, k_idx)), float(tCrA(3, 0, k_idx)),
float(tCrA(4, 0, k_idx)), float(tCrA(5, 0, k_idx)), float(tCrA(6, 0, k_idx)), float(tCrA(7, 0, k_idx)),
float(tCrB(0, 0, k_idx)), float(tCrB(1, 0, k_idx)), float(tCrB(2, 0, k_idx)), float(tCrB(3, 0, k_idx)),
float(tCrB(4, 0, k_idx)), float(tCrB(5, 0, k_idx)), float(tCrB(6, 0, k_idx)), float(tCrB(7, 0, k_idx)),
acc(0), acc(1), acc(2), acc(3), acc(4), acc(5), acc(6), acc(7),
acc(8), acc(9), acc(10), acc(11), acc(12), acc(13), acc(14), acc(15),
acc(16), acc(17), acc(18), acc(19), acc(20), acc(21), acc(22), acc(23),
acc(24), acc(25), acc(26), acc(27), acc(28), acc(29), acc(30), acc(31)
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
template<typename Layout>
__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
// static_assert(decltype(size<0>(acc_layout))::value == 4 || decltype(size<0>(acc_layout))::value == 8);
static_assert(decltype(rank(acc_layout))::value == 3);
auto l = logical_divide(acc_layout, Shape<_1>{}); // (_4,_1,_2):(_1,_0,_4) -> ((_1,_4),_1,_2):((_0,_1),_0,_4)
return make_layout(make_layout(get<1>(l)), make_layout(get<1>(get<0>(l)), get<2>(l))); // (1, (4, 2)):((_0),(_1,_4))
};
template<typename Layout>
__forceinline__ __device__ auto convert_trans_layout_acc_rowcol(Layout acc_layout) {
static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3);
return make_layout(
make_layout(get<0>(acc_layout), get<2>(acc_layout)),
make_layout(get<1>(acc_layout)));
};
template<typename Layout>
__forceinline__ __device__ auto convert_layout_acc(Layout acc_layout) {
static_assert(decltype(size<0>(acc_layout))::value == 16);
// static_assert(decltype(size<2>(acc_layout))::value == 1);
static_assert(decltype(rank(acc_layout))::value == 3);
// return make_layout(get<0>(get<0>(acc_layout)), get<1>(acc_layout), get<1>(get<0>(acc_layout)));
return make_layout(get<0>(get<0>(acc_layout)), get<1>(acc_layout), make_layout(get<1>(get<0>(acc_layout)), get<2>(acc_layout)));
};
template<typename Layout>
__forceinline__ __device__ auto convert_layout_acc_fp8(Layout acc_layout) {
static_assert(decltype(size<0>(acc_layout))::value == 16);
// static_assert(decltype(size<2>(acc_layout))::value == 1);
static_assert(decltype(rank(acc_layout))::value == 3);
// return make_layout(get<0>(get<0>(acc_layout)), get<1>(acc_layout), get<1>(get<0>(acc_layout)));
return make_layout(get<0>(get<0>(acc_layout)), get<1>(acc_layout), make_layout(get<1>(get<0>(acc_layout)), get<2>(acc_layout)));
};
// template<typename Layout>
// __forceinline__ __device__ auto convert_layout_acc_back(Layout acc_layout) {
// using X = Underscore;
// static_assert(decltype(size<0>(acc_layout))::value == 4);
// static_assert(decltype(rank(acc_layout))::value == 3);
// auto l = logical_divide(acc_layout, Shape<X, X, _1>{});
// return make_layout(make_layout(get<0>(l), get<1>(get<2>(l))), get<1>(l), get<0>(get<2>(l)));
// };
////////////////////////////////////////////////////////////////////////////////////////////////////
// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8.
// template<typename MMA_traits, typename Layout>
// __forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) {
// using X = Underscore;
// static_assert(decltype(size<0>(acc_layout))::value == 4);
// static_assert(decltype(rank(acc_layout))::value == 3);
// constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{});
// static_assert(mma_shape_K == 8 || mma_shape_K == 16);
// // if constexpr (mma_shape_K == 8) {
// // return acc_layout;
// // } else {
// // auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
// // return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
// // }
// };
template <class TiledMma,
typename Engine0, typename Layout0,
typename Engine1, typename Layout1
>
__forceinline__ __device__ auto convert_layout_acc_Aregs(const TiledMma& tiled_mma, Tensor<Engine0, Layout0> const& tOrP,
Tensor<Engine1, Layout1> const& sAcc)
{
int tidx = threadIdx.x;
auto thr_mma = tiled_mma.get_thread_slice(tidx);
auto smem_tiled_copy_ACC = make_tiled_copy_C(Copy_Atom<DefaultCopy, cute::half_t>{}, tiled_mma);
auto smem_thr_copy_ACC = smem_tiled_copy_ACC.get_thread_slice(tidx);
Tensor taccOr = smem_thr_copy_ACC.retile_S(tOrP);
Tensor taccOs = smem_thr_copy_ACC.partition_D(sAcc);
// if (cute::thread0())
// { taccOr
// raw_ptr_16b(0x2000000000010) o ((_1,_4),_1,_4):((_0,_1),_0,_4)
// print("taccOr\n"); print(taccOr); print("\n");
// }
cute::copy(smem_tiled_copy_ACC, taccOr, taccOs);
// asm volatile("s_waitcnt lgkmcnt(0)\n\t");
__syncthreads();
auto smem_tiled_copy_A = make_tiled_copy_A(Copy_Atom<DefaultCopy, cute::half_t>{}, tiled_mma);
auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(tidx);
Tensor tSsACC = smem_thr_copy_A.partition_S(sAcc);
Tensor tSrACC = thr_mma.partition_fragment_A(sAcc);
Tensor tSrACC_copy_view = smem_thr_copy_A.retile_D(tSrACC);
cute::copy(smem_tiled_copy_ACC, tSsACC, tSrACC_copy_view);
// asm volatile("s_waitcnt lgkmcnt(0)\n\t");
// __syncthreads(); // 取消这个sync,2024.06.13
return tSrACC;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
template<typename Layout>
__forceinline__ __device__ auto convert_layout_acc_dropout(Layout acc_layout) {
using X = Underscore;
static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3);
// auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
auto l = logical_divide(acc_layout, Shape<X, X, _1>{}); // (4, MMA_M, (1, MMA_N)))
return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); // ((4, 1), 1, 2):((1, 0), 0, 4)
};
template<typename Layout>
__forceinline__ __device__ auto convert_layout_acc_back(Layout acc_layout) {
using X = Underscore;
static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3);
auto l = logical_divide(acc_layout, Shape<X, X, _1>{});
return make_layout(make_layout(get<0>(l), get<1>(get<2>(l))), get<1>(l), get<0>(get<2>(l)));
};
template<typename Layout>
__forceinline__ __device__ auto convert_layout_acc_back_fp8(Layout acc_layout) {
using X = Underscore;
static_assert(decltype(size<0>(acc_layout))::value == 8);
static_assert(decltype(rank(acc_layout))::value == 3);
auto l = logical_divide(acc_layout, Shape<X, X, _1>{});
return make_layout(make_layout(get<0>(l), get<1>(get<2>(l))), get<1>(l), get<0>(get<2>(l)));
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename To_type, typename Engine, typename Layout>
__forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
using From_type = typename Engine::value_type;
constexpr int numel = decltype(size(tensor))::value;
Tensor tensor_To_type = make_tensor<To_type>(layout(tensor));
cutlass::Array<To_type, numel> *result_ptr = reinterpret_cast<cutlass::Array<To_type, numel> *>(tensor_To_type.data());
if constexpr (std::is_same_v<To_type, cutlass::bfloat16_t>) {
#ifndef FLASH_ATTENTION_BF16_TYPE
#define FLASH_ATTENTION_BF16_TYPE 0
#endif
#if FLASH_ATTENTION_BF16_TYPE == 1
cutlass::NumericArrayConverter<To_type, From_type, numel, cutlass::FloatRoundStyle::round_toward_zero> convert_op;
#elif FLASH_ATTENTION_BF16_TYPE == 2
cutlass::NumericArrayConverter<To_type, From_type, numel, cutlass::FloatRoundStyle::round_to_nearest> convert_op;
#else
cutlass::NumericArrayConverter<To_type, From_type, numel, cutlass::FloatRoundStyle::round_half_ulp_truncate> convert_op;
#endif
*result_ptr = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
} else {
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
*result_ptr = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
}
return tensor_To_type;
// cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
// // HACK: this requires tensor to be "contiguous"
// auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
// return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
}
template <typename To_type, typename Engine, typename Layout>
__forceinline__ __device__ auto convert_type_fp8(Tensor<Engine, Layout> const &tensor) {
using From_type = typename Engine::value_type;
if constexpr (std::is_same_v<To_type, From_type>)
{
return tensor;
}
constexpr int numel = decltype(size(tensor))::value;
Tensor tensor_To_type = make_tensor<To_type>(layout(tensor));
cutlass::Array<To_type, numel> *result_ptr = reinterpret_cast<cutlass::Array<To_type, numel> *>(tensor_To_type.data());
if constexpr (std::is_same_v<To_type, cutlass::bfloat16_t>) {
cutlass::NumericArrayConverter<To_type, From_type, numel, cutlass::FloatRoundStyle::round_to_nearest> convert_op;
*result_ptr = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
}
else if constexpr (std::is_same_v<To_type, cutlass::float_e4m3_t>) {
cutlass::NumericArrayConverter<To_type, From_type, numel,cutlass::FloatRoundStyle::round_to_nearest> convert_op;
*result_ptr = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
}
else if constexpr (std::is_same_v<To_type, cutlass::float_e5m2_t>) {
cutlass::NumericArrayConverter<To_type, From_type, numel,cutlass::FloatRoundStyle::round_to_nearest> convert_op;
*result_ptr = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
}
else {
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
*result_ptr = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
}
return tensor_To_type;
}
template <typename To_type, typename From_type>
__forceinline__ __device__ auto convert_type(From_type const &from) {
if constexpr (std::is_same_v<To_type, cutlass::bfloat16_t>) {
#ifndef FLASH_ATTENTION_BF16_TYPE
#define FLASH_ATTENTION_BF16_TYPE 0
#endif
#if FLASH_ATTENTION_BF16_TYPE == 1
cutlass::NumericConverter<To_type, From_type, cutlass::FloatRoundStyle::round_toward_zero> convert_;
#elif FLASH_ATTENTION_BF16_TYPE == 2
cutlass::NumericConverter<To_type, From_type, cutlass::FloatRoundStyle::round_to_nearest> convert_;
#else
cutlass::NumericConverter<To_type, From_type, cutlass::FloatRoundStyle::round_half_ulp_truncate> convert_;
#endif
return convert_(from);
} else {
cutlass::NumericConverter<To_type, From_type> convert_;
return convert_(from);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Engine, typename Layout>
__forceinline__ __device__ void relu_(Tensor<Engine, Layout> &tensor) {
constexpr int numel = decltype(size(tensor))::value;
static_assert(numel % 2 == 0);
using value_t = typename Engine::value_type;
// HACK: this requires tensor to be "contiguous"
Tensor tensor_uint32 = recast<uint32_t>(tensor);
#pragma unroll
for (int i = 0; i < size(tensor_uint32); ++i) {
tensor_uint32(i) = relu2<value_t>(tensor_uint32(i));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction
template <typename To_type, typename Engine, typename Layout>
__forceinline__ __device__ auto convert_type_relu(Tensor<Engine, Layout> const &tensor) {
using From_type = typename Engine::value_type;
static_assert(std::is_same_v<To_type, cutlass::half_t> || std::is_same_v<To_type, cutlass::bfloat16_t>);
static_assert(std::is_same_v<float, From_type>);
constexpr int numel = decltype(size(tensor))::value;
static_assert(numel % 2 == 0);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// HACK: this requires tensor to be "contiguous"
Tensor tensor_float2 = recast<float2>(tensor);
Tensor out_uint32 = make_tensor<uint32_t>(tensor_float2.layout());
#pragma unroll
for (int i = 0; i < size(out_uint32); ++i) {
out_uint32(i) = convert_relu2<To_type>(tensor_float2(i));
}
Tensor out = make_tensor(make_rmem_ptr<To_type>(out_uint32.data()), tensor.layout());
#else
Tensor out = flash::convert_type<To_type>(tensor);
flash::relu_(out);
#endif
return out;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Blocks until all but N previous cp.async.commit_group operations have committed.
// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all
// (which is equivalent to commit_group then wait_group 0).
// Instead we just call cp.async.wait_group 0, which is slightly faster.
// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113
template <int N>
CUTE_HOST_DEVICE
void cp_async_wait() {
#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
// asm volatile("cp.async.wait_group %0;\n" :: "n"(N));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
// There's no case where !Clear_OOB_K && Clear_OOB_MN
static_assert(!(Clear_OOB_MN && !Clear_OOB_K));
#pragma unroll
for (int m = 0; m < size<1>(S); ++m) {
if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
#pragma unroll
for (int k = 0; k < size<2>(S); ++k) {
if (Is_even_K || predicate_K(k)) {
cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
} else if (Clear_OOB_K) {
cute::clear(D(_, m, k));
}
}
} else if (Clear_OOB_MN) {
cute::clear(D(_, m, _));
}
}
}
template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
__forceinline__ __device__ void copy_v(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
// There's no case where !Clear_OOB_K && Clear_OOB_MN
static_assert(!(Clear_OOB_MN && !Clear_OOB_K));
#pragma unroll
for (int m = 0; m < size<1>(S); ++m) {
// if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
// #pragma unroll
// for (int k = 0; k < size<2>(S); ++k) {
// if (Is_even_K || predicate_K(k)) {
// cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
// } else if (Clear_OOB_K) {
// cute::clear(D(_, m, k));
// }
// }
// }
// else if (Clear_OOB_MN) {
// cute::clear(D(_, m, _));
// }
if (Is_even_K || predicate_K(m)) {
#pragma unroll
for (int k = 0; k < size<2>(S); ++k) {
if (Is_even_MN || get<0>(identity_MN(0, 0, k)) < max_MN) {
cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
} else if (Clear_OOB_K) {
cute::clear(D(_, m, k));
}
}
}
else if (Clear_OOB_MN) {
cute::clear(D(_, m, _));
}
}
}
template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
__forceinline__ __device__ void copy_k_idx(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
Tensor<Engine3, Layout3> const &predicate_K, int k_idx, const int max_MN=0) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
// There's no case where !Clear_OOB_K && Clear_OOB_MN
static_assert(!(Clear_OOB_MN && !Clear_OOB_K));
#pragma unroll
for (int m = 0; m < size<1>(S); ++m) {
if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
if (Is_even_K || predicate_K(k_idx)) {
cute::copy(tiled_copy, S(_, m, k_idx), D(_, m, k_idx));
} else if (Clear_OOB_K) {
cute::clear(D(_, m, k_idx));
}
} else if (Clear_OOB_MN) {
cute::clear(D(_, m, k_idx));
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool Is_even_K=true,
typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
__forceinline__ __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
Tensor<Engine3, Layout3> const &predicate_K,
const int max_MN=0, const int min_MN=0) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); }
#pragma unroll
for (int m = 0; m < size<1>(S); ++m) {
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
#pragma unroll
for (int k = 0; k < size<2>(S); ++k) {
if (Is_even_K || predicate_K(k)) {
cute::copy(S(_, m, k), D(_, m, k));
}
}
}
}
}
#if 1
/*
for _64x32, use thread layout is 64x4, per thread get 8 elements, get 64x32 data, put data in lds with 32x64
for _16x128, use thread layout is 16x16, per thread get 8 elements, get 16x128 data, put data in lds with 32x64
for _16x192, use thread layout is 16x16, per thread get 12 elements, get 16x192 data, put data in lds with 48x64
for _16x64_128, use thread layout is 16x16, per thread get 4 elements with offset 128, get 16x64 data, put data in lds with 16x64
*/
enum MMA_LAYOUT{ _64x32 /* for gemm0 load K */,_64x64_LIT, _64x16 /* for gemm1 load V */, _16x128 /* for gemm1 load V */, _16x192 /* for dim 192 */, _16x64_128 /* for dim 64 */, _16x64_64 /*for load dim 64 V*/ ,
_16x96 /*for load dim 96 V*/,
_16x96_multi_ins /*for load dim 96 V*/,
_16x256 /* for dim 256 read V */,
_64x64, _32x128 /* for dim 192,128 fp8 read KV */
};
template <bool Is_even_K=true,
bool Is_even_MN=true,
MMA_LAYOUT mma_layout = _64x32,
int K_BUFF_SIZE = 0,
bool Use_cache_swizzle = true,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
lds_direct_copy(
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst,
int k_idx_, const int row_stride,
const int max_K = 0, const int max_MN=0)
{
constexpr int warp_size = 64;
int tidx = threadIdx.x;
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size) % 4;
int lane = tidx % warp_size;
constexpr int element_size = 2;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
int k_slide = k_idx;
if constexpr(K_BUFF_SIZE) {
k_slide = (k_idx % K_BUFF_SIZE);
}
const int offset_s = 0;
// global addr
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
if constexpr (Use_cache_swizzle) {
glob_ptr.latter += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride
}
uint32x4_t global_addr = {0};
global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former);
global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
if constexpr(mma_layout == _64x32) {
constexpr int elements_per_thread = 8;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
int mma_k = 32*64;
int row = tidx % 16;
int col = lane / 16;
int row_offset = row * 4 + warp_id;
int col_offset = col * elements_per_thread + k_idx * 32;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_K && col_offset >= max_K) offset_v = -1;
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
#if (defined(__gfx936__) || defined(__gfx938__) )
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
#endif
} else if constexpr(mma_layout == _64x16) {
constexpr int elements_per_thread = 4;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
int mma_k = 16*64;
int row = tidx % 16;
int col = lane / 16;
int row_offset = row * 4 + warp_id;
int col_offset = col * elements_per_thread + k_idx * 16;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_MN && col_offset >= max_MN) offset_v = -1;
if (!Is_even_K && row_offset >= max_K) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
#if (defined(__gfx936__) || defined(__gfx938__) )
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
#endif
} else if constexpr(mma_layout == _16x128) {
constexpr int elements_per_thread = 8;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
int mma_k = 16*128;
int row = lane / 4;
int col = tidx % 4;
int row_offset = row + k_idx * 16;
int col_offset = col * elements_per_thread + warp_id * 32;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_K && col_offset >= max_K) offset_v = -1;
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
#if (defined(__gfx936__) || defined(__gfx938__) )
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
#endif
} else if constexpr(mma_layout == _16x192) {
constexpr int elements_per_thread = 8;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
int mma_k = 48*64;
int row = lane / 4;
int col = tidx % 4;
int row_offset = row + k_idx * 16;
int col_offset = col * elements_per_thread + warp_id * 32;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_K && col_offset >= max_K) offset_v = -1;
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
#if (defined(__gfx936__) || defined(__gfx938__) )
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
#endif
constexpr int elements_per_thread_tail = 4;
constexpr int bytes_per_warp_tail = warp_size * elements_per_thread_tail * element_size;
row = (tidx / 8) % 16;
col = tidx % 8;
row_offset = row + k_idx * 16;
col_offset = col * elements_per_thread_tail + warp_id / 2 * 32 + /* pre offset */128 ;
offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_K && col_offset >= max_K) offset_v = -1;
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + /* pre offset */64*32 * element_size + warp_id * bytes_per_warp_tail + k_slide * mma_k * element_size;
// if (thread0()) printf("tid:%d offset_v:%d ldsAddrPerWave:%d\n", tidx, offset_v, ldsAddrPerWave);
#if (defined(__gfx936__) || defined(__gfx938__) )
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
#endif
} else if constexpr(mma_layout == _16x64_128) {
constexpr int elements_per_thread = 4;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
int mma_k = 16*64;
int row = (tidx / 8) % 16;
int col = tidx % 8;
int row_offset = row + k_idx * 16;
int col_offset = col * elements_per_thread + warp_id / 2 * 32 + /* pre offset */128 ;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_K && col_offset >= max_K) offset_v = -1;
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
// if (thread0()) printf("tid:%d offset_v:%d ldsAddrPerWave:%d\n", tidx, offset_v, ldsAddrPerWave);
#if (defined(__gfx936__) || defined(__gfx938__) )
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
#endif
} else if constexpr(mma_layout == _16x64_64) {
constexpr int elements_per_thread = 4;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
int mma_k = 16*64;
int row = (tidx / 8) % 16;
int col = tidx % 8;
int row_offset = row + k_idx * 16;
int col_offset = col * elements_per_thread + warp_id / 2 * 32;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_K && col_offset >= max_K) offset_v = -1;
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
// if (tidx < 64) printf("tid:%d offset_v:%d ldsAddrPerWave:%d\n", tidx, offset_v, ldsAddrPerWave);
#if (defined(__gfx936__) || defined(__gfx938__) )
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
#endif
} else if constexpr(mma_layout == _16x96) {
constexpr int elements_per_thread = 8;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
int mma_k = 16*96;
int row = lane / 4;
int col = tidx % 4;
int row_offset = row + k_idx * 16;
int col_offset = col * elements_per_thread + warp_id * 32;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_K && col_offset >= max_K) offset_v = -1;
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
if (warp_id < 3) {
#if (defined(__gfx936__) || defined(__gfx938__) )
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
#endif
}
} else if constexpr(mma_layout == _16x96_multi_ins) {
constexpr int elements_per_thread = 4;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
int mma_k = 16*96;
int row = lane / 8;
int col = tidx % 8;
int row_offset = row + (warp_id % 2) * 8 + k_idx * 16;
int col_offset = col * elements_per_thread + warp_id / 2 * 32;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_K && col_offset >= max_K) offset_v = -1;
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
#if (defined(__gfx936__) || defined(__gfx938__) )
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
#endif
constexpr int elements_per_thread_tail = 2;
constexpr int bytes_per_warp_tail = warp_size * elements_per_thread_tail * element_size;
row = lane / 16;
col = tidx % 16;
row_offset = row + warp_id * 4 + k_idx * 16;
col_offset = col * elements_per_thread_tail + /* pre offset */64 ;
offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_K && col_offset >= max_K) offset_v = -1;
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + /* pre offset */16*64 * element_size + warp_id * bytes_per_warp_tail + k_slide * mma_k * element_size;
// if (thread0()) printf("tid:%d offset_v:%d ldsAddrPerWave:%d\n", tidx, offset_v, ldsAddrPerWave);
#if (defined(__gfx936__) || defined(__gfx938__) )
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dword %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
#endif
}
}
#define fp8 unsigned char
__forceinline__ __device__ float fp8e5m2_to_fp32(const fp8& input) {
union uf16{
uint16_t as_bits;
_Float16 as_value;
} ;
union uf32 {
uint32_t as_bits;
float as_value;
};
uf16 u16;
uf32 u32;
u16.as_bits = (uint16_t)input << 8;
u32.as_value = (float)u16.as_value;
return u32.as_value;
}
template <typename Element, bool Is_even_K=true,
bool Is_even_MN=true,
MMA_LAYOUT mma_layout = _64x32,
int K_BUFF_SIZE = 0,
bool Use_cache_swizzle = true,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
lds_direct_copy_kv_fp8(float scale,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst,
int k_idx_, const int row_stride,
const int max_K = 0, const int max_MN=0)
{
constexpr int warp_size = 64;
int tidx = threadIdx.x;
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size) % 4;
int lane = tidx % warp_size;
constexpr int element_size = 2;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
int k_slide = k_idx;
if constexpr(K_BUFF_SIZE) {
k_slide = (k_idx % K_BUFF_SIZE);
}
if constexpr(mma_layout == _64x32) {
constexpr int elements_per_thread = 8;
int mma_k = 32*64;
int row = tidx % 16;
int col = lane / 16;
int row_offset = row * 4 + warp_id;
int col_offset = col * elements_per_thread + k_idx * 32;
Element rst[8];
cutlass::NumericConverter<Element, float, cutlass::FloatRoundStyle::round_toward_zero> convert_;
const fp8* src_ptr = reinterpret_cast<const fp8*>(src.data().get());
#pragma unroll
for (int i = 0; i < 8; ++i) {
if ((Is_even_K || col_offset < max_K) &&
(Is_even_MN || row_offset < max_MN)) {
int offset = row_offset * row_stride + col_offset + i;
float f = fp8e5m2_to_fp32(src_ptr[offset]) * scale;
rst[i] = convert_(f);
} else {
rst[i] = Element(0);
}
}
int element_offset = warp_id * warp_size * elements_per_thread + k_slide * mma_k + lane * elements_per_thread;
Element* lds_ptr = dst.data().get() + element_offset;
*reinterpret_cast<uint4*>(lds_ptr) = *reinterpret_cast<uint4*>(rst);
} else if constexpr(mma_layout == _64x16) {
constexpr int elements_per_thread = 4;
int mma_k = 16*64;
int row = tidx % 16;
int col = lane / 16;
int row_offset = row * 4 + warp_id;
int col_offset = col * elements_per_thread + k_idx * 16;
Element rst[4];
cutlass::NumericConverter<Element, float, cutlass::FloatRoundStyle::round_toward_zero> convert_;
const fp8* src_ptr = reinterpret_cast<const fp8*>(src.data().get());
bool valid = (Is_even_K || row_offset < max_K) &&
(Is_even_MN || col_offset < max_MN); // 不检查 col_offset+i
for (int i = 0; i < 4; ++i) {
if (valid) {
int offset = row_offset * row_stride + col_offset + i;
float f = fp8e5m2_to_fp32(src_ptr[offset]) * scale;
rst[i] = convert_(f);
} else {
rst[i] = Element(0);
}
}
int element_offset = warp_id * warp_size * elements_per_thread + k_slide * mma_k + lane * elements_per_thread;
Element* lds_ptr = dst.data().get() + element_offset;
*reinterpret_cast<uint2*>(lds_ptr) = *reinterpret_cast<uint2*>(rst);
}
}
template <bool Is_even_K=true,
bool Is_even_MN=true,
MMA_LAYOUT mma_layout = _64x32,
bool Use_cache_swizzle = true,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
lds_direct_copy(int k_slide,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst,
int k_idx_, const int row_stride,
const int max_K = 0, const int max_MN=0)
{
constexpr int warp_size = 64;
int tidx = threadIdx.x;
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size) % 4;
int lane = tidx % warp_size;
constexpr int element_size = 2;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
const int offset_s = 0;
// global addr
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
if constexpr (Use_cache_swizzle) {
glob_ptr.latter += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride
}
uint32x4_t global_addr = {0};
global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former);
global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
if constexpr(mma_layout == _64x32) {
constexpr int elements_per_thread = 8;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
int mma_k = 32*64;
int row = tidx % 16;
int col = lane / 16;
int row_offset = row * 4 + warp_id;
int col_offset = col * elements_per_thread + k_idx * 32;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_K && col_offset >= max_K) offset_v = -1;
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
#if (defined(__gfx936__) || defined(__gfx938__) )
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
#endif
} else if constexpr(mma_layout == _16x128) {
constexpr int elements_per_thread = 8;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
int mma_k = 16*128;
int row = lane / 4;
int col = tidx % 4;
int row_offset = row + k_idx * 16;
int col_offset = col * elements_per_thread + warp_id * 32;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_K && col_offset >= max_K) offset_v = -1;
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
#if (defined(__gfx936__) || defined(__gfx938__) )
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
#endif
} else if constexpr(mma_layout == _16x192) {
constexpr int elements_per_thread = 8;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
int mma_k = 48*64;
int row = lane / 4;
int col = tidx % 4;
int row_offset = row + k_idx * 16;
int col_offset = col * elements_per_thread + warp_id * 32;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_K && col_offset >= max_K) offset_v = -1;
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
#if (defined(__gfx936__) || defined(__gfx938__) )
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
#endif
constexpr int elements_per_thread_tail = 4;
constexpr int bytes_per_warp_tail = warp_size * elements_per_thread_tail * element_size;
row = (tidx / 8) % 16;
col = tidx % 8;
row_offset = row + k_idx * 16;
col_offset = col * elements_per_thread_tail + warp_id / 2 * 32 + /* pre offset */128 ;
offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_K && col_offset >= max_K) offset_v = -1;
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + /* pre offset */64*32 * element_size + warp_id * bytes_per_warp_tail + k_slide * mma_k * element_size;
// if (thread0()) printf("tid:%d offset_v:%d ldsAddrPerWave:%d\n", tidx, offset_v, ldsAddrPerWave);
#if (defined(__gfx936__) || defined(__gfx938__) )
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
#endif
} else if constexpr(mma_layout == _16x64_128) {
constexpr int elements_per_thread = 4;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
int mma_k = 16*64;
int row = (tidx / 8) % 16;
int col = tidx % 8;
int row_offset = row + k_idx * 16;
int col_offset = col * elements_per_thread + warp_id / 2 * 32 + /* pre offset */128 ;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_K && col_offset >= max_K) offset_v = -1;
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
// if (thread0()) printf("tid:%d offset_v:%d ldsAddrPerWave:%d\n", tidx, offset_v, ldsAddrPerWave);
#if (defined(__gfx936__) || defined(__gfx938__) )
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
#endif
} else if constexpr(mma_layout == _16x64_64) {
constexpr int elements_per_thread = 4;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
int mma_k = 16*64;
int row = (tidx / 8) % 16;
int col = tidx % 8;
int row_offset = row + k_idx * 16;
int col_offset = col * elements_per_thread + warp_id / 2 * 32;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_K && col_offset >= max_K) offset_v = -1;
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
// if (tidx < 64) printf("tid:%d offset_v:%d ldsAddrPerWave:%d\n", tidx, offset_v, ldsAddrPerWave);
#if (defined(__gfx936__) || defined(__gfx938__) )
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
#endif
} else if constexpr(mma_layout == _16x96) {
constexpr int elements_per_thread = 8;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
int mma_k = 16*96;
int row = lane / 4;
int col = tidx % 4;
int row_offset = row + k_idx * 16;
int col_offset = col * elements_per_thread + warp_id * 32;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_K && col_offset >= max_K) offset_v = -1;
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
if (warp_id < 3) {
#if (defined(__gfx936__) || defined(__gfx938__) )
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
#endif
}
} else if constexpr(mma_layout == _16x96_multi_ins) {
constexpr int elements_per_thread = 4;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
int mma_k = 16*96;
int row = lane / 8;
int col = tidx % 8;
int row_offset = row + (warp_id % 2) * 8 + k_idx * 16;
int col_offset = col * elements_per_thread + warp_id / 2 * 32;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_K && col_offset >= max_K) offset_v = -1;
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
#if (defined(__gfx936__) || defined(__gfx938__) )
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
#endif
constexpr int elements_per_thread_tail = 2;
constexpr int bytes_per_warp_tail = warp_size * elements_per_thread_tail * element_size;
row = lane / 16;
col = tidx % 16;
row_offset = row + warp_id * 4 + k_idx * 16;
col_offset = col * elements_per_thread_tail + /* pre offset */64 ;
offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_K && col_offset >= max_K) offset_v = -1;
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + /* pre offset */16*64 * element_size + warp_id * bytes_per_warp_tail + k_slide * mma_k * element_size;
// if (thread0()) printf("tid:%d offset_v:%d ldsAddrPerWave:%d\n", tidx, offset_v, ldsAddrPerWave);
#if (defined(__gfx936__) || defined(__gfx938__) )
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dword %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
#endif
}
}
template <bool Is_even_K=true,
bool Is_even_MN=true,
MMA_LAYOUT mma_layout = _16x256,
bool Use_cache_swizzle = true,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
lds_direct_copy(int n_idx, int k_slide,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst,
int k_idx_, const int row_stride,
const int max_K = 0, const int max_MN=0)
{
constexpr int warp_size = 64;
const int tidx = threadIdx.x;
const int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size) % 4;
const int lane = tidx % warp_size;
constexpr int element_size = 2;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
const int offset_s = 0;
// global addr
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
if constexpr (Use_cache_swizzle) {
glob_ptr.latter += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride
}
uint32x4_t global_addr = {0};
global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former);
global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
if constexpr(mma_layout == _16x256) {
constexpr int elements_per_thread = 8;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
int mma_k = 16*128;
int row = lane / 4;
int col = tidx % 4;
int row_offset = row + k_idx * 16;
int col_offset = col * elements_per_thread + warp_id * 32 + n_idx * 128;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_K && col_offset >= max_K) offset_v = -1;
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
#if (defined(__gfx936__) || defined(__gfx938__) )
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
#endif
}
}
template <bool Is_even_K=true,
bool Is_even_MN=true,
MMA_LAYOUT mma_layout = _64x64,
int K_BUFF_SIZE = 0,
bool Use_cache_swizzle = true,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
lds_direct_copy_fp8(
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst,
int k_idx_, const int row_stride,
const int max_K = 0, const int max_MN=0)
{
constexpr int warp_size = 64;
int tidx = threadIdx.x;
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;
constexpr int element_size = 1;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
int k_slide = k_idx;
if constexpr(K_BUFF_SIZE) {
k_slide = (k_idx % K_BUFF_SIZE);
}
const int offset_s = 0;
// global addr
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
if constexpr (Use_cache_swizzle) {
glob_ptr.latter += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride
}
uint32x4_t global_addr = {0};
global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former);
global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
if constexpr(mma_layout == _64x64) {
// constexpr int elements_per_thread = 16;
// constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
// int mma_k = 64*64;
// int row = tidx % 16;
// int col = lane / 16;
// int row_offset = row * 4 + warp_id;
// int col_offset = col * elements_per_thread + k_idx * 64;
// int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
// if (!Is_even_K && col_offset >= max_K) offset_v = -1;
// if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
constexpr int elements_per_thread = 16;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
int mma_k = 64*64;
int row = tidx % 16;
int col = lane / 16;
int row_offset = row * 2 + warp_id +(warp_id/2)*30;
int col_offset = col * elements_per_thread + k_idx * 64;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_K && col_offset >= max_K) offset_v = -1;
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
#if defined(__gfx938__)
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
#endif
}else if constexpr(mma_layout == _64x64_LIT) {
constexpr int elements_per_thread = 16;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
int mma_k = 64*64;
int row = tidx % 16;
int col = lane / 16;
int row_offset = row * 2 + warp_id +(warp_id/2)*30;
int col_offset = col * elements_per_thread + k_idx * 64;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_K && col_offset >= max_K) offset_v = -1;
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
#if defined(__gfx938__)
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
#endif
}else if constexpr(mma_layout == _64x32) {
constexpr int elements_per_thread = 8;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
int mma_k = 32*64;
int row = tidx % 16;
int col = lane / 16;
int row_offset = row * 4 + warp_id;
int col_offset = col * elements_per_thread + k_idx * 32;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_MN && col_offset >= max_MN) offset_v = -1;
if (!Is_even_K && row_offset >= max_K) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
#if defined(__gfx938__)
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
#endif
} else if constexpr(mma_layout == _32x128) {
constexpr int elements_per_thread = 16;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
int mma_k = 32*128;
int row = lane / 2;
int col = tidx % 2;
int row_offset = row + k_idx * 32;
int col_offset = col * elements_per_thread + warp_id * 32;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_K && col_offset >= max_K) offset_v = -1;
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
#if defined(__gfx938__)
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
#endif
}
}
template <bool Is_even_K=true,
bool Is_even_MN=true,
MMA_LAYOUT mma_layout = _64x32,
int K_BUFF_SIZE = 0,
bool Use_cache_swizzle = true,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
lds_direct_copy_for_vertical_sparse(
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst,
int k_idx_, const int row_stride, int row_offset,
const int max_K = 0, const int max_MN=0)
{
constexpr int warp_size = 64;
int tidx = threadIdx.x;
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;
constexpr int element_size = 2;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
int k_slide = k_idx;
if constexpr(K_BUFF_SIZE) {
k_slide = (k_idx % K_BUFF_SIZE);
}
const int offset_s = 0;
if constexpr(mma_layout == _64x32) {
// global addr
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
{
// 设置stride值为16, 因为一个线程读取8个元素, 16字节
glob_ptr.latter += ((row_stride * 2 ) << 16); // 62 bit: cache swizzle; 48~61: Stride
}
uint32x4_t global_addr = {0};
global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former);
global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter);
global_addr[2] = __builtin_amdgcn_readfirstlane(max_MN); // number records 95:64
global_addr[3] = 0x00020000;
constexpr int elements_per_thread = 8;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
int mma_k = 32*64;
int row = tidx % 16;
int col = lane / 16;
// int row_offset = cols_ptr[row * 4 + warp_id];
int col_offset = col * elements_per_thread + k_idx * 32;
typedef uint32_t uint32x2_t __attribute__((ext_vector_type(2)));
uint32x2_t offset_v = {0};
offset_v[0] = row_offset;
offset_v[1] = col_offset * 2;
// int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
// if (!Is_even_K && col_offset >= max_K) offset_v = -1;
// if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
#if (defined(__gfx936__) || defined(__gfx938__) )
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,idxen offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
#endif
} else if constexpr(mma_layout == _16x128) {
// global addr
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
// if constexpr (Use_cache_swizzle)
{
// 设置stride值为16, 因为一个线程读取8个元素, 16字节
glob_ptr.latter += ((row_stride * 2 ) << 16); // 62 bit: cache swizzle; 48~61: Stride
}
uint32x4_t global_addr = {0};
global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former);
global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter);
global_addr[2] = __builtin_amdgcn_readfirstlane(max_MN); // number records 95:64
global_addr[3] = 0x00020000;
constexpr int elements_per_thread = 8;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
int mma_k = 16*128;
int row = lane / 4;
int col = tidx % 4;
// int row_offset = cols_ptr[row + k_idx * 16];
int col_offset = col * elements_per_thread + warp_id * 32;
// int64_t offset_v = (row_offset + col_offset) / 8; // bytes
typedef uint32_t uint32x2_t __attribute__((ext_vector_type(2)));
uint32x2_t offset_v = {0};
offset_v[0] = row_offset ;
offset_v[1] = col_offset * 2;
// if (!Is_even_K && col_offset >= max_K) offset_v = -1;
// if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
// int index_v = offset_v;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
#if (defined(__gfx936__) || defined(__gfx938__) )
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,idxen offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
#endif
}
}
template<MMA_LAYOUT mma_layout = _64x32, int divide = 1, typename Layout>
__forceinline__ __device__ auto convert_layout_B_rowcol(Layout B_layout) {
static_assert(decltype(rank(B_layout))::value == 3);
if constexpr(mma_layout == _64x32||mma_layout == _64x64) {
auto layout = make_layout(get<1>(B_layout), get<0>(B_layout), get<2>(B_layout));
auto l = logical_divide(layout, Shape<Int<divide>>{});
// if (thread0()) {
// printf("l:"); print(l); printf("\n");
// }
return make_layout(get<1>(l), get<1>(get<0>(l)), get<0>(get<0>(l)));
} else if constexpr(mma_layout == _16x128) {
return make_layout(get<0>(B_layout), get<2>(B_layout), get<1>(B_layout));
} else if constexpr(mma_layout == _16x192) {
// disgusting!!! hard code
auto layout = make_layout(Shape<Shape<_8, _1>, _6, _4>{}, Stride<Stride<_1, _0>, _512, Int<3072>>{}); // ((_8,_1),_6,_4):((_1,_0),_3072,_512)
// if (thread0()) {
// printf("layout:"); print(layout); printf("\n");
// }
return layout;
} else if constexpr(mma_layout == _16x64_128) {
// disgusting!!! hard code
return make_layout(Shape<Shape<_8, _1>, _2, _4>{}, Stride<Stride<_1, _0>, _512, Int<1024>>{});
} else if constexpr(mma_layout == _16x64_64) {
// disgusting!!! hard code
return make_layout(Shape<Shape<_8, _1>, _2, _4>{}, Stride<Stride<_1, _0>, _512, Int<1024>>{});
}
};
template<MMA_LAYOUT mma_layout = _64x64, int divide = 1, typename Layout>
__forceinline__ __device__ auto convert_layout_B_rowcol_fp8(Layout B_layout) {
static_assert(decltype(rank(B_layout))::value == 3);
if constexpr(mma_layout == _64x64) {
auto layout = make_layout(get<1>(B_layout), get<0>(B_layout), get<2>(B_layout));
auto l = logical_divide(layout, Shape<Int<divide>>{});
return make_layout(get<1>(l), get<1>(get<0>(l)), get<0>(get<0>(l)));
} else if constexpr(mma_layout == _32x128) {
auto layout = make_layout(Shape<Shape<_16, _1>, _4, _2>{}, Stride<Stride<_1, _0>, _1024, Int<4096>>{});
return layout;
}
};
template<MMA_LAYOUT mma_layout = _64x32, int divide = 1, typename Layout>
__forceinline__ __device__ auto convert_layout_B_rowcol_(Layout B_layout) {
static_assert(decltype(rank(B_layout))::value == 3);
if constexpr(mma_layout == _64x32) {
auto layout = make_layout(get<2>(B_layout), get<0>(B_layout), get<1>(B_layout));
auto l = logical_divide(layout, Shape<Int<divide>>{});
return make_layout(get<1>(l), get<0>(get<0>(l)), get<1>(get<0>(l)));
} else if constexpr(mma_layout == _16x128 || mma_layout == _16x192 || mma_layout == _16x64_64 || mma_layout == _16x96) {
auto layout = make_layout(get<1>(B_layout), get<0>(B_layout), get<2>(B_layout));
auto l = logical_divide(layout, Shape<Int<divide>>{});
return make_layout(get<1>(l), get<0>(get<0>(l)), get<1>(get<0>(l)));
}
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
// resolves offset of a slice of a paged kv copy from gmem.
// assumes that the tensor has already been positioned at the correct head.
template <typename Kernel_traits>
__forceinline__ __device__
int64_t resolve_thread_kv_page_slice_offset(const int tidx, const int n_block_max, const int page_block_size,
const int* block_table, const int page_stride, const int row_stride) {
constexpr int kGmemThreadsPerRow = Kernel_traits::kGmemThreadsPerRow;
constexpr int kGmemRowsPerThread = Kernel_traits::kGmemRowsPerThread;
constexpr int kGmemElemsPerLoad = Kernel_traits::kGmemElemsPerLoad;
constexpr int kBlockN = Kernel_traits::kBlockN;
const int64_t col_offset = tidx % kGmemThreadsPerRow * kGmemElemsPerLoad;
const int64_t block_row_offset = tidx / kGmemThreadsPerRow * kGmemRowsPerThread;
const int64_t global_row_offset = block_row_offset + (n_block_max - 1) * kBlockN;
const int64_t page_offset = global_row_offset % page_block_size;
const int64_t virtual_page_idx = global_row_offset / page_block_size;
return ((int64_t) block_table[virtual_page_idx]) * ((int64_t) page_stride)
+ page_offset * ((int64_t) row_stride)
+ col_offset;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Layout reshape function. Given a layout with modes ((v1, v2), m, k), returns (v1, v2, k),
// where v2 may be a tuple itself, in the case of swizzled smem-backed thread tiles. This ensures
// that paged and non-paged copies result in equivalently shaped, if not necessarily strided, tensors.
template <class Shape, class Stride>
__forceinline__ __device__
auto reshape_thread_tile(Layout<Shape, Stride> l) {
return make_layout(append(get<0>(l.shape()), get<2>(l.shape())),
append(get<0>(l.stride()), get<2>(l.stride())));
}
// reshapes and flattens the thread tile layout. A separate function is needed for the case where
// one of the modes of l is a layout itself and must be flattened, as opposed to keeping it intact
// for the case of swizzled layouts
template <class Shape, class Stride>
__forceinline__ __device__
auto reshape_flatten_thread_tile(Layout<Shape, Stride> l) {
auto mode_0 = filter(flatten(get<0>(l)));
return make_layout(append(mode_0.shape(), get<2>(l.shape())),
append(mode_0.stride(), get<2>(l.stride())));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Engine, typename Layout>
__forceinline__ __device__ void apply_softcap(Tensor<Engine, Layout> &tensor, const float softcap){
#pragma unroll
for (int i = 0; i < size(tensor); ++i) {
tensor(i) = cutlass::fast_tanh(tensor(i) * softcap);
}
}
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__forceinline__ __device__ void calculate_dtanh(Tensor<Engine0, Layout0> &src_tensor, Tensor<Engine1, Layout1> &dst_tensor, const float softcap){
#pragma unroll
for (int i = 0; i < size(src_tensor); ++i) {
dst_tensor(i) = (1.f - (src_tensor(i) * src_tensor(i))) * softcap;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool Use_mask, typename TensorSrc, typename TensorAcc>
__forceinline__ __device__ void apply_atten_mask(TensorSrc const& src, TensorAcc& accum, float value = -INFINITY) {
if constexpr(Use_mask) {
CUTE_STATIC_ASSERT_V(size(src) == size(accum));
#pragma unroll
for (int i = 0; i < size(src); i++) {
accum(i) = src(i) ? accum(i) : value;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
/*
原来的 exp2f 对于极小数有特殊处理, 对于小于 -126 的输入 x , exp2f 计算方式是 2^(x + 64) * 2^{-64}
但是对于深度学习来说, 2^-126 的数字其实没那么重要了, 因此只需要保留 v_exp_f32 直接暴力计算即可
*/
extern __device__ __attribute__((const)) float __llvm_exp2_f32(float) __asm("llvm.exp2.f32");
__forceinline__ __device__ float custom_exp2f(float x) {
#if 0
return __exp2f(x);
#elif 0
return __llvm_exp2_f32(x);
#else
return __builtin_amdgcn_exp2f(x);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace flash
This source diff could not be displayed because it is too large. You can view the blob instead.
#include <fstream>
#include <cstdlib>
#include <vector>
#include "numeric_types.h"
#include "config.h"
#include "static_switch.h"
#include "flash_singleton.h"
#include "flash_memory.h"
// ====================================================================================================================================
// FWD
// ====================================================================================================================================
void run_mha_fwd(Flash_fwd_params &params, hipStream_t stream, bool force_split_kernel=false) {
#if defined(BUILD_FA_FWD)
const char* fa_debug = std::getenv("FA_DEBUG");
if (fa_debug != nullptr) {
if (std::strcmp(fa_debug, "5") == 0) return;
else if (std::strcmp(fa_debug, "C") == 0) {PRINT_PARAMS}; // for c interface debug
}
if (params.seqused_k != nullptr) {
// Prefix prefill attention
if (!params.is_int8){
FP16_SWITCH(!params.is_bf16, [&] {
if (params.d == 128 and params.d_value == 128) {
run_mha_fwd_prefix_prefill_<elem_type, 128, 128>(params, stream);
} else if (params.d == 192 and params.d_value == 128) {
run_mha_fwd_prefix_prefill_<elem_type, 192, 128>(params, stream);
} else if (params.d == 192 and params.d_value == 192) {
run_mha_fwd_prefix_prefill_<elem_type, 192, 192>(params, stream);
} else if (params.d == 256 and params.d_value == 256) {
run_mha_fwd_prefix_prefill_<elem_type, 256, 256>(params, stream); // used in gemma2-9b
}
});
} else {
FP16_SWITCH(!params.is_bf16, [&] {
if (params.d == 128 and params.d_value == 128) {
run_int8_mha_fwd_prefix_prefill_<elem_type, 128, 128>(params, stream);
} else if (params.d == 192 and params.d_value == 128) {
run_int8_mha_fwd_prefix_prefill_<elem_type, 192, 128>(params, stream);
} else if (params.d == 192 and params.d_value == 192) {
run_int8_mha_fwd_prefix_prefill_<elem_type, 192, 192>(params, stream);
}
});
}
}
else if (params.attn_mask != nullptr) {
// Broadcastable mask attention like torch.sdpa
FP16_SWITCH(!params.is_bf16, [&] {
if (params.d == 128) {
run_mha_fwd_attn_mask_<elem_type, 128, 128>(params, stream);
}
});
}
else if (params.padding_mask != nullptr) {
// Encoder attention, bert. etc.
FP16_SWITCH(!params.is_bf16, [&] {
if (params.d == 64) {
run_mha_fwd_padding_mask_<elem_type, 64, 64>(params, stream);
} else if (params.d == 128) {
run_mha_fwd_padding_mask_<elem_type, 128, 128>(params, stream);
}
});
}
else {
// Decoder-only attention
FP16_SWITCH(!params.is_bf16, [&] {
#if defined(HEADDIM_128_ONLY)
run_mha_fwd_<elem_type, 128, 128>(params, stream);
#elif defined(HEADDIM_192_128_ONLY)
run_mha_fwd_<elem_type, 192, 128>(params, stream);
#else
ALL_HEADDIM_SWITCH(params.d, params.d_value, [&] {
run_mha_fwd_<elem_type, kHeadDimQ, kHeadDimV>(params, stream);
});
#endif
});
}
#endif
}
void (*run_mha_fwd_c)(Flash_fwd_params&, hipStream_t, bool) = run_mha_fwd;
// ====================================================================================================================================
// BWD
// ====================================================================================================================================
void run_mha_bwd(Flash_bwd_params &params, hipStream_t stream, const bool configure=false) {
#if defined(BUILD_FA_BWD)
const char* fa_debug = std::getenv("FA_DEBUG");
if (fa_debug != nullptr) {
if (std::strcmp(fa_debug, "5") == 0) return;
else { printFlashBwdParams(params); };
}
ElementType_SWITCH(params.is_bf16, params.is_e4m3, [&] {
#if defined(HEADDIM_128_ONLY)
run_mha_bwd_<elem_type, 128, 128>(params, stream, configure);
#elif defined(HEADDIM_192_128_ONLY)
run_mha_bwd_<elem_type, 192, 128>(params, stream, configure);
#else
HEADDIM_SWITCH(params.d, params.d_value, [&] {
run_mha_bwd_<elem_type, kHeadDimQ, kHeadDimV>(params, stream, configure);
});
#endif
});
#endif
}
// ====================================================================================================================================
// PA
// ====================================================================================================================================
void run_mha_fwd_kvcache(Flash_fwd_params &params, hipStream_t stream, bool force_split_kernel=false) {
#if defined(BUILD_FA_KVCACHE)
const char* fa_debug = std::getenv("FA_DEBUG");
if (fa_debug != nullptr) {
if (std::strcmp(fa_debug, "5") == 0) return;
else if (std::strcmp(fa_debug, "C") == 0) {PRINT_PARAMS}; // for c interface debug
}
FP16_SWITCH(!params.is_bf16, [&] {
#ifdef HEADDIM_128_ONLY
run_mha_fwd_splitkv_dispatch<elem_type, 128, 128>(params, stream);
#elif defined(HEADDIM_192_128_ONLY)
if (params.d == 192 and params.d_value == 128)
run_mha_fwd_splitkv_dispatch<elem_type, 192, 128>(params, stream);
else if (params.d == 576 and params.d_value == 512)
run_mha_fwd_splitkv_dispatch<elem_type, 576, 512>(params, stream);
#else
PA_HEADDIM_SWITCH(params.d, params.d_value, [&] {
run_mha_fwd_splitkv_dispatch<elem_type, kHeadDimQ, kHeadDimV>(params, stream);
});
#endif
});
#endif
}
void run_int8_fwd_kvcache(Flash_fwd_params &params, hipStream_t stream, bool force_split_kernel=false) {
#if defined(BUILD_FA_KVCACHE)
const char* fa_debug = std::getenv("FA_DEBUG");
if (fa_debug != nullptr) {
if (std::strcmp(fa_debug, "5") == 0) return;
else if (std::strcmp(fa_debug, "C") == 0) {PRINT_PARAMS}; // for c interface debug
}
FP16_SWITCH(!params.is_bf16, [&] {
if (params.d != 128 or params.d_value != 128){
printf("int8 pa only support headdim=128!\n");
assert(params.d == 128 and params.d_value == 128);
}
run_int8_fwd_splitkv_dispatch<elem_type, 128, 128>(params, stream);
});
#endif
}
// ====================================================================================================================================
// FlashMLA
// ====================================================================================================================================
void run_fwd_flashmla(Flash_fwd_mla_params &params, hipStream_t stream, bool force_split_kernel=false) {
#if defined(BUILD_FLASHMLA)
const char* fa_debug = std::getenv("FA_DEBUG");
if (fa_debug != nullptr) {
if (std::strcmp(fa_debug, "5") == 0) return;
else if (std::strcmp(fa_debug, "C") == 0) {PRINT_MLA_PARAMS}; // for c interface debug
}
FP16_SWITCH(!params.is_bf16, [&] {
run_mla_fwd_splitkv_dispatch<elem_type, 576, 512>(params, stream);
});
#endif
}
void run_fwd_prefix_prefill_mla(Flash_fwd_mla_params &params, hipStream_t stream) {
#if defined(BUILD_FA_FWD)
const char* fa_debug = std::getenv("FA_DEBUG");
if (fa_debug != nullptr) {
if (std::strcmp(fa_debug, "5") == 0) return;
else if (std::strcmp(fa_debug, "C") == 0) {PRINT_MLA_PARAMS}; // for c interface debug
}
FP16_SWITCH(!params.is_bf16, [&] {
run_mla_fwd_prefix_prefill_dispatch_<elem_type, 576, 512>(params, stream);
});
#endif
}
\ No newline at end of file
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
namespace flash {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool Varlen=true, bool Is_Kvcache=false, bool USE_BSHD_LAYOUT = false>
struct BlockInfo {
template<typename Params>
__device__ BlockInfo(const Params &params, const int bidb)
: sum_s_q((!Varlen || params.cu_seqlens_q == nullptr) ? -1 : params.cu_seqlens_q[bidb])
, sum_s_k((!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative) ? -1 : params.cu_seqlens_k[bidb])
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr || Is_Kvcache ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
, seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
, actual_seqlen_k(seqlen_k_cache/* + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)*/)
, nheads(params.h)
, nheads_k(params.h_k)
, leftpad_k(params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])
{
}
template <typename index_t>
__forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_k == -1 ? bidb * batch_stride + leftpad_k * row_stride : uint32_t(sum_s_k + leftpad_k) * row_stride;
}
inline __device__ int q_offset1(const int batch_stride, const int row_stride, const int bidb) const {
return sum_s_q == -1 ? bidb * batch_stride : (USE_BSHD_LAYOUT ? uint32_t(sum_s_q) * row_stride : uint32_t(sum_s_q) * row_stride * nheads);
}
inline __device__ int k_offset1(const int batch_stride, const int row_stride, const int bidb) const {
return sum_s_k == -1 ? bidb * batch_stride : (USE_BSHD_LAYOUT ? uint32_t(sum_s_k) * row_stride : uint32_t(sum_s_k) * row_stride * nheads_k);
}
inline __device__ int k_offset1_write(const int batch_stride, const int row_stride, const int bidb) const {
return sum_s_k == -1 ? bidb * batch_stride : (USE_BSHD_LAYOUT ? uint32_t(sum_s_k) * row_stride : uint32_t(sum_s_k) * row_stride * nheads);
}
inline __device__ int q_offset2(const int head_stride, const int bidh) const {
return (USE_BSHD_LAYOUT || sum_s_q == -1) ? bidh * head_stride : uint32_t(actual_seqlen_q) * head_stride * bidh;
}
inline __device__ int k_offset2(const int head_stride, const int bidh) const {
return (USE_BSHD_LAYOUT || sum_s_k == -1) ? bidh * head_stride : uint32_t(actual_seqlen_k) * head_stride *bidh;
}
const int sum_s_q;
const int sum_s_k;
const int actual_seqlen_q;
// We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
const int leftpad_k;
const int seqlen_k_cache;
int actual_seqlen_k;
const int nheads;
const int nheads_k;
};
// Simplified blockinfo for tranditional varlen fwd inference
template<bool USE_BSHD_LAYOUT=false>
struct SimplifyBlockInfo {
template<typename Params>
__device__ SimplifyBlockInfo(const Params &params, const int bidb)
: sum_s_q(params.cu_seqlens_q[bidb])
, sum_s_k(params.cu_seqlens_k[bidb])
, actual_seqlen_q(params.cu_seqlens_q[bidb + 1] - sum_s_q)
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
, seqlen_k_cache((params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
, actual_seqlen_k(seqlen_k_cache/* + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)*/)
, nheads(params.h)
, nheads_k(params.h_k)
// , leftpad_k(0)
{
}
inline __device__ int q_offset1(const int batch_stride, const int row_stride, const int bidb) const {
return sum_s_q == -1 ? bidb * batch_stride : (USE_BSHD_LAYOUT ? uint32_t(sum_s_q) * row_stride : uint32_t(sum_s_q) * row_stride * nheads);
}
inline __device__ int k_offset1(const int batch_stride, const int row_stride, const int bidb) const {
return sum_s_k == -1 ? bidb * batch_stride : (USE_BSHD_LAYOUT ? uint32_t(sum_s_k) * row_stride : uint32_t(sum_s_k) * row_stride * nheads_k);
}
inline __device__ int q_offset2(const int head_stride, const int bidh) const {
return (USE_BSHD_LAYOUT || sum_s_q == -1) ? bidh * head_stride : uint32_t(actual_seqlen_q) * head_stride * bidh;
}
inline __device__ int k_offset2(const int head_stride, const int bidh) const {
return (USE_BSHD_LAYOUT || sum_s_k == -1) ? bidh * head_stride : uint32_t(actual_seqlen_k) * head_stride *bidh;
}
const int sum_s_q;
const int sum_s_k;
const int actual_seqlen_q;
// We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
// const int leftpad_k;
const int seqlen_k_cache;
int actual_seqlen_k;
const int nheads;
const int nheads_k;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct SafeDecodeBlockInfo {
__device__ SafeDecodeBlockInfo() = default;
template<typename Params, bool Is_Q_varlen, bool Is_K_Cumulative>
__device__ void set_params(const Params &params, const int bidb) {
// process Q
if constexpr (Is_Q_varlen) { // Is_Q_varlen also means Is_Q_Cumulative = true
this->sum_s_q = params.cu_seqlens_q[bidb];
this->actual_seqlen_q = params.cu_seqlens_q[bidb + 1] - this->sum_s_q;
} else {
this->actual_seqlen_q = params.seqlen_q;
}
// process KV
if constexpr (Is_K_Cumulative) {
this->sum_s_k = params.cu_seqlens_k[bidb];
this->actual_seqlen_k = params.cu_seqlens_k[bidb + 1] - sum_s_k;
} else {
this->actual_seqlen_k = params.cu_seqlens_k[bidb];
}
}
int sum_s_q;
int sum_s_k;
int actual_seqlen_q;
int actual_seqlen_k;
};
} // namespace flash
#pragma once
#include <block_info.h>
#include "utils.h"
// Just compute dot(do, o) and write the result (softmax_d) to global memory as a separate kernel.
// This is used in the case where we want to parallelize the backward across seqlen_k.
template<bool Clear_dQaccum=true, bool Is_even_MN, class Element, class ElementAccumType, int kBlockM_, int kBlockN_, int WARP_M_, int WARP_N_, int kHeadDim_, int STAGES, bool USE_BSHD_LAYOUT, typename Params>
inline __device__ void compute_dot_do_o(const Params &params) {
Element *do_ptr = static_cast<Element*>(params.do_ptr);
Element *o_ptr = static_cast<Element*>(params.o_ptr);
ElementAccumType* dsoftmax_sum = static_cast<ElementAccumType*>(params.dsoftmax_sum);
const int m_block = blockIdx.x;
// The block index for the batch.
const int bidb = blockIdx.z;
// The block index for the head.
const int bidh = blockIdx.y;
// The thread index.
const int tidx = threadIdx.x;
//wave size should be defined in launch file. Here use 64 threads
int lane_id = threadIdx.x & 63; //lane id, 0-63
int warp_id_vec = threadIdx.x / 64; //warp id in a block
int warp_id =0;
__shared__ Element do_lds[STAGES*(kBlockM_/32) * (kBlockN_/32)*(32*34)];
__shared__ Element o_lds[STAGES*(kBlockM_/32) * (kBlockN_/32)*(32*34)];
float dP_sum_cur[(kBlockM_/16)] = {0.0f};
int stage_id = 0;
constexpr int kBlockM = kBlockM_;
constexpr int kBlockN = kBlockN_;
constexpr int kHeadDim = kHeadDim_;
const int WARP_NUM = (kBlockM_)/(WARP_M_);
const flash::BlockInfo</*Varlen=*/!Is_even_MN, false, USE_BSHD_LAYOUT> binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
int seqlen_do_stride = params.do_row_stride;
int seqlen_o_stride = params.o_row_stride;
const int row_offset_do = binfo.q_offset1(params.do_batch_stride, params.do_row_stride, bidb) + binfo.q_offset2(params.do_head_stride,bidh) + m_block * kBlockM * seqlen_do_stride;
const int row_offset_o = binfo.q_offset1(params.o_batch_stride, params.o_row_stride, bidb) + binfo.q_offset2(params.o_head_stride,bidh) + m_block * kBlockM * seqlen_o_stride;
const int row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM;
// Element *gdO = reinterpret_cast<Element *>(do_ptr) + row_offset_do;
auto gdO = tcp_cache_swizzle_func<kHeadDim_, Element>(reinterpret_cast<Element *>(do_ptr) + row_offset_do);
// Element *gO = reinterpret_cast<Element *>(o_ptr) + row_offset_o;
auto gO = tcp_cache_swizzle_func<kHeadDim_, Element>(reinterpret_cast<Element *>(o_ptr) + row_offset_o);
ElementAccumType *dP_sum = reinterpret_cast<ElementAccumType *>(dsoftmax_sum) + row_offset_dpsum;
asm volatile("v_readfirstlane_b32 %0,%1"
: "=s"(warp_id)
: "v"(warp_id_vec)
:);
vec2_Element<Element> do_reg[(kHeadDim_/kBlockN_)*((WARP_M_*kBlockN_)/(32*32))*2][4]; //ds_read mini size is 32*32,2 is seq, 4 is head dim
vec2_Element<Element> o_reg[(kHeadDim_/kBlockN_)*((WARP_M_*kBlockN_)/(32*32))*2][4]; //ds_read mini size is 32*32,2 is seq, 4 is head dim
// int A_lane_m_idx = (lane_id >> 4);
int do_lane_m_idx = ((lane_id >> 4) & 1)*2 + ((lane_id >> 4) >> 1); //(0, 1, 2, 3) --> (0, 2, 1, 3)
int do_lane_head_dim_idx = (lane_id & 15);
//global->lds, left matrix
// printf("kBlockN_==%d, kHeadDim_=%d, WARP_M_=%d\n",kBlockN_, kHeadDim_, WARP_M_);
for(int k_loop=0; k_loop<kHeadDim_/kBlockN_; k_loop++) {
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
int do_block_buffer_load_global_offset = k_loop * kBlockN_;
// do_ptr buffer load mini size is 4*32, (kBlockM_ * kBlockN_) mini size is (32*32)
const int do_lds_load_num = (kBlockM_ * kBlockN_) / (4*32);
int do_lds_stage_offset = stage_id * (kBlockM_/32) * (kBlockN_/32)*(32*34);
for(int warp_loop=warp_id; warp_loop < do_lds_load_num; warp_loop+=WARP_NUM) {
int padding = (warp_loop & 7)*2; // padding size in shared memory per buffer load, to avoid bank conflict
int do_warp_buffer_load_m_id = (warp_loop & (kBlockM_/4 - 1)); //这样子对L1和utlc1有啥影响呢?
int do_warp_buffer_load_k_id = (warp_loop / (kBlockM_/4));
int do_warp_buffer_load_lds_offset = do_lds_stage_offset + (do_warp_buffer_load_k_id * kBlockM_ * 34) + ((do_warp_buffer_load_m_id >> 3)*(32*34) + (do_warp_buffer_load_m_id & 7)*(4*32)) ;
int do_warp_buffer_load_global_offset = (do_warp_buffer_load_k_id * 32);
int gsOffset = (do_block_buffer_load_global_offset + do_warp_buffer_load_global_offset)/2 ;
// int gvOffset = (do_lane_m_idx * kHeadDim_)/2 + do_lane_head_dim_idx;
int lds_offset = (do_warp_buffer_load_lds_offset + padding)/2;
{
int gvOffset;
if constexpr (!Is_even_MN) {
gvOffset = (min((do_lane_m_idx + (do_warp_buffer_load_m_id * 4)),binfo.actual_seqlen_q - m_block * kBlockM - 1) * seqlen_do_stride)/2 + do_lane_head_dim_idx;
} else {
gvOffset = ((do_lane_m_idx + (do_warp_buffer_load_m_id * 4)) * seqlen_do_stride)/2 + do_lane_head_dim_idx;
}
builtin_buffer_load_dword_lds(do_lds, gdO, lds_offset, gsOffset, gvOffset);
}
{
int gvOffset;
if constexpr (!Is_even_MN) {
gvOffset = (min((do_lane_m_idx + (do_warp_buffer_load_m_id * 4)),binfo.actual_seqlen_q - m_block * kBlockM - 1) * seqlen_o_stride)/2 + do_lane_head_dim_idx;
} else {
gvOffset = ((do_lane_m_idx + (do_warp_buffer_load_m_id * 4)) * seqlen_o_stride)/2 + do_lane_head_dim_idx;
}
builtin_buffer_load_dword_lds(o_lds, gO, lds_offset, gsOffset, gvOffset);
}
}
vmcnt_wait(0);
// By right we need to scale dP up by 1/params.p_dropout, but instead we don't and only scale the final
// results (dQ and dK) by 1/params.p_dropout. So we need to keep dP_sum scaled down by params.p_dropout here,
// so that (dP - dP_sum) is on the same scale.
{
//lds -> vgpr use ds_read_m; left matrix
int do_warp_m_id = (warp_id & ((kBlockM_/WARP_M_) - 1));
int do_lds_stage_offset = stage_id * (kBlockM_/32) * (kBlockN_/32)*(32*17);
vec2_Element<Element> *do_lds_v2fp16 = (vec2_Element<Element> *)(do_lds);
vec2_Element<Element> *o_lds_v2fp16 = (vec2_Element<Element> *)(o_lds);
#pragma unroll
for(int head_dim_idx=0; head_dim_idx<(kBlockN_/32); head_dim_idx++) { //32 half in col direction
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) {
//a warp load min size is (row, col) = (32,16) float
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { //sequence direction
#pragma unroll
for(int vec_id=0; vec_id<4; vec_id++) { //head dim direction
int lds_offset = do_lds_stage_offset + head_dim_idx*kBlockM_*17 + (warp_id*(WARP_M_/32) + m_idx)*(32*17) + vec_id*2 + min_tile_m*32 + (lane_id & 1)*16 + ((lane_id & 15)>>1)*64 + /*padding*/ ((lane_id & 15)>>1) + ((lane_id/16) &1)*8 + (lane_id/32);
inline_ds_read_b32_wait(do_lds_v2fp16, lds_offset, do_reg[/*(k_loop)*((WARP_M_*kBlockN_)/(32*32))*2 +*/ (head_dim_idx*(WARP_M_/32) + m_idx)*2 + min_tile_m][vec_id]);
}
}
}
}
#pragma unroll
for(int head_dim_idx=0; head_dim_idx<(kBlockN_/32); head_dim_idx++) { //32 half in col direction
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) {
//a warp load min size is (row, col) = (32,16) float
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { //sequence direction
#pragma unroll
for(int vec_id=0; vec_id<4; vec_id++) { //head dim direction
int lds_offset = do_lds_stage_offset + head_dim_idx*kBlockM_*17 + (warp_id*(WARP_M_/32) + m_idx)*(32*17) + vec_id*2 + min_tile_m*32 + (lane_id & 1)*16 + ((lane_id & 15)>>1)*64 + /*padding*/ ((lane_id & 15)>>1) + ((lane_id/16) &1)*8 + (lane_id/32);
inline_ds_read_b32_wait(o_lds_v2fp16, lds_offset, o_reg[/*(k_loop)*((WARP_M_*kBlockN_)/(32*32))*2 +*/ (head_dim_idx*(WARP_M_/32) + m_idx)*2 + min_tile_m][vec_id]);
}
}
}
}
}
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for (int mi = 0; mi < (WARP_M_/32); ++mi) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < (kBlockN/32); ++head_dim_idx) {
#pragma unroll
for(int vec_id = 0; vec_id<4; vec_id++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
if (Is_even_MN || (m_block * kBlockM + mi*32 + min_tile_m + (threadIdx.x & 15)*2) < binfo.actual_seqlen_q) {
dP_sum_cur[mi*2 + min_tile_m] += UpCast<Element,float,true>(do_reg[(head_dim_idx*(WARP_M_/32) + mi)*2 + min_tile_m][vec_id][min_tile_n]) * UpCast<Element,float,true>(o_reg[(head_dim_idx*(WARP_M_/32) + mi)*2 + min_tile_m][vec_id][min_tile_n]);
}
}
}
}
}
}
}
#pragma unroll
for (int mi = 0; mi < (WARP_M_/32); ++mi) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
flash::SumOp<float> sum_op;
dP_sum_cur[mi*2 + min_tile_m] = flash::Allreduce<64>::run(dP_sum_cur[mi*2 + min_tile_m], sum_op) * params.p_dropout;
if ((threadIdx.x >> 4) == 0) {
dP_sum[mi*32 + min_tile_m + (threadIdx.x & 15)*2] = dP_sum_cur[mi*2 + min_tile_m];
}
}
}
}
\ No newline at end of file
#pragma once
#include <block_info.h>
#include "utils.h"
#include "prefetch.h"
// Just compute dot(do, o) and write the result (softmax_d) to global memory as a separate kernel.
// This is used in the case where we want to parallelize the backward across seqlen_k.
template<bool Clear_dQaccum=true, bool Is_even_MN, class Element, class ElementAccum, int kBlockM, int kBlockN, int WARP_M, int WARP_N, int K, int STAGES, bool USE_BSHD_LAYOUT, typename Params>
inline __device__ void compute_dot_do_o_gfx938(const Params &params) {
Element *do_ptr = static_cast<Element*>(params.do_ptr);
Element *o_ptr = static_cast<Element*>(params.o_ptr);
ElementAccum* dsoftmax_sum = static_cast<ElementAccum*>(params.dsoftmax_sum);
const int m_block = blockIdx.x;
// The block index for the batch.
const int bidb = blockIdx.z;
// The block index for the head.
const int bidh = blockIdx.y;
// The thread index.
const int tidx = threadIdx.x;
//wave size should be defined in launch file. Here use 64 threads
int lane_id = threadIdx.x & 63; //lane id, 0-63
int warp_id_vec = threadIdx.x / 64; //warp id in a block
int warp_id = 0;
__shared__ Element dO_lds[kBlockM * kBlockN];
__shared__ Element O_lds[kBlockM * kBlockN];
float dP_sum_cur[(kBlockM/16)] = {0.0f};
const int WARP_NUM = (kBlockM)/(WARP_M);
const flash::BlockInfo</*Varlen=*/!Is_even_MN, false, USE_BSHD_LAYOUT> binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
int seqlen_do_stride = params.do_row_stride;
int seqlen_o_stride = params.o_row_stride;
const int row_offset_do = binfo.q_offset1(params.do_batch_stride, params.do_row_stride, bidb) + binfo.q_offset2(params.do_head_stride,bidh) + m_block * kBlockM * seqlen_do_stride;
const int row_offset_o = binfo.q_offset1(params.o_batch_stride, params.o_row_stride, bidb) + binfo.q_offset2(params.o_head_stride,bidh) + m_block * kBlockM * seqlen_o_stride;
const int row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM;
auto gdO = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(do_ptr) + row_offset_do, seqlen_do_stride);
auto gO = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(o_ptr) + row_offset_o, seqlen_o_stride);
ElementAccum *dP_sum = reinterpret_cast<ElementAccum *>(dsoftmax_sum) + row_offset_dpsum;
warp_id = __builtin_amdgcn_readfirstlane(warp_id_vec);
union_vec4_f16x2<Element> dO_reg[((WARP_M*kBlockN)/(32*32))*2];
union_vec4_f16x2<Element> O_reg[((WARP_M*kBlockN)/(32*32))*2];
for(int k_loop=0; k_loop<K/kBlockN; k_loop++) {
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
int do_block_buffer_load_global_offset = k_loop * kBlockN;
//read 32 * 128
prefetch_to_lds_gfx938<true, kBlockM, kBlockN, Element, ElementAccum, Is_even_MN, 1>(gdO, do_block_buffer_load_global_offset, dO_lds, binfo.actual_seqlen_q - m_block * kBlockM, warp_id);
prefetch_to_lds_gfx938<true, kBlockM, kBlockN, Element, ElementAccum, Is_even_MN, 1>(gO, do_block_buffer_load_global_offset, O_lds, binfo.actual_seqlen_q - m_block * kBlockM, warp_id);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
for(int i = 0; i < kBlockN / 32; ++i) {
DS_READ_MATRIX_32X32_B16(ds_offset_cast(dO_lds + i * 32 * 32), dO_reg[i * 2 + 0].f16, dO_reg[i * 2 + 1].f16, true);
DS_READ_MATRIX_32X32_B16(ds_offset_cast(O_lds + i * 32 * 32), O_reg[i * 2 + 0].f16, O_reg[i * 2 + 1].f16, true);
// if constexpr (std::is_same_v<Element, half_t>) {
// dO_reg[i*2 + 0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(dO_lds + i * 32 * 32, 0, 2, 1, 0);
// dO_reg[i*2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(dO_lds + i * 32 * 32, 1024, 2, 1, 0);
// O_reg[i*2 + 0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(O_lds + i * 32 * 32, 0, 2, 1, 0);
// O_reg[i*2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(O_lds + i * 32 * 32, 1024, 2, 1, 0);
// } else {
// dO_reg[i*2 + 0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(dO_lds + i * 32 * 32, 0, 2, 1, 0);
// dO_reg[i*2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(dO_lds + i * 32 * 32, 1024, 2, 1, 0);
// O_reg[i*2 + 0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(O_lds + i * 32 * 32, 0, 2, 1, 0);
// O_reg[i*2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(O_lds + i * 32 * 32, 1024, 2, 1, 0);
// }
}
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < (kBlockN/32); ++head_dim_idx) {
#pragma unroll
for(int vec_id = 0; vec_id<4; vec_id++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
if (Is_even_MN || (m_block * kBlockM + min_tile_m*16 + (threadIdx.x & 15)) < binfo.actual_seqlen_q) {
dP_sum_cur[min_tile_m] += UpCast<Element,float,false>(dO_reg[head_dim_idx*2 + min_tile_m].f16[vec_id * 2 + min_tile_n]) * UpCast<Element,float,false>(O_reg[head_dim_idx*2 + min_tile_m].f16[vec_id * 2 + min_tile_n]);
}
}
}
}
}
}
#pragma unroll
for (int mi = 0; mi < (WARP_M/32); ++mi) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
flash::SumOp<float> sum_op;
dP_sum_cur[mi*2 + min_tile_m] = flash::Allreduce<64>::run(dP_sum_cur[mi*2 + min_tile_m], sum_op) * params.p_dropout;
if ((threadIdx.x >> 4) == 0) {
dP_sum[mi*32 + min_tile_m * 16 + (threadIdx.x & 15)] = dP_sum_cur[mi*2 + min_tile_m];
}
}
}
}
#include <iostream>
#include <memory>
#include <vector>
#include <random>
#include <fstream>
#include <stdlib.h>
#include <dirent.h>
#include <unistd.h>
#include <sys/stat.h>
#include "assert.h"
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
#include "flash.h"
#include "utils.h"
#include "wait.h"
#include "../numeric_types.h"
#include "philox.cuh"
#include "softmax_tiling.h"
#include "gpu_gemm_nn.h"
#include "gpu_gemm_tt.h"
#include "intrinsic.h"
#include "intrinsic_mls_ds.h"
#include "static_switch.h"
#include "dot_do_o.h"
#include "dot_do_o_gfx938.h"
#include "prefetch.h"
#include "flash_singleton.h"
#include "flash_attention_dv_dk_bwd.h"
#include "flash_attention_dv_dk_bwd_gfx938.h"
#include "flash_attention_dq_bwd.h"
#include "flash_attention_dq_bwd_gfx938.h"
using std::make_shared;
using std::shared_ptr;
template <int kBlockM_, int kBlockN_, int WARP_M_, int WARP_N_, typename Element>
inline __device__ void reshape(Element* smem, vec4_Element<Element> ds_reg_fp16[(WARP_N_/32)*(WARP_M_/32)][4], int warp_id) {
int lane_id = threadIdx.x & 63; //lane id, 0-63
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
int lds_offset = warp_id*(WARP_N_/32)*33*kBlockM_ + n_idx*33*kBlockM_ + m_idx*32*33 + min_tile_m*16*33 + vec_idx*4*33 + (lane_id>>4)*33 + min_tile_n*16 + (lane_id&15);
Element ds_reg_tmp = ds_reg_fp16[(WARP_N_/32)*m_idx + n_idx][min_tile_m*2 + min_tile_n][vec_idx];
{
smem[lds_offset] = ds_reg_fp16[(WARP_N_/32)*m_idx + n_idx][min_tile_m*2 + min_tile_n][vec_idx];
}
}
}
}
}
}
#pragma unroll
for(int m_idx=0; m_idx<(kBlockM_/32); m_idx++) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
int lds_offset = warp_id*33*kBlockM_ + m_idx*32*33 + min_tile_m*16 + vec_idx*4 + (lane_id>>4) + min_tile_n*16*33 + (lane_id&15)*33;
ds_reg_fp16[(WARP_N_/32)*m_idx + n_idx][min_tile_m*2 + min_tile_n][vec_idx] = smem[lds_offset];
}
}
}
}
}
}
/*
* q_ptr: Transposed 32x16 matrix
* k_ptr: Non-transposed 32x16 matrix
* qk_ptr: Non-transposed 32x32 matrixseqlen_q
*/
template<class DataType>
int check_param(int seqlen_q, int seqlen_k, int K, int kBlockM_, int kBlockN_, int kBlockK_, int WARP_M_, int WARP_N_, dim3 blockDim, dim3 gridDim, int maxBlockThreads, int STAGES) {
// min warp size is 32x32
if(WARP_M_<32 || WARP_N_<32) {
std::cout<<"Error, WARP_M_<32 or WARP_N_<32!"<<std::endl;
assert(((WARP_M_>=32) && (WARP_N_>=32)));
}
// check block threads number
const int blockThreads = ((kBlockM_*kBlockN_)/(WARP_M_*WARP_N_)*64);
if(blockThreads > maxBlockThreads) {
std::cout<<"Error,Block threads is greater than maxBlockThreads! "<<std::endl;
assert(blockThreads <= maxBlockThreads);
}
//check lds data numbers
int DataTypeSize = sizeof(DataType);
const int q_lds_size = STAGES * kBlockM_ * kBlockK_ * DataTypeSize;
const int k_lds_size = STAGES * kBlockN_ * kBlockK_ * DataTypeSize;
if(((q_lds_size + k_lds_size)/1024) > 64) {
std::cout<<"Error, shared memory size is greater than 64KB"<<std::endl;
assert(((q_lds_size + k_lds_size)/1024) <= 64); //BW lds 64KB
}
}
#ifdef DEBUGING
#define print_qk(block_id_m, bidb, bidh) {\
int qk_warp_n_offset = warp_id * WARP_M_ * params.seqlen_k; \
int qk_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + qk_warp_n_offset; \
for(int block_n_idx = 0; block_n_idx < kBlockN_/WARP_N_; ++block_n_idx){ \
for(int warp_n_idx = 0; warp_n_idx < WARP_N_/16; ++warp_n_idx){ \
for(int warp_m_idx = 0; warp_m_idx < WARP_M_/16; ++warp_m_idx){ \
for(int vec_idx = 0; vec_idx < 4 ; ++vec_idx) {\
int offset = qk_global_offset + block_n_idx * WARP_N_ + lane_id%16*2*params.seqlen_k + lane_id/16*2 + warp_m_idx * params.seqlen_k + warp_n_idx + vec_idx * 8; \
kq_ptr[offset] = s_reg[block_n_idx][warp_n_idx*(WARP_M_/16) + warp_m_idx].f32[vec_idx]; \
} \
} \
} \
}\
}
#define print_softmax_rescale_o(block_id_m, bidb, bidh) { \
int qk_warp_n_offset = warp_id * WARP_M_ * params.seqlen_k; \
int qk_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + qk_warp_n_offset; \
for(int block_n_idx = 0; block_n_idx < kBlockN_/WARP_N_; ++block_n_idx){ \
for(int warp_n_idx = 0; warp_n_idx < WARP_N_/16; ++warp_n_idx){ \
for(int warp_m_idx = 0; warp_m_idx < WARP_M_/16; ++warp_m_idx){ \
for(int vec_idx = 0; vec_idx < 4 ; ++vec_idx) {\
int offset = qk_global_offset + block_n_idx * WARP_N_ + lane_id%16*2*params.seqlen_k + lane_id/16*2 + warp_m_idx * params.seqlen_k + warp_n_idx + vec_idx * 8; \
s_ptr[offset] = s_reg[block_n_idx][warp_n_idx*(WARP_M_/16) + warp_m_idx].f32[vec_idx]; \
} \
} \
} \
}\
}
#define print_dp(block_id_m, bidb, bidh) {\
int dp_warp_n_offset = warp_id * WARP_M_ * params.seqlen_k; \
int dp_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + dp_warp_n_offset; \
for(int block_n_idx = 0; block_n_idx < kBlockN_/WARP_N_; ++block_n_idx){ \
for(int warp_n_idx = 0; warp_n_idx < WARP_N_/16; ++warp_n_idx){ \
for(int warp_m_idx = 0; warp_m_idx < WARP_M_/16; ++warp_m_idx){ \
for(int vec_idx = 0; vec_idx < 4 ; ++vec_idx) {\
int offset = dp_global_offset + block_n_idx * WARP_N_ + lane_id%16*2*params.seqlen_k + lane_id/16*2 + warp_m_idx * params.seqlen_k + warp_n_idx + vec_idx * 8; \
dp_ptr[offset] = dp_reg[block_n_idx][warp_n_idx*(WARP_M_/16) + warp_m_idx].f32[vec_idx]; \
} \
} \
} \
}\
}
#define print_ds(block_id_m, bidb, bidh) {\
int ds_warp_n_offset = warp_id * WARP_M_ * params.seqlen_k; \
int ds_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + ds_warp_n_offset; \
for(int block_n_idx = 0; block_n_idx < kBlockN_/WARP_N_; ++block_n_idx){ \
for(int warp_n_idx = 0; warp_n_idx < WARP_N_/16; ++warp_n_idx){ \
for(int warp_m_idx = 0; warp_m_idx < WARP_M_/16; ++warp_m_idx){ \
for(int vec_idx = 0; vec_idx < 4 ; ++vec_idx) {\
int offset = ds_global_offset + block_n_idx * WARP_N_ + lane_id%16*2*params.seqlen_k + lane_id/16*2 + warp_m_idx * params.seqlen_k + warp_n_idx + vec_idx * 8; \
ds_ptr[offset] = dS_reg[block_n_idx][warp_n_idx*(WARP_M_/16) + warp_m_idx].f32[vec_idx]; \
} \
} \
} \
}\
}
#endif
template<class Element, class ElementAccum, bool Is_dropout, bool Is_causal , bool Is_local, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Seq_parallel, int kBlockM_, int kBlockN_, int K, int K_v, int kBlockK_, int WARP_M_, int WARP_N_, int STAGES, int USE_BSHD_LAYOUT, typename Params>
__forceinline__ __device__ void compute_dq_1colblock(Params &params, int bidb, int bidh, int m_block
) {
#ifdef DEBUGING
ElementAccum * kq_ptr = static_cast<ElementAccum*>(params.kq_ptr);
ElementAccum * s_ptr = static_cast<ElementAccum*>(params.s_ptr);
ElementAccum * dp_ptr = static_cast<ElementAccum*>(params.dp_ptr);
ElementAccum * ds_ptr = static_cast<ElementAccum*>(params.ds_ptr);
#endif
Element* q_ptr = static_cast<Element*>(params.q_ptr);
Element* k_ptr = static_cast<Element*>(params.k_ptr);
Element* v_ptr = static_cast<Element*>(params.v_ptr);
Element* o_ptr = static_cast<Element*>(params.o_ptr);
Element* dq_ptr = static_cast<Element*>(params.dq_ptr);
Element* dk_ptr = static_cast<Element*>(params.dk_ptr);
Element* dv_ptr = static_cast<Element*>(params.dv_ptr);
Element* do_ptr = static_cast<Element*>(params.do_ptr);
ElementAccum* softmax_lse_ptr = static_cast<ElementAccum*>(params.softmax_lse_ptr);
ElementAccum* dsoftmax_sum = static_cast<ElementAccum*>(params.dsoftmax_sum);
//flash-attention QK, kBlockN_==WARP_N_;
const int M_BLOCK_NUM = params.seqlen_q/kBlockM_;
const int N_BLOCK_NUM = params.seqlen_k/kBlockN_;
extern __shared__ Element smem[];
#if 1//defined(__gfx936__)
const bool Is_store_K = true;
const bool Is_preload_K = true;
const bool Is_preload_V = true;
#else
const bool Is_store_K = false;
const bool Is_preload_K = false;
const bool Is_preload_V = false;
#endif
const int K_prefetch_level = Is_preload_K ? 1 : 0;
const int V_prefetch_level = Is_preload_V ? 1 : 0;
const int Q_prefetch_level = 3;
Element* K_lds = (Element*)&(smem);
Element* Q_lds = (Element*)&(smem);
Element* dO_lds = (Element*)&(smem);
Element* V_lds = (Element*)&(smem) + (kBlockN_/32)*(K/32)*(32*34);//(Is_preload_K || Is_store_K) ? (Element*)&(smem) + (kBlockN_/32)*(K/32)*(32*34) : (Element*)&(smem);
int tidx = threadIdx.x;
int lane_id = threadIdx.x & 63; //lane id, 0-63
const flash::BlockInfo</*Varlen=*/!Is_even_MN, false, USE_BSHD_LAYOUT> binfo(params, bidb);
if (m_block < 0 || m_block * kBlockM_ >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return;
const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM_ - params.window_size_left) / kBlockN_);
const int n_block_max = (!Is_causal && !Is_local) ? ceil_div(binfo.actual_seqlen_k, kBlockN_) : std::min(ceil_div(binfo.actual_seqlen_k, kBlockN_), flash::ceil_div((m_block + 1) * kBlockM_ + params.window_size_right, kBlockN_));
int seqlen_q_stride = params.q_row_stride;
int seqlen_k_stride = params.k_row_stride;
int seqlen_v_stride = params.v_row_stride;
int seqlen_do_stride = params.do_row_stride;
int seqlen_o_stride = params.o_row_stride;
int seqlen_dq_stride = params.dq_row_stride;
// We move K and V to the last block.
const int row_offset_q = binfo.q_offset1(params.q_batch_stride, params.q_row_stride, bidb) + binfo.q_offset2(params.q_head_stride,bidh) + m_block * kBlockM_ * seqlen_q_stride;
const int row_offset_k = binfo.k_offset1(params.k_batch_stride, params.k_row_stride, bidb) + binfo.k_offset2(params.k_head_stride,bidh/params.h_h_k_ratio) + (n_block_max - 1) * kBlockN_ * seqlen_k_stride;
const int row_offset_v = binfo.k_offset1(params.v_batch_stride, params.v_row_stride, bidb) + binfo.k_offset2(params.v_head_stride,bidh/params.h_h_k_ratio) + (n_block_max - 1) * kBlockN_ * seqlen_v_stride;
const int row_offset_dO = binfo.q_offset1(params.do_batch_stride, params.do_row_stride, bidb) + binfo.q_offset2(params.do_head_stride,bidh) + m_block * kBlockM_ * seqlen_do_stride;
const int row_offset_o = binfo.q_offset1(params.o_batch_stride, params.o_row_stride, bidb) + binfo.q_offset2(params.o_head_stride,bidh) + m_block * kBlockM_ * seqlen_o_stride;
const int row_offset_dq = binfo.q_offset1(params.dq_batch_stride, params.dq_row_stride, bidb) + binfo.q_offset2(params.dq_head_stride,bidh) + m_block * kBlockM_ * seqlen_dq_stride;
const int row_offset_lse = params.cu_seqlens_q == nullptr ? (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM_ : bidh * params.total_q + binfo.sum_s_q + m_block * kBlockM_;
const int row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM_;
// Element * gQ = reinterpret_cast<Element *>(q_ptr) + row_offset_q;
auto gQ = tcp_cache_swizzle_func<K, Element>(reinterpret_cast<Element *>(q_ptr) + row_offset_q);
// Element * gK = reinterpret_cast<Element *>(k_ptr) + row_offset_k;
auto gK = tcp_cache_swizzle_func<K, Element>(reinterpret_cast<Element *>(k_ptr) + row_offset_k);
// Element * gV = reinterpret_cast<Element *>(v_ptr) + row_offset_v;
auto gV = tcp_cache_swizzle_func<K_v, Element>(reinterpret_cast<Element *>(v_ptr) + row_offset_v);
// Element * gdO = reinterpret_cast<Element *>(do_ptr) + row_offset_dO;
auto gdO = tcp_cache_swizzle_func<K_v, Element>(reinterpret_cast<Element *>(do_ptr) + row_offset_dO);
Element * gO = reinterpret_cast<Element *>(o_ptr) + row_offset_o;
dq_ptr = reinterpret_cast<Element *>(dq_ptr) + row_offset_dq;
auto gdQ = tcp_cache_swizzle_func<K, Element>(reinterpret_cast<Element *>(dq_ptr));
ElementAccum *gLSE = reinterpret_cast<ElementAccum *>(softmax_lse_ptr) + row_offset_lse;
ElementAccum *gdPsum = reinterpret_cast<ElementAccum *>(dsoftmax_sum) + row_offset_dpsum;
constexpr int n_masking_steps = (!Is_causal && !Is_local)
? 1
: ((Is_even_MN && Is_causal) ? flash::ceil_div(kBlockM_, kBlockN_) : flash::ceil_div(kBlockM_, kBlockN_) + 1);
// int warp_id =0;
int warp_id_vec = threadIdx.x / 64; //warp id in a block
int warp_id = __builtin_amdgcn_readfirstlane(warp_id_vec);
union_vec2_f16x2<Element> q_reg[(K/kBlockK_)*((WARP_M_*kBlockK_)/(32*32))*2][2];
union_vec2_f16x2<Element> dO_reg[(K_v/kBlockK_)*((WARP_M_*kBlockK_)/(32*32))*2][2];
union_vec4_fp32 acc_dq[(K/kBlockK_) * ((WARP_M_/32)*(kBlockK_/32))][4]={0};
float lse[WARP_M_/16];
#pragma unroll
for (int mi = 0; mi < (WARP_M_/32); ++mi) {
for(int min_tile_m = 0; min_tile_m<2; min_tile_m++) {
int lse_idx = warp_id*WARP_M_ + mi*32 + ((lane_id & 15)*2) + min_tile_m;
lse[mi*2 + min_tile_m] = (Is_even_MN || lse_idx < binfo.actual_seqlen_q - m_block * kBlockM_) ? gLSE[lse_idx] : INFINITY;
}
}
float dP_sum_reg[WARP_M_/16];
#pragma unroll
for (int mi = 0; mi < (WARP_M_/32); ++mi) {
for(int min_tile_m = 0; min_tile_m<2; min_tile_m++) {
int dP_sum_idx = warp_id*WARP_M_ + mi*32 + ((lane_id & 15)*2) + min_tile_m;
dP_sum_reg[mi*2 + min_tile_m] = gdPsum[dP_sum_idx];
}
}
prefetch_to_vgpr<K, kBlockM_, kBlockK_, WARP_N_, Element, ElementAccum, Is_even_MN>(gQ, Q_lds, q_reg, (binfo.actual_seqlen_q - m_block * kBlockM_), seqlen_q_stride);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
prefetch_to_vgpr<K_v, kBlockM_, kBlockK_, WARP_N_, Element, ElementAccum, Is_even_MN>(gdO, dO_lds, dO_reg, (binfo.actual_seqlen_q - m_block * kBlockM_), seqlen_do_stride);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
if constexpr (Is_preload_K){
prefetch_to_tmp_lds_wait<Is_even_MN, K, kBlockM_, kBlockN_, kBlockK_, WARP_M_, WARP_N_, Element>(gK, K_lds, (binfo.actual_seqlen_k - (n_block_max - 1) * kBlockN_), warp_id, seqlen_k_stride);
}
if constexpr (Is_preload_V){
prefetch_to_tmp_lds_wait<Is_even_MN, K_v, kBlockM_, kBlockN_, kBlockK_, WARP_M_, WARP_N_, Element>(gV, V_lds, (binfo.actual_seqlen_k - (n_block_max - 1) * kBlockN_), warp_id, seqlen_v_stride);
}
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
for (int n_block = n_block_max - 1; n_block >= n_block_min ; --n_block) {
union_vec2_f16x2<Element> v_reg[((WARP_N_*kBlockK_)/(32*32))*2][2];
union_vec4_fp32 dp_reg[(WARP_M_/32)*(kBlockN_/32)][4]= {0};
{
//dp gemm
gemm_tt_kq<false, Is_preload_K, Is_even_MN, 3, V_prefetch_level, K_v, kBlockM_, kBlockN_, kBlockK_, WARP_N_, WARP_N_, STAGES, Element>(gdO, gV, dO_lds, V_lds, (binfo.actual_seqlen_q - m_block * kBlockM_), (binfo.actual_seqlen_k - n_block * kBlockN_), dO_reg, v_reg, dp_reg, warp_id, seqlen_do_stride, seqlen_v_stride);
}
#ifdef DEBUGING
print_dp(m_block, bidb, bidh);
#endif
union_vec2_f16x2<Element> k_reg[((WARP_M_*kBlockK_)/(32*32))*2][2];
//c mini tile is 32*32
union_vec4_fp32 s_reg[(WARP_N_/32)*(kBlockM_/32)][4]= {0};
//qk gemm
gemm_tt_kq<Is_store_K, false, Is_even_MN, Q_prefetch_level, K_prefetch_level, K, kBlockM_, kBlockN_, kBlockK_, WARP_N_, WARP_N_, STAGES, Element>(gQ, gK, Q_lds, K_lds, (binfo.actual_seqlen_q - m_block * kBlockM_), (binfo.actual_seqlen_k - n_block * kBlockN_), q_reg, k_reg, s_reg, warp_id, seqlen_q_stride, seqlen_k_stride);
*(uint64_t*)&gV -= ((kBlockN_ * seqlen_v_stride) * sizeof(Element));
if (Is_preload_V && n_block > n_block_min){
prefetch_to_tmp_lds_wait<Is_even_MN, K_v, kBlockM_, kBlockN_, kBlockK_, WARP_M_, WARP_N_, Element>(gV, V_lds, (binfo.actual_seqlen_k - (n_block - 1) * kBlockN_), warp_id, seqlen_v_stride);
}
apply_mask_bwd<Is_even_MN, Is_local ? 3 : (Is_causal ? 1 : 0)>(s_reg, binfo.actual_seqlen_q - m_block * kBlockM_ - warp_id * 32, binfo.actual_seqlen_k - n_block * kBlockN_, (m_block * kBlockM_ + warp_id * 32) - (n_block * kBlockN_), params.window_size_left, params.window_size_right);
#ifdef DEBUGING
print_qk(m_block, bidb, bidh);
#endif
scale_apply_exp2_bwd_seq_q_major</*scale_max=*/false, WARP_M_, kBlockN_, union_vec4_fp32, ElementAccum>(s_reg, lse, params.scale_softmax_log2);
#ifdef DEBUGING
print_softmax_rescale_o(m_block, bidb, bidh)
#endif
auto pointwise_mult = [](float p, float dp, float d) {
return p * (!Is_dropout || p >= 0 ? dp - d : d);
// return p * (dp - d);
};
union_vec4_fp32 dS_reg[(WARP_M_/32)*(kBlockN_/32)][4];
#pragma unroll
for (int ni = 0; ni < (kBlockN_/32); ++ni) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for (int mi = 0; mi < (WARP_M_/32); ++mi) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
// result register ds_reg reuse dp_reg
dS_reg[ni*(WARP_M_/32) + mi][min_tile_n*2 + min_tile_m].f32[vec_idx] = pointwise_mult(
s_reg[ni*(WARP_M_/32) + mi][min_tile_n*2 + min_tile_m].f32[vec_idx],
dp_reg[ni*(WARP_M_/32) + mi][min_tile_n*2 + min_tile_m].f32[vec_idx],
dP_sum_reg[min_tile_m + mi*2]);
}
}
}
}
}
#ifdef DEBUGING
print_ds(m_block, bidb, bidh);
#endif
union_vec2_f16x2<Element> dS_reg_fp16[(WARP_M_/32)*(kBlockN_/32)][4];
convert_pk_type<WARP_M_, kBlockN_, Element>(dS_reg_fp16, dS_reg);
{
//dq gemm, K*dS
gpu_gemm_B_in_reg<Is_store_K , false , false, Is_even_MN, K, kBlockK_, kBlockM_, kBlockN_, kBlockK_, WARP_M_, 2, Element>(gK, gK, K_lds, dS_reg_fp16, acc_dq, (binfo.actual_seqlen_k - n_block * kBlockN_), (binfo.actual_seqlen_q - m_block * kBlockM_), warp_id, seqlen_k_stride);
}
*(uint64_t*)&gK -= ((kBlockN_ * seqlen_k_stride) * sizeof(Element));
// if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0){
// printf("(binfo.actual_seqlen_k - n_block * kBlockN_) = %d\n", (binfo.actual_seqlen_k - n_block * kBlockN_));
// }
#if 1//defined(__gfx936__)
{
__syncthreads();
if (Is_preload_K && n_block > n_block_min){
prefetch_to_tmp_lds_wait<Is_even_MN, K, kBlockM_, kBlockN_, kBlockK_, WARP_M_, WARP_N_, Element>(gK, K_lds, (binfo.actual_seqlen_k - (n_block - 1) * kBlockN_), warp_id, seqlen_k_stride);
}
}
#else
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
#endif
}
{
int dq_lane_seq_idx = (lane_id >> 4);
int dq_lane_head_dim_idx = (lane_id & 15);
int dq_global_addr_offset=0;
#pragma unroll
for(int k_loop = 0; k_loop<(K/kBlockK_); k_loop++) {
int dq_block_buffer_store_global_offset = k_loop * kBlockK_;
#pragma unroll
for(int warp_m_idx=0; warp_m_idx<(WARP_M_/32); warp_m_idx++) {
int dq_warp_buffer_store_global_offset = (warp_id*WARP_M_ + warp_m_idx*32 + dq_lane_seq_idx*2) * seqlen_dq_stride;
#pragma unroll
for(int k_tile_idx=0; k_tile_idx<(kBlockK_/32); k_tile_idx++) {
dq_global_addr_offset = dq_block_buffer_store_global_offset + dq_warp_buffer_store_global_offset + k_tile_idx*32;
#pragma unroll 2
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for(int vec_index=0; vec_index<4; vec_index++) {
#pragma unroll 2
for(int min_tile_k=0; min_tile_k<2; min_tile_k++) {
int dq_global_addr = dq_global_addr_offset + (min_tile_m + vec_index*8)*seqlen_dq_stride + min_tile_k + dq_lane_head_dim_idx*2;
if(Is_even_MN || ((m_block * kBlockM_) + (warp_id*WARP_M_ + warp_m_idx*32 + dq_lane_seq_idx*2) + min_tile_m + vec_index*8) < binfo.actual_seqlen_q) {
dq_ptr[dq_global_addr] = DownCast<ElementAccum, Element>(acc_dq[k_loop * ((WARP_M_/32)*(kBlockK_/32)) + (warp_m_idx*(kBlockK_/32) + k_tile_idx)][min_tile_k + min_tile_m*2].f32[vec_index] * params.scale_softmax_rp_dropout);
}
}
}
}
}
}
}
}
}
#undef print_qk
#undef print_softmax_rescale_o
#undef print_dp
#undef print_ds
#ifdef DEBUGING
#define print_qk(block_id_m, bidb, bidh) {\
int qk_warp_n_offset = warp_id * WARP_M_ * params.seqlen_k; \
int qk_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + qk_warp_n_offset; \
for(int block_n_idx = 0; block_n_idx < kBlockN_/WARP_N_; ++block_n_idx){ \
for(int warp_n_idx = 0; warp_n_idx < WARP_N_/16; ++warp_n_idx){ \
for(int warp_m_idx = 0; warp_m_idx < WARP_M_/16; ++warp_m_idx){ \
for(int vec_idx = 0; vec_idx < 4 ; ++vec_idx) {\
int offset = qk_global_offset + block_n_idx * WARP_N_ + lane_id%16 * params.seqlen_k + warp_m_idx * params.seqlen_k * 16 + warp_n_idx*16 + lane_id/16*4 + vec_idx; \
kq_ptr[offset] = s_reg[block_n_idx][warp_n_idx*(WARP_M_/16) + warp_m_idx].f32[vec_idx]; \
} \
} \
} \
}\
}
#define print_softmax_rescale_o(block_id_m, bidb, bidh) { \
int qk_warp_n_offset = warp_id * WARP_M_ * params.seqlen_k; \
int qk_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + qk_warp_n_offset; \
for(int block_n_idx = 0; block_n_idx < kBlockN_/WARP_N_; ++block_n_idx){ \
for(int warp_n_idx = 0; warp_n_idx < WARP_N_/16; ++warp_n_idx){ \
for(int warp_m_idx = 0; warp_m_idx < WARP_M_/16; ++warp_m_idx){ \
for(int vec_idx = 0; vec_idx < 4 ; ++vec_idx) {\
int offset = qk_global_offset + block_n_idx * WARP_N_ + lane_id%16 * params.seqlen_k + warp_m_idx * params.seqlen_k * 16 + warp_n_idx*16 + lane_id/16*4 + vec_idx; \
s_ptr[offset] = s_reg[block_n_idx][warp_n_idx*(WARP_M_/16) + warp_m_idx].f32[vec_idx]; \
} \
} \
} \
}\
}
#define print_dp(block_id_m, bidb, bidh) {\
int dp_warp_n_offset = warp_id * WARP_M_ * params.seqlen_k; \
int dp_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + dp_warp_n_offset; \
for(int block_n_idx = 0; block_n_idx < kBlockN_/WARP_N_; ++block_n_idx){ \
for(int warp_n_idx = 0; warp_n_idx < WARP_N_/16; ++warp_n_idx){ \
for(int warp_m_idx = 0; warp_m_idx < WARP_M_/16; ++warp_m_idx){ \
for(int vec_idx = 0; vec_idx < 4 ; ++vec_idx) {\
int offset = dp_global_offset + block_n_idx * WARP_N_ + lane_id%16 * params.seqlen_k + warp_m_idx * params.seqlen_k * 16 + warp_n_idx*16 + lane_id/16*4 + vec_idx; \
dp_ptr[offset] = dp_reg[block_n_idx][warp_n_idx*(WARP_M_/16) + warp_m_idx].f32[vec_idx]; \
} \
} \
} \
}\
}
#define print_ds(block_id_m, bidb, bidh) {\
int ds_warp_n_offset = warp_id * WARP_M_ * params.seqlen_k; \
int ds_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + ds_warp_n_offset; \
for(int block_n_idx = 0; block_n_idx < kBlockN_/WARP_N_; ++block_n_idx){ \
for(int warp_n_idx = 0; warp_n_idx < WARP_N_/16; ++warp_n_idx){ \
for(int warp_m_idx = 0; warp_m_idx < WARP_M_/16; ++warp_m_idx){ \
for(int vec_idx = 0; vec_idx < 4 ; ++vec_idx) {\
int offset = ds_global_offset + block_n_idx * WARP_N_ + lane_id%16 * params.seqlen_k + warp_m_idx * params.seqlen_k * 16 + warp_n_idx*16 + lane_id/16*4 + vec_idx; \
ds_ptr[offset] = dS_reg[block_n_idx][warp_n_idx*(WARP_M_/16) + warp_m_idx].f32[vec_idx]; \
} \
} \
} \
}\
}
#endif
template<class Element, class ElementAccum, bool Is_dropout, bool Is_causal , bool Is_local, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Seq_parallel, int kBlockM_, int kBlockN_, int K, int K_v, int kBlockK_, int WARP_M_, int WARP_N_, int STAGES, int USE_BSHD_LAYOUT, typename Params>
__forceinline__ __device__ void compute_dq_1colblock_gfx938(Params &params, int bidb, int bidh, int m_block
) {
#ifdef DEBUGING
ElementAccum * kq_ptr = static_cast<ElementAccum*>(params.kq_ptr);
ElementAccum * s_ptr = static_cast<ElementAccum*>(params.s_ptr);
ElementAccum * dp_ptr = static_cast<ElementAccum*>(params.dp_ptr);
ElementAccum * ds_ptr = static_cast<ElementAccum*>(params.ds_ptr);
#endif
Element* q_ptr = static_cast<Element*>(params.q_ptr);
Element* k_ptr = static_cast<Element*>(params.k_ptr);
Element* v_ptr = static_cast<Element*>(params.v_ptr);
Element* o_ptr = static_cast<Element*>(params.o_ptr);
Element* dq_ptr = static_cast<Element*>(params.dq_ptr);
Element* dk_ptr = static_cast<Element*>(params.dk_ptr);
Element* dv_ptr = static_cast<Element*>(params.dv_ptr);
Element* do_ptr = static_cast<Element*>(params.do_ptr);
ElementAccum* softmax_lse_ptr = static_cast<ElementAccum*>(params.softmax_lse_ptr);
ElementAccum* dsoftmax_sum = static_cast<ElementAccum*>(params.dsoftmax_sum);
//flash-attention QK, kBlockN_==WARP_N_;
const int M_BLOCK_NUM = params.seqlen_q/kBlockM_;
const int N_BLOCK_NUM = params.seqlen_k/kBlockN_;
extern __shared__ Element smem[];
#if 1//defined(__gfx936__)
const bool Is_store_K = true;
const bool Is_preload_K = true;
const bool Is_preload_V = true;
#else
const bool Is_store_K = false;
const bool Is_preload_K = false;
const bool Is_preload_V = false;
#endif
const int K_prefetch_level = Is_preload_K ? 1 : 0;
const int V_prefetch_level = Is_preload_V ? 1 : 0;
const int Q_prefetch_level = 3;
Element* K_lds = (Element*)&(smem);
Element* Q_lds = (Element*)&(smem);
Element* dO_lds = (Element*)&(smem);
Element* V_lds = (Element*)&(smem) + kBlockN_* K;
int tidx = threadIdx.x;
int lane_id = threadIdx.x & 63; //lane id, 0-63
const flash::BlockInfo</*Varlen=*/!Is_even_MN, false, USE_BSHD_LAYOUT> binfo(params, bidb);
if (m_block < 0 || m_block * kBlockM_ >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return;
const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM_ - params.window_size_left) / kBlockN_);
const int n_block_max = (!Is_causal && !Is_local) ? ceil_div(binfo.actual_seqlen_k, kBlockN_) : std::min(ceil_div(binfo.actual_seqlen_k, kBlockN_), flash::ceil_div((m_block + 1) * kBlockM_ + params.window_size_right, kBlockN_));
int seqlen_q_stride = params.q_row_stride;
int seqlen_k_stride = params.k_row_stride;
int seqlen_v_stride = params.v_row_stride;
int seqlen_do_stride = params.do_row_stride;
int seqlen_o_stride = params.o_row_stride;
int seqlen_dq_stride = params.dq_row_stride;
// We move K and V to the last block.
const int row_offset_q = binfo.q_offset1(params.q_batch_stride, params.q_row_stride, bidb) + binfo.q_offset2(params.q_head_stride,bidh) + m_block * kBlockM_ * seqlen_q_stride;
const int row_offset_k = binfo.k_offset1(params.k_batch_stride, params.k_row_stride, bidb) + binfo.k_offset2(params.k_head_stride,bidh/params.h_h_k_ratio) + (n_block_max - 1) * kBlockN_ * seqlen_k_stride;
const int row_offset_v = binfo.k_offset1(params.v_batch_stride, params.v_row_stride, bidb) + binfo.k_offset2(params.v_head_stride,bidh/params.h_h_k_ratio) + (n_block_max - 1) * kBlockN_ * seqlen_v_stride;
const int row_offset_dO = binfo.q_offset1(params.do_batch_stride, params.do_row_stride, bidb) + binfo.q_offset2(params.do_head_stride,bidh) + m_block * kBlockM_ * seqlen_do_stride;
const int row_offset_o = binfo.q_offset1(params.o_batch_stride, params.o_row_stride, bidb) + binfo.q_offset2(params.o_head_stride,bidh) + m_block * kBlockM_ * seqlen_o_stride;
const int row_offset_dq = binfo.q_offset1(params.dq_batch_stride, params.dq_row_stride, bidb) + binfo.q_offset2(params.dq_head_stride,bidh) + m_block * kBlockM_ * seqlen_dq_stride;
const int row_offset_lse = params.cu_seqlens_q == nullptr ? (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM_ : bidh * params.total_q + binfo.sum_s_q + m_block * kBlockM_;
const int row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM_;
auto gQ = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(q_ptr) + row_offset_q, seqlen_q_stride);
auto gK = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(k_ptr) + row_offset_k, seqlen_k_stride);
auto gV = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(v_ptr) + row_offset_v, seqlen_v_stride);
auto gdO = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(do_ptr) + row_offset_dO, seqlen_do_stride);
Element * gO = reinterpret_cast<Element *>(o_ptr) + row_offset_o;
dq_ptr = reinterpret_cast<Element *>(dq_ptr) + row_offset_dq;
ElementAccum *gLSE = reinterpret_cast<ElementAccum *>(softmax_lse_ptr) + row_offset_lse;
ElementAccum *gdPsum = reinterpret_cast<ElementAccum *>(dsoftmax_sum) + row_offset_dpsum;
constexpr int n_masking_steps = (!Is_causal && !Is_local)
? 1
: ((Is_even_MN && Is_causal) ? flash::ceil_div(kBlockM_, kBlockN_) : flash::ceil_div(kBlockM_, kBlockN_) + 1);
// int warp_id =0;
int warp_id_vec = threadIdx.x / 64; //warp id in a block
int warp_id = __builtin_amdgcn_readfirstlane(warp_id_vec);
union_vec4_f16x2<Element> q_reg[(K/kBlockK_)*((WARP_M_*kBlockK_)/(32*32))*2];
union_vec4_f16x2<Element> dO_reg[(K_v/kBlockK_)*((WARP_M_*kBlockK_)/(32*32))*2];
union_vec4_fp32 acc_dq[(K/kBlockK_) * ((WARP_M_/32)*(kBlockK_/32))][4]={0};
float lse[WARP_M_/16];
#pragma unroll
for (int mi = 0; mi < (WARP_M_/32); ++mi) {
for(int min_tile_m = 0; min_tile_m<2; min_tile_m++) {
int lse_idx = warp_id*WARP_M_ + mi*32 + (lane_id & 15) + min_tile_m * 16;
lse[mi*2 + min_tile_m] = (Is_even_MN || lse_idx < binfo.actual_seqlen_q - m_block * kBlockM_) ? gLSE[lse_idx] : INFINITY;
}
}
float dP_sum_reg[WARP_M_/16];
#pragma unroll
for (int mi = 0; mi < (WARP_M_/32); ++mi) {
for(int min_tile_m = 0; min_tile_m<2; min_tile_m++) {
int dP_sum_idx = warp_id*WARP_M_ + mi*32 + (lane_id & 15) + min_tile_m * 16;
dP_sum_reg[mi*2 + min_tile_m] = gdPsum[dP_sum_idx];
}
}
prefetch_to_vgpr_gfx938<true, kBlockM_, K, Element, ElementAccum, Is_even_MN>(gQ, Q_lds, q_reg, (binfo.actual_seqlen_q - m_block * kBlockM_), warp_id);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
prefetch_to_vgpr_gfx938<true, kBlockM_, K, Element, ElementAccum, Is_even_MN>(gdO, dO_lds, dO_reg, (binfo.actual_seqlen_q - m_block * kBlockM_), warp_id);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
if constexpr (Is_preload_V){
prefetch_to_lds_gfx938<true, kBlockN_, K_v, Element, ElementAccum, Is_even_MN>(gV, 0, V_lds, (binfo.actual_seqlen_k - (n_block_max - 1) * kBlockN_), warp_id);
}
if constexpr (Is_preload_K){
prefetch_to_lds_gfx938<true, kBlockN_, K, Element, ElementAccum, Is_even_MN>(gK, 0, K_lds, (binfo.actual_seqlen_k - (n_block_max - 1) * kBlockN_), warp_id);
}
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
for (int n_block = n_block_max - 1; n_block >= n_block_min ; --n_block) {
union_vec4_f16x2<Element> v_reg[((WARP_N_*kBlockK_)/(32*32))*2];
union_vec4_fp32 dp_reg[(WARP_M_/32)*(kBlockN_/32)][4]= {0};
//dP gemm
gemm_tt_kq_gfx938<false, Is_preload_K, Is_even_MN, 3, V_prefetch_level, K_v, kBlockM_, kBlockN_, kBlockK_, WARP_N_, WARP_N_, STAGES, Element>(
gdO, gV, dO_lds, V_lds, (binfo.actual_seqlen_q - m_block * kBlockM_), (binfo.actual_seqlen_k - n_block * kBlockN_), dO_reg, v_reg, dp_reg, warp_id, seqlen_do_stride, seqlen_v_stride
);
#ifdef DEBUGING
print_dp(m_block, bidb, bidh);
#endif
union_vec4_f16x2<Element> k_reg[((WARP_M_*kBlockK_)/(32*32))*2];
//c mini tile is 32*32
union_vec4_fp32 s_reg[(WARP_N_/32)*(kBlockM_/32)][4]= {0};
//qk gemm
gemm_tt_kq_gfx938<Is_store_K, false, Is_even_MN, Q_prefetch_level, K_prefetch_level, K, kBlockM_, kBlockN_, kBlockK_, WARP_N_, WARP_N_, STAGES, Element>(
gQ, gK, Q_lds, K_lds, (binfo.actual_seqlen_q - m_block * kBlockM_), (binfo.actual_seqlen_k - n_block * kBlockN_), q_reg, k_reg, s_reg, warp_id, seqlen_q_stride, seqlen_k_stride
);
*(uint64_t*)&gV -= ((kBlockN_ * seqlen_v_stride) * sizeof(Element));
if (Is_preload_V && n_block > n_block_min){
prefetch_to_lds_gfx938<true, kBlockN_, K_v, Element, ElementAccum, Is_even_MN>(gV, 0, V_lds, (binfo.actual_seqlen_k - (n_block - 1) * kBlockN_), warp_id);
}
apply_mask_bwd_gfx938<Is_even_MN, Is_local ? 3 : (Is_causal ? 1 : 0)>(s_reg, binfo.actual_seqlen_q - m_block * kBlockM_ - warp_id * 32, binfo.actual_seqlen_k - n_block * kBlockN_, (m_block * kBlockM_ + warp_id * 32) - (n_block * kBlockN_), params.window_size_left, params.window_size_right);
#ifdef DEBUGING
print_qk(m_block, bidb, bidh);
#endif
scale_apply_exp2_bwd_seq_q_major</*scale_max=*/false, WARP_M_, kBlockN_, union_vec4_fp32, ElementAccum>(s_reg, lse, params.scale_softmax_log2);
#ifdef DEBUGING
print_softmax_rescale_o(m_block, bidb, bidh)
#endif
auto pointwise_mult = [](float p, float dp, float d) {
return p * (!Is_dropout || p >= 0 ? dp - d : d);
// return p * (dp - d);
};
union_vec4_fp32 dS_reg[(WARP_M_/32)*(kBlockN_/32)][4];
#pragma unroll
for (int ni = 0; ni < (kBlockN_/32); ++ni) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for (int mi = 0; mi < (WARP_M_/32); ++mi) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
// result register ds_reg reuse dp_reg
dS_reg[ni*(WARP_M_/32) + mi][min_tile_n*2 + min_tile_m].f32[vec_idx] = pointwise_mult(
s_reg[ni*(WARP_M_/32) + mi][min_tile_n*2 + min_tile_m].f32[vec_idx],
dp_reg[ni*(WARP_M_/32) + mi][min_tile_n*2 + min_tile_m].f32[vec_idx],
dP_sum_reg[min_tile_m + mi*2]);
}
}
}
}
}
#ifdef DEBUGING
print_ds(m_block, bidb, bidh);
#endif
union_vec4_f16x2<Element> dS_reg_fp16[(WARP_M_/32)*(kBlockN_/32)*2];
convert_pk_type_gfx938<WARP_M_, kBlockN_, Element>(dS_reg_fp16, dS_reg);
{
//dq gemm, K*dS
gpu_gemm_B_in_reg_gfx938<Is_store_K , false , Is_even_MN, K, kBlockK_, kBlockM_, kBlockN_, kBlockK_, WARP_M_, 2, Element>(gK, gK, K_lds, dS_reg_fp16, acc_dq, (binfo.actual_seqlen_k - n_block * kBlockN_), (binfo.actual_seqlen_q - m_block * kBlockM_), warp_id, seqlen_k_stride);
}
*(uint64_t*)&gK -= ((kBlockN_ * seqlen_k_stride) * sizeof(Element));
#if 1//defined(__gfx936__)
{
__syncthreads();
if (Is_preload_K && n_block > n_block_min){
prefetch_to_lds_gfx938<true, kBlockN_, K, Element, ElementAccum, Is_even_MN>(gK, 0, K_lds, (binfo.actual_seqlen_k - (n_block - 1) * kBlockN_), warp_id);
}
}
#else
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
#endif
}
//mmac
{
int dq_lane_seq_idx = (lane_id >> 4);
int dq_lane_head_dim_idx = (lane_id & 15);
int dq_global_addr_offset=0;
#pragma unroll
for(int k_loop = 0; k_loop<(K/kBlockK_); k_loop++) {
int dq_block_buffer_store_global_offset = k_loop * kBlockK_;
#pragma unroll
for(int warp_m_idx=0; warp_m_idx<(WARP_M_/32); warp_m_idx++) {
int dq_warp_buffer_store_global_offset = (warp_id*WARP_M_ + warp_m_idx*32 + dq_lane_seq_idx) * seqlen_dq_stride;
#pragma unroll
for(int k_tile_idx=0; k_tile_idx<(kBlockK_/32); k_tile_idx++) {
dq_global_addr_offset = dq_block_buffer_store_global_offset + dq_warp_buffer_store_global_offset + k_tile_idx*32;
#pragma unroll 2
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for(int vec_index=0; vec_index<4; vec_index++) {
#pragma unroll 2
for(int min_tile_k=0; min_tile_k<2; min_tile_k++) {
int dq_global_addr = dq_global_addr_offset + (min_tile_m*16 + vec_index*4)*seqlen_dq_stride + min_tile_k + dq_lane_head_dim_idx*2;
if(Is_even_MN || ((m_block * kBlockM_) + (warp_id*WARP_M_ + warp_m_idx*32 + dq_lane_seq_idx) + min_tile_m*16 + vec_index*4) < binfo.actual_seqlen_q) {
dq_ptr[dq_global_addr] = DownCast<ElementAccum, Element>(acc_dq[k_loop * ((WARP_M_/32)*(kBlockK_/32)) + (warp_m_idx*(kBlockK_/32) + k_tile_idx)][min_tile_k + min_tile_m*2].f32[vec_index] * params.scale_softmax_rp_dropout);
}
}
}
}
}
}
}
}
}
#undef print_qk
#undef print_softmax_rescale_o
#undef print_dp
#undef print_ds
#define print_kq(block_id_m, bidb, bidh) { \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
int qk_warp_m_id = (warp_id / (kBlockN_/WARP_N_)); \
int qk_warp_n_id = (warp_id & (kBlockN_/WARP_N_ - 1)); \
int qk_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + qk_warp_n_id*WARP_N_ + qk_warp_m_id*WARP_M_*params.seqlen_k; \
for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) { \
for(int m_block_idx=0; m_block_idx<kBlockM_/WARP_M_; m_block_idx++) { \
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) { \
for(int reg_id=0; reg_id<4; reg_id++) { \
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { \
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { \
if(((n_block*kBlockN_ + qk_warp_n_id*WARP_N_ + n_idx * 32 + ((lane_id & 15) << 1) + min_tile_n) < params.seqlen_k) && \
((block_id_m*kBlockM_ + qk_warp_m_id*WARP_M_ + m_idx*32 + reg_id * 8 + min_tile_m + ((lane_id / 16) * 2)) < params.seqlen_q)) { \
int offset = qk_global_offset + n_idx * 32 + m_block_idx * WARP_M_ * params.seqlen_k + m_idx*32*params.seqlen_k+ ((lane_id & 15) << 1) + min_tile_m*params.seqlen_k + ((lane_id / 16) * 2) *params.seqlen_k + min_tile_n ; \
kq_ptr[offset + reg_id * 8 *params.seqlen_k] = 0;(s_reg[m_block_idx * (WARP_M_/32) *(WARP_N_/32)+ m_idx*(WARP_N_/32) + n_idx][min_tile_m*2 + min_tile_n].f32[reg_id]); \
} \
} \
} \
} \
} \
} \
} \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
}
// #define print_kq(block_id_m, bidb, bidh) { \
// __builtin_amdgcn_sched_barrier(0);\
// __builtin_amdgcn_s_waitcnt(0);\
// __syncthreads();\
// __builtin_amdgcn_sched_barrier(0);\
// int qk_warp_m_id = (warp_id / (kBlockN_/WARP_N_)); \
// int qk_warp_n_id = (warp_id & (kBlockN_/WARP_N_ - 1)); \
// int qk_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
// + block_id_m*kBlockM_*params.seqlen_k + qk_warp_n_id*WARP_N_ + qk_warp_m_id*WARP_M_*params.seqlen_k; \
// for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) { \
// for(int m_block_idx=0; m_block_idx<kBlockM_/WARP_M_; m_block_idx++) { \
// for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) { \
// for(int reg_id=0; reg_id<4; reg_id++) { \
// for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { \
// for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { \
// if(((n_block*kBlockN_ + qk_warp_n_id*WARP_N_ + n_idx * 32 + (lane_id & 15) + min_tile_n*16) < params.seqlen_k) && \
// ((block_id_m*kBlockM_ + qk_warp_m_id*WARP_M_ + m_idx*32 + reg_id + min_tile_m*16 + ((lane_id / 16) * 4)) < params.seqlen_q)) { \
// int offset = qk_global_offset + n_idx * 32 + m_block_idx * WARP_M_ * params.seqlen_k + m_idx*32*params.seqlen_k+ (lane_id & 15) + min_tile_m*params.seqlen_k*16 + ((lane_id / 16) * 4) *params.seqlen_k + min_tile_n*16 ; \
// kq_ptr[offset + reg_id *params.seqlen_k] = 0;(s_reg[m_block_idx * (WARP_M_/32) *(WARP_N_/32)+ m_idx*(WARP_N_/32) + n_idx][min_tile_m*2 + min_tile_n].f32[reg_id]); \
// } \
// } \
// } \
// } \
// } \
// } \
// } \
// __builtin_amdgcn_sched_barrier(0);\
// __builtin_amdgcn_s_waitcnt(0);\
// __syncthreads();\
// __builtin_amdgcn_sched_barrier(0);\
// }
#define print_softmax_rescale_o(block_id_m, bidb, bidh) { \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
int s_warp_m_id = (warp_id / (kBlockN_/WARP_N_)); \
int s_warp_n_id = (warp_id & (kBlockN_/WARP_N_ - 1)); \
int s_global_offset = bidb*(params.h * params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ + block_id_m*kBlockM_*params.seqlen_k + s_warp_n_id*WARP_N_ + s_warp_m_id*WARP_M_*params.seqlen_k; \
for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) { \
for(int m_block_idx=0; m_block_idx<kBlockM_/WARP_M_; m_block_idx++) { \
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) { \
for(int reg_id=0; reg_id<4; reg_id++) { \
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { \
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { \
if(((n_block*kBlockN_ + s_warp_n_id*WARP_N_ + n_idx * 32 + ((lane_id & 15) << 1) + min_tile_n) < params.seqlen_k) && \
((block_id_m*kBlockM_ + s_warp_m_id*WARP_M_ + m_idx*32 + reg_id * 8 + min_tile_m + ((lane_id / 16) * 2)) < params.seqlen_q)) { \
int offset = s_global_offset + n_idx * 32 + m_block_idx * WARP_M_ * params.seqlen_k + m_idx*32*params.seqlen_k+ ((lane_id & 15) << 1) + min_tile_m*params.seqlen_k + ((lane_id / 16) * 2)*params.seqlen_k + min_tile_n ;\
s_ptr[offset + reg_id * 8 * params.seqlen_k] = (s_reg[m_block_idx * (WARP_M_/32) *(WARP_N_/32)+ m_idx*(WARP_N_/32) + n_idx][min_tile_n + min_tile_m*2].f32[reg_id]);\
} \
} \
} \
} \
} \
} \
} \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
}
#define print_ds(block_id_m, bidb, bidh) { \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
int ds_warp_m_id = (warp_id / (kBlockN_/WARP_N_)); \
int ds_warp_n_id = (warp_id & (kBlockN_/WARP_N_ - 1)); \
int ds_global_offset = bidb*(params.h * params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ + block_id_m*kBlockM_*params.seqlen_k + ds_warp_n_id*WARP_N_ + ds_warp_m_id*WARP_M_*params.seqlen_k; \
for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) { \
for(int m_block_idx=0; m_block_idx<kBlockM_/WARP_M_; m_block_idx++) { \
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) { \
for(int reg_id=0; reg_id<4; reg_id++) { \
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { \
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { \
if(((n_block*kBlockN_ + ds_warp_n_id*WARP_N_ + n_idx * 32 + ((lane_id & 15) << 1) + min_tile_n) < params.seqlen_k) && \
((block_id_m*kBlockM_ + ds_warp_m_id*WARP_M_ + m_idx*32 + reg_id * 8 + min_tile_m + ((lane_id / 16) * 2)) < params.seqlen_q)) { \
int offset = ds_global_offset + n_idx * 32 + m_block_idx * WARP_M_ * params.seqlen_k + m_idx*32*params.seqlen_k+ ((lane_id & 15) << 1) + min_tile_m*params.seqlen_k + ((lane_id / 16) * 2)*params.seqlen_k + min_tile_n ;\
ds_ptr[offset + reg_id * 8 * params.seqlen_k] = (dS_reg[m_block_idx * (WARP_M_/32) *(WARP_N_/32)+ m_idx*(WARP_N_/32) + n_idx][min_tile_n + min_tile_m*2].f32[reg_id]);\
} \
} \
} \
} \
} \
} \
} \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
}
#define print_ds_fp16(block_id_m, bidb, bidh) { \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
int ds_warp_m_id = (warp_id / (kBlockN_/WARP_N_)); \
int ds_warp_n_id = (warp_id & (kBlockN_/WARP_N_ - 1)); \
int ds_global_offset = bidb*(params.h * params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ + block_id_m*kBlockM_*params.seqlen_k + ds_warp_n_id*WARP_N_ + ds_warp_m_id*WARP_M_*params.seqlen_k; \
for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) { \
for(int m_block_idx=0; m_block_idx<kBlockM_/WARP_M_; m_block_idx++) { \
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) { \
for(int reg_id=0; reg_id<4; reg_id++) { \
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { \
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { \
if(((n_block*kBlockN_ + ds_warp_n_id*WARP_N_ + n_idx * 32 + ((lane_id & 15) << 1) + min_tile_n) < params.seqlen_k) && \
((block_id_m*kBlockM_ + ds_warp_m_id*WARP_M_ + m_idx*32 + reg_id * 8 + min_tile_m + ((lane_id / 16) * 2)) < params.seqlen_q)) { \
int offset = ds_global_offset + n_idx * 32 + m_block_idx * WARP_M_ * params.seqlen_k + m_idx*32*params.seqlen_k+ ((lane_id & 15) << 1) + min_tile_m*params.seqlen_k + ((lane_id / 16) * 2)*params.seqlen_k + min_tile_n ;\
ds_ptr[offset + reg_id * 8 * params.seqlen_k] = UpCast<Element,float>(dS_reg_fp16[m_block_idx * (WARP_M_/32) *(WARP_N_/32)+ m_idx*(WARP_N_/32) + n_idx][min_tile_n + min_tile_m*2].f16[reg_id]);\
} \
} \
} \
} \
} \
} \
} \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
}
#define print_dp(block_id_m, bidb, bidh) { \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
int dp_warp_m_id = (warp_id / (kBlockN_/WARP_N_)); \
int dp_warp_n_id = (warp_id & (kBlockN_/WARP_N_ - 1)); \
int dp_global_offset = bidb*(params.h * params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ + block_id_m*kBlockM_*params.seqlen_k + dp_warp_n_id*WARP_N_ + dp_warp_m_id*WARP_M_*params.seqlen_k; \
for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) {\
for(int m_block_idx=0; m_block_idx<kBlockM_/WARP_M_; m_block_idx++) { \
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) { \
for(int reg_id=0; reg_id<4; reg_id++) { \
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { \
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { \
if(((n_block*kBlockN_ + dp_warp_n_id*WARP_N_ + n_idx * 32 + ((lane_id & 15) << 1) + min_tile_n) < params.seqlen_k) && \
((block_id_m*kBlockM_ + dp_warp_m_id*WARP_M_ + m_idx*32 + reg_id * 8 + min_tile_m + ((lane_id / 16) * 2)) < params.seqlen_q)) { \
int offset = dp_global_offset + n_idx * 32 + m_block_idx * WARP_M_ * params.seqlen_k + m_idx*32*params.seqlen_k+ ((lane_id & 15) << 1) + min_tile_m*params.seqlen_k + ((lane_id / 16) * 2)*params.seqlen_k + min_tile_n ;\
dp_ptr[offset + reg_id * 8 * params.seqlen_k] = (dp_reg[m_block_idx * (WARP_M_/32) *(WARP_N_/32) + m_idx*(WARP_N_/32) + n_idx][min_tile_n + min_tile_m*2].f32[reg_id]);\
} \
} \
} \
} \
} \
}\
} \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
}
template<class Element, class ElementAccumType, bool Is_dropout, bool Is_causal , bool Is_local, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Seq_parallel=false, int kBlockM_, int kBlockN_, int K, int K_v, int kBlockK_, int WARP_M_, int WARP_N_, bool USE_BSHD_LAYOUT, typename Params>
__forceinline__ __device__ void compute_dk_dv_1colblock(Params &params, int bidb, int bidh, int n_block
) {
#ifdef DEBUGING
ElementAccumType * kq_ptr = static_cast<ElementAccumType*>(params.kq_ptr);
ElementAccumType * s_ptr = static_cast<ElementAccumType*>(params.s_ptr);
ElementAccumType * dp_ptr = static_cast<ElementAccumType*>(params.dp_ptr);
ElementAccumType * ds_ptr = static_cast<ElementAccumType*>(params.ds_ptr);
#endif
Element* q_ptr = static_cast<Element*>(params.q_ptr);
Element* k_ptr = static_cast<Element*>(params.k_ptr);
Element* v_ptr = static_cast<Element*>(params.v_ptr);
Element* o_ptr = static_cast<Element*>(params.o_ptr);
Element* p_ptr = static_cast<Element*>(params.p_ptr);
// Element* dq_ptr = static_cast<Element*>(params.dq_ptr);
Element* dk_ptr = static_cast<Element*>(params.dk_ptr);
Element* dv_ptr = static_cast<Element*>(params.dv_ptr);
Element* do_ptr = static_cast<Element*>(params.do_ptr);
ElementAccumType* softmax_lse_ptr = static_cast<ElementAccumType*>(params.softmax_lse_ptr);
ElementAccumType* dsoftmax_sum = static_cast<ElementAccumType*>(params.dsoftmax_sum);
//flash-attention QK, kBlockN_==WARP_N_;
// static_assert(kBlockM_=WARP_M_,"Error: kBlockM_ not equal WARP_M_!");
const int WARP_NUM = (kBlockM_*kBlockN_)/(WARP_M_*WARP_N_);
const int M_BLOCK_NUM = params.seqlen_q/kBlockM_;
const int N_BLOCK_NUM = params.seqlen_k/kBlockN_;
extern __shared__ Element smem[];
int K_lds_ratio;
const int K_prefetch_level = 3;
const int STAGES = 2;
const bool Is_store_Q = true;
const bool Is_store_dO = true;
const bool Is_preload_Q = true;
const bool Is_preload_dO = true;
const int dP_dO_prefetch_level = Is_store_dO ? 1 : 0;
const int Q_prefetech_level = Is_preload_Q ? 1 : 0;
if constexpr (K_prefetch_level == 2){
K_lds_ratio = (K / kBlockK_) / 2;
} else {
K_lds_ratio = (K_prefetch_level == 3) ? 0 : STAGES;
}
// Element* K_lds = (Element*)&(smem);
// Element* Q_lds = K_lds + (kBlockN_/32) * (kBlockK_/32)*(32*34) * K_lds_ratio;
// Element* V_lds = K_prefetch_level == 2 ? Q_lds : K_lds;
// Element* dO_lds = Q_lds;
Element* K_lds = (Element*)&(smem);
Element* dO_lds = K_lds + (kBlockN_/32) * (kBlockK_/32)*(32*34) * K_lds_ratio;
Element* V_lds = K_prefetch_level == 2 ? dO_lds : K_lds;
Element* Q_lds = Is_store_Q ? dO_lds + (kBlockM_/32) * (K_v/32)*(32*34) : dO_lds;
#if 0//defined(__gfx936__)
auto pointwise_mult = [](vec2_fp32 p, vec2_fp32 dp, vec2_fp32 d) {
auto d0 = (!Is_dropout || p[0] >= 0 ? dp[0] - d[0] : d[0]);
auto d1 = (!Is_dropout || p[1] >= 0 ? dp[1] - d[1] : d[1]);
// return vec2_fp32{p[0]*d0,p[1]*d1};
// return __builtin_hcu_pk_mul_f32(p, vec2_fp32{d0, d1});
return __builtin_hcu_v_pk_mul_f32(p, vec2_fp32{d0, d1});
};
#else
auto pointwise_mult = [](float p, float dp, float d) {
return p * (!Is_dropout || p >= 0 ? dp - d : d);
};
#endif
int tidx = threadIdx.x;
int lane_id = threadIdx.x & 63; //lane id, 0-63
const flash::BlockInfo</*Varlen=*/!Is_even_MN, false, USE_BSHD_LAYOUT> binfo(params, bidb);
if (n_block * kBlockN_ >= binfo.actual_seqlen_k || binfo.actual_seqlen_q == 0) return;
const int m_block_min = (!Is_causal && !Is_local) ? 0 : std::max(0, (n_block * kBlockN_ - params.window_size_right) / kBlockM_);
const int m_block_max = !Is_local ? ceil_div(binfo.actual_seqlen_q, kBlockM_) : std::min(ceil_div(binfo.actual_seqlen_q, kBlockM_), ceil_div((n_block + 1) * kBlockN_ + params.window_size_left, kBlockM_));
int seqlen_q_stride = params.q_row_stride;
int seqlen_k_stride = params.k_row_stride;
int seqlen_v_stride = params.v_row_stride;
int seqlen_do_stride = params.do_row_stride;
int seqlen_o_stride = params.o_row_stride;
int seqlen_dk_stride = params.dk_row_stride;
int seqlen_dv_stride = params.dv_row_stride;
// We move K and V to the last block.
const int row_offset_q = binfo.q_offset1(params.q_batch_stride, params.q_row_stride, bidb) + binfo.q_offset2(params.q_head_stride,bidh) + (m_block_max - 1) * kBlockM_* seqlen_q_stride;
const int row_offset_k = binfo.k_offset1(params.k_batch_stride, params.k_row_stride, bidb) + binfo.k_offset2(params.k_head_stride,bidh/params.h_h_k_ratio) + n_block * kBlockN_ * seqlen_k_stride;
const int row_offset_v = binfo.k_offset1(params.v_batch_stride, params.v_row_stride, bidb) + binfo.k_offset2(params.v_head_stride,bidh/params.h_h_k_ratio) + n_block * kBlockN_ * seqlen_v_stride;
const int row_offset_dO = binfo.q_offset1(params.do_batch_stride, params.do_row_stride, bidb) + binfo.q_offset2(params.do_head_stride,bidh) + (m_block_max - 1) * kBlockM_ * seqlen_do_stride;
const int row_offset_o = binfo.q_offset1(params.o_batch_stride, params.o_row_stride, bidb) + binfo.q_offset2(params.o_head_stride,bidh) + (m_block_max - 1) * kBlockM_ * seqlen_o_stride;
// const int row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + (m_block_max - 1) * kBlockM_;
const int row_offset_lse = params.cu_seqlens_q == nullptr ? (bidb * params.h + bidh) * params.seqlen_q + (m_block_max - 1) * kBlockM_ : bidh * params.total_q + binfo.sum_s_q + (m_block_max - 1) * kBlockM_;
const int row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + (m_block_max - 1) * kBlockM_;
// Element * gQ = reinterpret_cast<Element *>(q_ptr) + row_offset_q;
auto gQ = tcp_cache_swizzle_func<K, Element>(reinterpret_cast<Element *>(q_ptr) + row_offset_q);
// Element * gK = reinterpret_cast<Element *>(k_ptr) + row_offset_k;
auto gK = tcp_cache_swizzle_func<K, Element>(reinterpret_cast<Element *>(k_ptr) + row_offset_k);
// Element * gV = reinterpret_cast<Element *>(v_ptr) + row_offset_v;
auto gV = tcp_cache_swizzle_func<K_v, Element>(reinterpret_cast<Element *>(v_ptr) + row_offset_v);
// Element * gdO = reinterpret_cast<Element *>(do_ptr) + row_offset_dO;
auto gdO = tcp_cache_swizzle_func<K_v, Element>(reinterpret_cast<Element *>(do_ptr) + row_offset_dO);
Element * gO = reinterpret_cast<Element *>(o_ptr) + row_offset_o;
// Element * gdQ = reinterpret_cast<Element *>(dq_ptr) + row_offset_dq;
ElementAccumType *gLSE = reinterpret_cast<ElementAccumType *>(softmax_lse_ptr) + row_offset_lse;
ElementAccumType *gdPsum = reinterpret_cast<ElementAccumType *>(dsoftmax_sum) + row_offset_dpsum;
constexpr int m_masking_steps = (!Is_causal && !Is_local)
? 0
: flash::ceil_div(kBlockN_, kBlockM_);
/***************************************************************************************************************************/
// int warp_id =0;
int warp_id_vec = threadIdx.x / 64; //warp id in a block
int warp_id = __builtin_amdgcn_readfirstlane(warp_id_vec);
union_vec2_f16x2<Element> k_reg[(K/kBlockK_)*((WARP_N_*kBlockK_)/(32*32))*2/((K_prefetch_level == 3)? 1 : 2)][2]; //ds_read mini size is 32*32,2 is seq, 4 is head dim
union_vec2_f16x2<Element> v_reg[(K_v/kBlockK_)*((WARP_N_*kBlockK_)/(32*32))*2][2];
__builtin_amdgcn_sched_barrier(0);
prefetch_to_vgpr<K_v, kBlockN_, kBlockK_, WARP_N_, Element, ElementAccumType, Is_even_MN>(gV, V_lds, v_reg, (binfo.actual_seqlen_k - n_block * kBlockN_), seqlen_v_stride);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
prefetch_to_vgpr<K, kBlockN_, kBlockK_, WARP_N_, Element, ElementAccumType, Is_even_MN>(gK, K_lds, k_reg, (binfo.actual_seqlen_k - n_block * kBlockN_), seqlen_k_stride);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
// __builtin_amdgcn_sched_barrier(0);
// __builtin_amdgcn_s_waitcnt(0);
// __syncthreads();
// __builtin_amdgcn_sched_barrier(0);
if constexpr (Is_preload_Q){
prefetch_to_tmp_lds_wait<Is_even_MN, K, kBlockN_, kBlockM_, kBlockK_, WARP_N_, WARP_M_, Element>(gQ, Q_lds, (binfo.actual_seqlen_q - (m_block_max - 1) * kBlockM_), warp_id, seqlen_q_stride);
}
if constexpr (Is_preload_dO){
prefetch_to_tmp_lds_wait<Is_even_MN, K_v, kBlockN_, kBlockM_, kBlockK_, WARP_N_, WARP_M_, Element>(gdO, dO_lds, (binfo.actual_seqlen_q - (m_block_max - 1) * kBlockM_), warp_id, seqlen_do_stride);
}
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
union_vec4_fp32 acc_dv[(K_v/kBlockK_) * ((WARP_N_/32)*(kBlockK_/32))][4]={0};
union_vec4_fp32 acc_dk[(K/kBlockK_) * ((WARP_N_/32)*(kBlockK_/32))][4]={0};
for (int m_block = m_block_max - 1; m_block >= m_block_min; --m_block) {
union_vec2_f16x2<Element> q_reg[((WARP_M_*kBlockK_)/(32*32))*2][2];
// int warp_id =0;
int warp_id_vec = threadIdx.x / 64; //warp id in a block
warp_id = __builtin_amdgcn_readfirstlane(warp_id_vec);
//c mini tile is 32*32
union_vec4_fp32 s_reg[(WARP_N_/32)*(kBlockM_/32)][4]= {0};
//qk gemm
gemm_tt_kq<Is_store_Q, Is_preload_dO, Is_even_MN, K_prefetch_level, Q_prefetech_level, K, kBlockN_, kBlockM_, kBlockK_, WARP_N_, WARP_M_, STAGES, Element>(gK, gQ, K_lds, Q_lds, (binfo.actual_seqlen_k - n_block * kBlockN_), (binfo.actual_seqlen_q - m_block * kBlockM_), k_reg, q_reg, s_reg, warp_id, seqlen_k_stride, seqlen_q_stride);
float lse[kBlockM_/4];
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
#pragma unroll
for (int mi = 0; mi < (kBlockM_/32); ++mi) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
const int lse_idx = mi*32 + vec_idx*8 + (lane_id >> 4)*2 + min_tile_m;
lse[(mi*2 + min_tile_m)*4 + vec_idx] = Is_even_MN || lse_idx < binfo.actual_seqlen_q - m_block * kBlockM_ ? gLSE[lse_idx] : INFINITY;
}
}
}
apply_mask_bwd<Is_even_MN, Is_local ? 3 : (Is_causal ? 2 : 0)>(s_reg, binfo.actual_seqlen_k - n_block * kBlockN_ - warp_id * 32, binfo.actual_seqlen_q - m_block * kBlockM_, (n_block * kBlockN_ + warp_id * 32) - m_block * kBlockM_, params.window_size_right, params.window_size_left);
#ifdef DEBUGING
print_kq(m_block, bidb, bidh);
#endif
float dP_sum_reg[kBlockM_/4];
#pragma unroll
for (int vec_idx = 0; vec_idx < (kBlockM_/8); ++vec_idx) {
for(int min_tile_m = 0; min_tile_m<2; min_tile_m++) {
dP_sum_reg[vec_idx*2 + min_tile_m] = gdPsum[vec_idx*8 + ((lane_id >> 4)*2) + min_tile_m];
}
}
{
scale_apply_exp2_bwd</*scale_max=*/false, kBlockM_, WARP_N_>(s_reg, lse, params.scale_softmax_log2);
}
#ifdef DEBUGING
print_softmax_rescale_o(m_block, bidb, bidh);
#endif
// //TODO:drop
union_vec2_f16x2<Element> p_reg[(kBlockM_/32)*(WARP_N_/32)][4];
convert_pk_type<kBlockM_, WARP_N_, Element>(p_reg, s_reg);
//QK(seq_q, seq_kv), seq_q is continuous, seq_kv is not continuous
// __builtin_amdgcn_sched_barrier(0);
// __builtin_amdgcn_s_waitcnt(0);
// __syncthreads();
// __builtin_amdgcn_sched_barrier(0);
{
//dv gemm,dO*P
gpu_gemm_B_in_reg<Is_preload_dO, Is_store_dO, false, Is_even_MN, K_v, kBlockK_, kBlockN_, kBlockM_, kBlockK_, WARP_N_, 2, Element>(gdO, gQ, dO_lds, p_reg, acc_dv, (binfo.actual_seqlen_k - n_block * kBlockN_), (binfo.actual_seqlen_q - m_block * kBlockM_), warp_id, seqlen_do_stride);
}
// __builtin_amdgcn_sched_barrier(0);
// __builtin_amdgcn_s_waitcnt(0);
// __syncthreads();
// __builtin_amdgcn_sched_barrier(0);
union_vec2_f16x2<Element> dO_reg[((WARP_M_*kBlockK_)/(32*32))*2][2];
union_vec4_fp32 dp_reg[(WARP_N_/32)*(kBlockM_/32)][4]= {0};
{
// dP gemm
gemm_tt_kq<Is_store_dO, false, Is_even_MN, 3, dP_dO_prefetch_level, K_v, kBlockN_, kBlockM_, kBlockK_, WARP_N_, WARP_M_, STAGES, Element>(
gV, gdO, V_lds, dO_lds, (binfo.actual_seqlen_k - n_block * kBlockN_), (binfo.actual_seqlen_q - m_block * kBlockM_), v_reg, dO_reg, dp_reg, warp_id, seqlen_v_stride, seqlen_do_stride);
}
#ifdef DEBUGING
print_dp(m_block, bidb, bidh);
#endif
union_vec4_fp32 dS_reg[(WARP_N_/32)*(kBlockM_/32)][4];
#pragma unroll
for (int mi = 0; mi < (kBlockM_/32); ++mi) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for (int ni = 0; ni < (WARP_N_/32); ++ni) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#if 0//defined(__gfx936__)
#pragma unroll
for(int vec_idx=0; vec_idx<2; vec_idx++) {
// result register ds_reg reuse dp_reg
dS_reg[ni + mi*(WARP_N_/32)][min_tile_m*2 + min_tile_n].u64[vec_idx] = pointwise_mult(
s_reg[ni + mi*(WARP_N_/32)][min_tile_m*2 + min_tile_n].u64[vec_idx],
dp_reg[ni + mi*(WARP_N_/32)][min_tile_m*2 + min_tile_n].u64[vec_idx],
vec2_fp32{gdPsum[vec_idx*16 + mi*8*4 + ((lane_id >> 4)*2) + min_tile_m], gdPsum[vec_idx*16 + mi*8*4 + ((lane_id >> 4)*2) + min_tile_m + 8]});
}
#else
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
// result register ds_reg reuse dp_reg
// if((m_block*kBlockM_ + vec_idx * 8 + min_tile_m + ((lane_id / 16) * 2)) < params.seqlen_q){
dS_reg[ni + mi*(WARP_N_/32)][min_tile_m*2 + min_tile_n].f32[vec_idx] = pointwise_mult(s_reg[ni + mi*(WARP_N_/32)][min_tile_m*2 + min_tile_n].f32[vec_idx], dp_reg[ni + mi*(WARP_N_/32)][min_tile_m*2 + min_tile_n].f32[vec_idx], dP_sum_reg[vec_idx*2 + min_tile_m]);
// }
// else{
// dS_reg[ni + mi*(WARP_N_/32)][min_tile_m*2 + min_tile_n].f32[vec_idx] = 0;
// }
}
#endif
}
}
}
}
#ifdef DEBUGING
print_ds(m_block, bidb, bidh);
#endif
union_vec2_f16x2<Element> dS_reg_fp16[(WARP_N_/32)*(kBlockM_/32)][4];
convert_pk_type<kBlockM_, WARP_N_, Element>(dS_reg_fp16, dS_reg);
// #ifdef DEBUGING
// print_ds_fp16(m_block, bidb, bidh);
// #endif
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
{
//dk gemm, Q*dS
gpu_gemm_B_in_reg<Is_store_Q , false , false, Is_even_MN, K, kBlockK_, kBlockN_, kBlockM_, kBlockK_, WARP_N_, 2, Element>(gQ, gdO, Q_lds, dS_reg_fp16, acc_dk, (binfo.actual_seqlen_k - n_block * kBlockN_), (binfo.actual_seqlen_q - m_block * kBlockM_), warp_id, seqlen_q_stride);
}
gLSE = gLSE + (-int(kBlockM_));
gdPsum = gdPsum - kBlockM_;
*(uint64_t*)&gQ -= ((kBlockM_ * seqlen_q_stride) * sizeof(Element));
*(uint64_t*)&gdO -= ((kBlockM_ * seqlen_do_stride) * sizeof(Element));
{
__syncthreads();
if (Is_preload_Q && m_block > m_block_min){
prefetch_to_tmp_lds_wait<Is_even_MN, K, kBlockN_, kBlockM_, kBlockK_, WARP_N_, WARP_M_, Element>(gQ, Q_lds, (binfo.actual_seqlen_q - (m_block - 1) * kBlockM_), warp_id, seqlen_q_stride);
}
// __syncthreads();
if (Is_preload_dO && m_block > m_block_min){
prefetch_to_tmp_lds_wait<Is_even_MN, K_v, kBlockN_, kBlockM_, kBlockK_, WARP_N_, WARP_M_, Element>(gdO, dO_lds, (binfo.actual_seqlen_q - (m_block - 1) * kBlockM_), warp_id, seqlen_do_stride);
}
}
}
{
// dv_ptr = dv_ptr + binfo.k_offset1(params.dv_batch_stride, params.dv_row_stride, bidb)*params.h_h_k_ratio + binfo.k_offset2(params.dv_head_stride,bidh);
dv_ptr = dv_ptr + binfo.k_offset1_write(params.dv_batch_stride, params.dv_row_stride, bidb) + binfo.k_offset2(params.dv_head_stride,bidh);
auto gdV = tcp_cache_swizzle_func<K_v, Element>(dv_ptr);
int dv_lane_seq_idx = (lane_id >> 4);
int dv_lane_head_dim_idx = (lane_id & 15);
int dv_global_addr_offset=0;
#pragma unroll
for(int k_loop = 0; k_loop<(K_v/kBlockK_); k_loop++) {
#pragma unroll
for(int warp_n_idx=0; warp_n_idx<(WARP_N_/32); warp_n_idx++) {
#pragma unroll
for(int k_tile_idx=0; k_tile_idx<(kBlockK_/32); k_tile_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int vec_index=0; vec_index<4; vec_index++) {
int v_offset = (dv_lane_head_dim_idx*2) + (dv_lane_seq_idx*2 * seqlen_dv_stride);
int s_offset = (min_tile_n*seqlen_dv_stride + vec_index * 8 * seqlen_dv_stride) + (k_tile_idx*32) + ((warp_id*WARP_N_ + warp_n_idx*32) * seqlen_dv_stride) + (k_loop * kBlockK_ + n_block * kBlockN_ * seqlen_dv_stride);
int known_offset = 0;
vec2_Element<Element> v_data;
v_data[0] = DownCast<float,Element,true>(acc_dv[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2].f32[vec_index]);
v_data[1] = DownCast<float,Element,true>(acc_dv[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2 + 1].f32[vec_index]);
if (Is_even_MN || n_block * kBlockN_ + warp_id*WARP_N_ + warp_n_idx*32 + dv_lane_seq_idx*2 + min_tile_n + vec_index * 8 < binfo.actual_seqlen_k){
inline_buffer_store_dword_glc_slc<vec2_Element<Element>, 1>(v_data, v_offset, gdV, s_offset, /* immediate integer */known_offset);
}
}
}
}
}
}
}
{
// dk_ptr = dk_ptr + binfo.k_offset1(params.dk_batch_stride, params.dk_row_stride, bidb)*params.h_h_k_ratio + binfo.k_offset2(params.dk_head_stride,bidh);
dk_ptr = dk_ptr + binfo.k_offset1_write(params.dk_batch_stride, params.dk_row_stride, bidb) + binfo.k_offset2(params.dk_head_stride,bidh);
auto gdK = tcp_cache_swizzle_func<K, Element>(dk_ptr);
int dk_lane_seq_idx = (lane_id >> 4);
int dk_lane_head_dim_idx = (lane_id & 15);
int dk_global_addr_offset=0;
#pragma unroll
for(int k_loop = 0; k_loop<(K/kBlockK_); k_loop++) {
#pragma unroll
for(int warp_n_idx=0; warp_n_idx<(WARP_N_/32); warp_n_idx++) {
#pragma unroll
for(int k_tile_idx=0; k_tile_idx<(kBlockK_/32); k_tile_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int vec_index=0; vec_index<4; vec_index++) {
vec2_Element<Element> v_data;
int v_offset = dk_lane_head_dim_idx*2 + (dk_lane_seq_idx*2) * seqlen_dk_stride;
int s_offset = n_block * kBlockN_ * seqlen_dk_stride + (warp_id*WARP_N_) * seqlen_dk_stride + (min_tile_n*seqlen_dk_stride + vec_index * 8 * seqlen_dk_stride + k_tile_idx*32 + k_loop * kBlockK_ + warp_n_idx*32);
int known_offset = 0;
v_data[0] = DownCast<float,Element,true>(acc_dk[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2].f32[vec_index] * params.scale_softmax_rp_dropout);
v_data[1] = DownCast<float,Element,true>(acc_dk[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2 + 1].f32[vec_index] * params.scale_softmax_rp_dropout);
if (Is_even_MN || n_block * kBlockN_ + warp_id*WARP_N_ + dk_lane_seq_idx*2 + min_tile_n + vec_index * 8 < binfo.actual_seqlen_k){
inline_buffer_store_dword_glc_slc<vec2_Element<Element>, 1>(v_data, v_offset, gdK, s_offset, /* immediate integer */known_offset);
}
}
}
}
}
}
}
}
#undef print_kq
#undef print_dq
#undef print_softmax_rescale_o
#undef print_ds
#undef print_ds_fp16
#undef print_dp
#define print_kq(block_id_m, bidb, bidh) { \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
int qk_warp_m_id = (warp_id / (kBlockN_/WARP_N_)); \
int qk_warp_n_id = (warp_id & (kBlockN_/WARP_N_ - 1)); \
int qk_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + qk_warp_n_id*WARP_N_ + qk_warp_m_id*WARP_M_*params.seqlen_k; \
for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) { \
for(int m_block_idx=0; m_block_idx<kBlockM_/WARP_M_; m_block_idx++) { \
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) { \
for(int reg_id=0; reg_id<4; reg_id++) { \
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { \
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { \
if(((n_block*kBlockN_ + qk_warp_n_id*WARP_N_ + n_idx * 32 + (lane_id & 15) + min_tile_n*16) < params.seqlen_k) && \
((block_id_m*kBlockM_ + qk_warp_m_id*WARP_M_ + m_idx*32 + reg_id + min_tile_m*16 + ((lane_id / 16) * 4)) < params.seqlen_q)) { \
int offset = qk_global_offset + n_idx * 32 + m_block_idx * WARP_M_ * params.seqlen_k + m_idx*32*params.seqlen_k+ (lane_id & 15) + min_tile_m*params.seqlen_k*16 + ((lane_id / 16) * 4) *params.seqlen_k + min_tile_n*16 ; \
kq_ptr[offset + reg_id *params.seqlen_k] = (s_reg[m_block_idx * (WARP_M_/32) *(WARP_N_/32)+ m_idx*(WARP_N_/32) + n_idx][min_tile_m*2 + min_tile_n].f32[reg_id]); \
} \
} \
} \
} \
} \
} \
} \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
}
#define print_softmax_rescale_o(block_id_m, bidb, bidh) { \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
int s_warp_m_id = (warp_id / (kBlockN_/WARP_N_)); \
int s_warp_n_id = (warp_id & (kBlockN_/WARP_N_ - 1)); \
int s_global_offset = bidb*(params.h * params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ + block_id_m*kBlockM_*params.seqlen_k + s_warp_n_id*WARP_N_ + s_warp_m_id*WARP_M_*params.seqlen_k; \
for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) { \
for(int m_block_idx=0; m_block_idx<kBlockM_/WARP_M_; m_block_idx++) { \
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) { \
for(int reg_id=0; reg_id<4; reg_id++) { \
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { \
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { \
if(((n_block*kBlockN_ + s_warp_n_id*WARP_N_ + n_idx * 32 + (lane_id & 15) + min_tile_n*16) < params.seqlen_k) && \
((block_id_m*kBlockM_ + s_warp_m_id*WARP_M_ + m_idx*32 + reg_id + min_tile_m*16 + ((lane_id / 16) * 4)) < params.seqlen_q)) { \
int offset = s_global_offset + n_idx * 32 + m_block_idx * WARP_M_ * params.seqlen_k + m_idx*32*params.seqlen_k+ (lane_id & 15) + min_tile_m*params.seqlen_k*16 + ((lane_id / 16) * 4) *params.seqlen_k + min_tile_n*16 ; \
s_ptr[offset + reg_id * params.seqlen_k] = (s_reg[m_block_idx * (WARP_M_/32) *(WARP_N_/32)+ m_idx*(WARP_N_/32) + n_idx][min_tile_n + min_tile_m*2].f32[reg_id]);\
} \
} \
} \
} \
} \
} \
} \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
}
#define print_ds(block_id_m, bidb, bidh) { \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
int ds_warp_m_id = (warp_id / (kBlockN_/WARP_N_)); \
int ds_warp_n_id = (warp_id & (kBlockN_/WARP_N_ - 1)); \
int ds_global_offset = bidb*(params.h * params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ + block_id_m*kBlockM_*params.seqlen_k + ds_warp_n_id*WARP_N_ + ds_warp_m_id*WARP_M_*params.seqlen_k; \
for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) { \
for(int m_block_idx=0; m_block_idx<kBlockM_/WARP_M_; m_block_idx++) { \
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) { \
for(int reg_id=0; reg_id<4; reg_id++) { \
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { \
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { \
if(((n_block*kBlockN_ + ds_warp_n_id*WARP_N_ + n_idx * 32 + (lane_id & 15) + min_tile_n*16) < params.seqlen_k) && \
((block_id_m*kBlockM_ + ds_warp_m_id*WARP_M_ + m_idx*32 + reg_id + min_tile_m*16 + ((lane_id / 16) * 4)) < params.seqlen_q)) { \
int offset = ds_global_offset + n_idx * 32 + m_block_idx * WARP_M_ * params.seqlen_k + m_idx*32*params.seqlen_k+ (lane_id & 15) + min_tile_m*params.seqlen_k*16 + ((lane_id / 16) * 4) *params.seqlen_k + min_tile_n*16 ; \
ds_ptr[offset + reg_id * params.seqlen_k] = (dS_reg[m_block_idx * (WARP_M_/32) *(WARP_N_/32)+ m_idx*(WARP_N_/32) + n_idx][min_tile_n + min_tile_m*2].f32[reg_id]);\
} \
} \
} \
} \
} \
} \
} \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
}
#define print_ds_fp16(block_id_m, bidb, bidh) { \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
int ds_warp_m_id = (warp_id / (kBlockN_/WARP_N_)); \
int ds_warp_n_id = (warp_id & (kBlockN_/WARP_N_ - 1)); \
int ds_global_offset = bidb*(params.h * params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ + block_id_m*kBlockM_*params.seqlen_k + ds_warp_n_id*WARP_N_ + ds_warp_m_id*WARP_M_*params.seqlen_k; \
for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) { \
for(int m_block_idx=0; m_block_idx<kBlockM_/WARP_M_; m_block_idx++) { \
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) { \
for(int reg_id=0; reg_id<4; reg_id++) { \
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { \
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { \
if(((n_block*kBlockN_ + ds_warp_n_id*WARP_N_ + n_idx * 32 + (lane_id & 15) + min_tile_n*16) < params.seqlen_k) && \
((block_id_m*kBlockM_ + ds_warp_m_id*WARP_M_ + m_idx*32 + reg_id + min_tile_m*16 + ((lane_id / 16) * 4)) < params.seqlen_q)) { \
int offset = ds_global_offset + n_idx * 32 + m_block_idx * WARP_M_ * params.seqlen_k + m_idx*32*params.seqlen_k+ (lane_id & 15) + min_tile_m*params.seqlen_k*16 + ((lane_id / 16) * 4) *params.seqlen_k + min_tile_n*16 ; \
ds_ptr[offset + reg_id * 8 * params.seqlen_k] = UpCast<Element,float>(dS_reg_fp16[m_block_idx * (WARP_M_/32) *(WARP_N_/32)+ m_idx*(WARP_N_/32) + n_idx][min_tile_n + min_tile_m*2].f16[reg_id]);\
} \
} \
} \
} \
} \
} \
} \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
}
#define print_dp(block_id_m, bidb, bidh) { \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
int dp_warp_m_id = (warp_id / (kBlockN_/WARP_N_)); \
int dp_warp_n_id = (warp_id & (kBlockN_/WARP_N_ - 1)); \
int dp_global_offset = bidb*(params.h * params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ + block_id_m*kBlockM_*params.seqlen_k + dp_warp_n_id*WARP_N_ + dp_warp_m_id*WARP_M_*params.seqlen_k; \
for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) {\
for(int m_block_idx=0; m_block_idx<kBlockM_/WARP_M_; m_block_idx++) { \
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) { \
for(int reg_id=0; reg_id<4; reg_id++) { \
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { \
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { \
if(((n_block*kBlockN_ + dp_warp_n_id*WARP_N_ + n_idx * 32 + (lane_id & 15) + min_tile_n*16) < params.seqlen_k) && \
((block_id_m*kBlockM_ + dp_warp_m_id*WARP_M_ + m_idx*32 + reg_id + min_tile_m*16 + ((lane_id / 16) * 4)) < params.seqlen_q)) { \
int offset = dp_global_offset + n_idx * 32 + m_block_idx * WARP_M_ * params.seqlen_k + m_idx*32*params.seqlen_k+ (lane_id & 15) + min_tile_m*params.seqlen_k*16 + ((lane_id / 16) * 4) *params.seqlen_k + min_tile_n*16 ; \
dp_ptr[offset + reg_id * params.seqlen_k] = (dp_reg[m_block_idx * (WARP_M_/32) *(WARP_N_/32) + m_idx*(WARP_N_/32) + n_idx][min_tile_m*2 + min_tile_n].f32[reg_id]);\
} \
} \
} \
} \
} \
}\
} \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
}
/*
load q/k:累加方向为主序方向
ps: 在offset传0的情况下,T和R的取值似乎没有影响!?
调用matrix_load_32x32_b16:
R=0: offset in column direction
load Q: T=1: row major
load K: T=0: column major
m_ab=1: 线程数据在主序方向拼接
调用ds_read_matrix_trans_format(和m_ab保持一致):
element:0x2 row:0x2 col:0x1 alt:0x0
load v:累加方向为非主序方向
调用matrix_load_32x32_b16:
R=0: offset in column direction
T=1: row major
m_ab=0: 线程数据在非主序方向拼接
调用ds_read_matrix_format(和m_ab保持一致)
*/
template<class Element, class ElementAccum, bool Is_dropout, bool Is_causal , bool Is_local, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Seq_parallel=false, int kBlockM_, int kBlockN_, int K, int K_v, int kBlockK_, int WARP_M_, int WARP_N_, bool USE_BSHD_LAYOUT, typename Params>
__forceinline__ __device__ void compute_dk_dv_1colblock_gfx938(Params &params, int bidb, int bidh, int n_block
) {
#ifdef DEBUGING
ElementAccum * kq_ptr = static_cast<ElementAccum*>(params.kq_ptr);
ElementAccum * s_ptr = static_cast<ElementAccum*>(params.s_ptr);
ElementAccum * dp_ptr = static_cast<ElementAccum*>(params.dp_ptr);
ElementAccum * ds_ptr = static_cast<ElementAccum*>(params.ds_ptr);
#endif
Element* q_ptr = static_cast<Element*>(params.q_ptr);
Element* k_ptr = static_cast<Element*>(params.k_ptr);
Element* v_ptr = static_cast<Element*>(params.v_ptr);
Element* o_ptr = static_cast<Element*>(params.o_ptr);
Element* p_ptr = static_cast<Element*>(params.p_ptr);
// Element* dq_ptr = static_cast<Element*>(params.dq_ptr);
Element* dk_ptr = static_cast<Element*>(params.dk_ptr);
Element* dv_ptr = static_cast<Element*>(params.dv_ptr);
Element* do_ptr = static_cast<Element*>(params.do_ptr);
ElementAccum* softmax_lse_ptr = static_cast<ElementAccum*>(params.softmax_lse_ptr);
ElementAccum* dsoftmax_sum = static_cast<ElementAccum*>(params.dsoftmax_sum);
//flash-attention QK, kBlockN_==WARP_N_;
// static_assert(kBlockM_=WARP_M_,"Error: kBlockM_ not equal WARP_M_!");
const int WARP_NUM = (kBlockM_*kBlockN_)/(WARP_M_*WARP_N_);
const int M_BLOCK_NUM = params.seqlen_q/kBlockM_;
const int N_BLOCK_NUM = params.seqlen_k/kBlockN_;
extern __shared__ Element smem[];
int K_lds_ratio;
// 0表示k不预取;1表示k预取一半到寄存器;2表示一半到寄存器,一半到LDS;3表示全部预取到寄存器
const int K_prefetch_level = 3;
const int STAGES = 2;
const bool Is_store_Q = true;
const bool Is_store_dO = true;
const bool Is_preload_Q = true;
const bool Is_preload_dO = true;
const int dP_dO_prefetch_level = Is_store_dO ? 1 : 0;
const int Q_prefetech_level = Is_preload_Q ? 1 : 0;
if constexpr (K_prefetch_level == 2){
K_lds_ratio = (K / kBlockK_) / 2;
} else {
K_lds_ratio = (K_prefetch_level == 3) ? 0 : STAGES;
}
Element* K_lds = (Element*)&(smem);
Element* dO_lds = K_lds + kBlockN_ * kBlockK_ * K_lds_ratio;
Element* V_lds = K_prefetch_level == 2 ? dO_lds : K_lds;
Element* Q_lds = Is_store_Q ? dO_lds + kBlockM_ * K_v : dO_lds;
#if 0//defined(__gfx938__)
auto pointwise_mult = [](vec2_fp32 p, vec2_fp32 dp, vec2_fp32 d) {
auto d0 = (!Is_dropout || p[0] >= 0 ? dp[0] - d[0] : d[0]);
auto d1 = (!Is_dropout || p[1] >= 0 ? dp[1] - d[1] : d[1]);
// return vec2_fp32{p[0]*d0,p[1]*d1};
// return __builtin_hcu_pk_mul_f32(p, vec2_fp32{d0, d1});
return hcu_pk_mul_f32(p, vec2_fp32{d0, d1});
};
#else
auto pointwise_mult = [](float p, float dp, float d) {
return p * (!Is_dropout || p >= 0 ? dp - d : d);
};
#endif
int tidx = threadIdx.x;
int lane_id = threadIdx.x & 63; //lane id, 0-63
const flash::BlockInfo</*Varlen=*/!Is_even_MN, false, USE_BSHD_LAYOUT> binfo(params, bidb);
if (n_block * kBlockN_ >= binfo.actual_seqlen_k || binfo.actual_seqlen_q == 0) return;
const int m_block_min = (!Is_causal && !Is_local) ? 0 : std::max(0, (n_block * kBlockN_ - params.window_size_right) / kBlockM_);
const int m_block_max = !Is_local ? ceil_div(binfo.actual_seqlen_q, kBlockM_) : std::min(ceil_div(binfo.actual_seqlen_q, kBlockM_), ceil_div((n_block + 1) * kBlockN_ + params.window_size_left, kBlockM_));
int seqlen_q_stride = params.q_row_stride;
int seqlen_k_stride = params.k_row_stride;
int seqlen_v_stride = params.v_row_stride;
int seqlen_do_stride = params.do_row_stride;
int seqlen_o_stride = params.o_row_stride;
int seqlen_dk_stride = params.dk_row_stride;
int seqlen_dv_stride = params.dv_row_stride;
// We move K and V to the last block.
const int row_offset_q = binfo.q_offset1(params.q_batch_stride, params.q_row_stride, bidb) + binfo.q_offset2(params.q_head_stride,bidh) + (m_block_max - 1) * kBlockM_* seqlen_q_stride;
const int row_offset_k = binfo.k_offset1(params.k_batch_stride, params.k_row_stride, bidb) + binfo.k_offset2(params.k_head_stride,bidh/params.h_h_k_ratio) + n_block * kBlockN_ * seqlen_k_stride;
const int row_offset_v = binfo.k_offset1(params.v_batch_stride, params.v_row_stride, bidb) + binfo.k_offset2(params.v_head_stride,bidh/params.h_h_k_ratio) + n_block * kBlockN_ * seqlen_v_stride;
const int row_offset_dO = binfo.q_offset1(params.do_batch_stride, params.do_row_stride, bidb) + binfo.q_offset2(params.do_head_stride,bidh) + (m_block_max - 1) * kBlockM_ * seqlen_do_stride;
const int row_offset_o = binfo.q_offset1(params.o_batch_stride, params.o_row_stride, bidb) + binfo.q_offset2(params.o_head_stride,bidh) + (m_block_max - 1) * kBlockM_ * seqlen_o_stride;
// const int row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + (m_block_max - 1) * kBlockM_;
const int row_offset_lse = params.cu_seqlens_q == nullptr ? (bidb * params.h + bidh) * params.seqlen_q + (m_block_max - 1) * kBlockM_ : bidh * params.total_q + binfo.sum_s_q + (m_block_max - 1) * kBlockM_;
const int row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + (m_block_max - 1) * kBlockM_;
auto gQ = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(q_ptr) + row_offset_q, seqlen_q_stride);
auto gK = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(k_ptr) + row_offset_k, seqlen_k_stride);
auto gV = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(v_ptr) + row_offset_v, seqlen_v_stride);
auto gdO = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(do_ptr) + row_offset_dO, seqlen_do_stride);
Element * gO = reinterpret_cast<Element *>(o_ptr) + row_offset_o;
ElementAccum *gLSE = reinterpret_cast<ElementAccum *>(softmax_lse_ptr) + row_offset_lse;
ElementAccum *gdPsum = reinterpret_cast<ElementAccum *>(dsoftmax_sum) + row_offset_dpsum;
constexpr int m_masking_steps = (!Is_causal && !Is_local)
? 0
: flash::ceil_div(kBlockN_, kBlockM_);
/***************************************************************************************************************************/
// int warp_id =0;
int warp_id_vec = threadIdx.x / 64; //warp id in a block
int warp_id = __builtin_amdgcn_readfirstlane(warp_id_vec);
union_vec4_f16x2<Element> k_reg[(K/kBlockK_)*((WARP_N_*kBlockK_)/(32*32))*2/((K_prefetch_level == 3)? 1 : 2)]; //ds_read mini size is 32*32,2 is seq, 4 is head dim
union_vec4_f16x2<Element> v_reg[(K_v/kBlockK_)*((WARP_N_*kBlockK_)/(32*32))*2];
//提前读取V到vgpr
prefetch_to_vgpr_gfx938<true, kBlockN_, K, Element, ElementAccum, Is_even_MN>(gV, V_lds, v_reg, (binfo.actual_seqlen_k - n_block * kBlockN_), warp_id);
//提前读取K到vgpr
prefetch_to_vgpr_gfx938<true, kBlockN_, K, Element, ElementAccum, Is_even_MN>(gK, K_lds, k_reg, (binfo.actual_seqlen_k - n_block * kBlockN_), warp_id);
//提前读取Q到lds
if constexpr (Is_preload_Q){
prefetch_to_lds_gfx938<true, kBlockM_, K, Element, ElementAccum, Is_even_MN>(gQ, 0, Q_lds, (binfo.actual_seqlen_q - (m_block_max - 1) * kBlockM_), warp_id);
}
//提前读取dO到lds
if constexpr (Is_preload_dO){
prefetch_to_lds_gfx938<true, kBlockM_, K_v, Element, ElementAccum, Is_even_MN>(gdO, 0, dO_lds, (binfo.actual_seqlen_q - (m_block_max - 1) * kBlockM_), warp_id);
}
// __builtin_amdgcn_s_waitcnt(0);
// __syncthreads();
union_vec4_fp32 acc_dv[(K_v/kBlockK_) * ((WARP_N_/32)*(kBlockK_/32))][4]={0};
union_vec4_fp32 acc_dk[(K/kBlockK_) * ((WARP_N_/32)*(kBlockK_/32))][4]={0};
for (int m_block = m_block_max - 1; m_block >= m_block_min; --m_block) {
union_vec4_f16x2<Element> q_reg[((WARP_M_*kBlockK_)/(32*32))*2];
//c mini tile is 32*32
union_vec4_fp32 s_reg[(WARP_N_/32)*(kBlockM_/32)][4]= {0};
/*
qk gemm
结果矩阵layout:
0 0 0 0 16 16 16 16 32 32 32 32 48 48 48 48 0 0 0 0 16 16 16 16 32 32 32 32 48 48 48 48
1 1 1 1 17 17 17 17 33 33 33 33 49 49 49 49 1 1 1 1 17 17 17 17 33 33 33 33 49 49 49 49
...
0 0 0 0 16 16 16 16 32 32 32 32 48 48 48 48 0 0 0 0 16 16 16 16 32 32 32 32 48 48 48 48
1 1 1 1 17 17 17 17 33 33 33 33 49 49 49 49 1 1 1 1 17 17 17 17 33 33 33 33 49 49 49 49
*/
gemm_tt_kq_gfx938<Is_store_Q, Is_preload_dO, Is_even_MN, K_prefetch_level, Q_prefetech_level, K, kBlockN_, kBlockM_, kBlockK_, WARP_N_, WARP_M_, STAGES, Element>(
gK, gQ, K_lds, Q_lds, (binfo.actual_seqlen_k - n_block * kBlockN_), (binfo.actual_seqlen_q - m_block * kBlockM_), k_reg, q_reg, s_reg, warp_id, seqlen_k_stride, seqlen_q_stride);
/*
lse layout:
4 warp:
32
32
32
32
因为warp在seqlen_k维度,所以不区分warp
每16个thread持有相同的lse,所以需要/4
*/
float lse[kBlockM_/4];
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
#pragma unroll
for (int mi = 0; mi < (kBlockM_/32); ++mi) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
const int lse_idx = mi*32 + min_tile_m * 16 + (lane_id >> 4)*4 + vec_idx;
lse[(mi*2 + min_tile_m)*4 + vec_idx] = Is_even_MN || lse_idx < binfo.actual_seqlen_q - m_block * kBlockM_ ? gLSE[lse_idx] : INFINITY;
}
}
}
apply_mask_bwd_gfx938<Is_even_MN, Is_local ? 3 : (Is_causal ? 2 : 0)>(s_reg, binfo.actual_seqlen_k - n_block * kBlockN_ - warp_id * 32, binfo.actual_seqlen_q - m_block * kBlockM_, (n_block * kBlockN_ + warp_id * 32) - m_block * kBlockM_, params.window_size_right, params.window_size_left);
#ifdef DEBUGING
print_kq(m_block, bidb, bidh);
#endif
//do . o后在headdim维度reduce求和,读取方式和lse一样,因为pad了,所以无需边界判断
float dP_sum_reg[kBlockM_/4];
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
#pragma unroll
for (int mi = 0; mi < (kBlockM_/32); ++mi) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
const int dPsum_idx = mi*32 + min_tile_m * 16 + (lane_id >> 4)*4 + vec_idx;
dP_sum_reg[(mi*2 + min_tile_m)*4 + vec_idx] = gdPsum[dPsum_idx];
}
}
}
{
scale_apply_exp2_bwd</*scale_max=*/false, kBlockM_, WARP_N_>(s_reg, lse, params.scale_softmax_log2);
}
#ifdef DEBUGING
print_softmax_rescale_o(m_block, bidb, bidh);
#endif
// //TODO:drop
union_vec4_f16x2<Element> p_reg[(kBlockM_/32)*(WARP_N_/32)*2];
// convert_pk_type<kBlockM_, WARP_N_, Element>(p_reg, s_reg);
convert_pk_type_gfx938<kBlockM_, WARP_N_, Element>(p_reg, s_reg);
//QK(seq_q, seq_kv), seq_q is continuous, seq_kv is not continuous
// __builtin_amdgcn_sched_barrier(0);
// __builtin_amdgcn_s_waitcnt(0);
// __syncthreads();
// __builtin_amdgcn_sched_barrier(0);
{
//dv gemm,dO*P
gpu_gemm_B_in_reg_gfx938<Is_preload_dO, Is_store_dO, Is_even_MN, K_v, kBlockK_, kBlockN_, kBlockM_, kBlockK_, WARP_N_, 2, Element>(gdO, gQ, dO_lds, p_reg, acc_dv, (binfo.actual_seqlen_k - n_block * kBlockN_), (binfo.actual_seqlen_q - m_block * kBlockM_), warp_id, seqlen_do_stride);
}
// __builtin_amdgcn_sched_barrier(0);
// __builtin_amdgcn_s_waitcnt(0);
// __syncthreads();
// __builtin_amdgcn_sched_barrier(0);
union_vec4_f16x2<Element> dO_reg[((WARP_M_*kBlockK_)/(32*32))*2];
union_vec4_fp32 dp_reg[(WARP_N_/32)*(kBlockM_/32)][4]= {0};
{
// dP gemm dO * V
gemm_tt_kq_gfx938<Is_store_dO, false, Is_even_MN, 3, dP_dO_prefetch_level, K_v, kBlockN_, kBlockM_, kBlockK_, WARP_N_, WARP_M_, STAGES, Element>(
gV, gdO, V_lds, dO_lds, (binfo.actual_seqlen_k - n_block * kBlockN_), (binfo.actual_seqlen_q - m_block * kBlockM_), v_reg, dO_reg, dp_reg, warp_id, seqlen_v_stride, seqlen_do_stride);
}
#ifdef DEBUGING
print_dp(m_block, bidb, bidh);
#endif
union_vec4_fp32 dS_reg[(WARP_N_/32)*(kBlockM_/32)][4];
#pragma unroll
for (int mi = 0; mi < (kBlockM_/32); ++mi) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for (int ni = 0; ni < (WARP_N_/32); ++ni) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#if 0//defined(__gfx938__)
#pragma unroll
for(int vec_idx=0; vec_idx<2; vec_idx++) {
// result register ds_reg reuse dp_reg
dS_reg[ni + mi*(WARP_N_/32)][min_tile_m*2 + min_tile_n].u64[vec_idx] = pointwise_mult(
s_reg[ni + mi*(WARP_N_/32)][min_tile_m*2 + min_tile_n].u64[vec_idx],
dp_reg[ni + mi*(WARP_N_/32)][min_tile_m*2 + min_tile_n].u64[vec_idx],
vec2_fp32{dP_sum_reg[min_tile_m*4 + vec_idx * 2], dP_sum_reg[min_tile_m*4 + vec_idx * 2 + 1]});
}
#else
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
// result register ds_reg reuse dp_reg
dS_reg[ni + mi*(WARP_N_/32)][min_tile_m*2 + min_tile_n].f32[vec_idx] = pointwise_mult(
s_reg[ni + mi*(WARP_N_/32)][min_tile_m*2 + min_tile_n].f32[vec_idx],
dp_reg[ni + mi*(WARP_N_/32)][min_tile_m*2 + min_tile_n].f32[vec_idx],
dP_sum_reg[min_tile_m*4 + vec_idx]);
}
#endif
}
}
}
}
#ifdef DEBUGING
print_ds(m_block, bidb, bidh);
#endif
union_vec4_f16x2<Element> dS_reg_fp16[(WARP_N_/32)*(kBlockM_/32)*2];
convert_pk_type_gfx938<kBlockM_, WARP_N_, Element>(dS_reg_fp16, dS_reg);
// #ifdef DEBUGING
// print_ds_fp16(m_block, bidb, bidh);
// #endif
// __builtin_amdgcn_sched_barrier(0);
// __builtin_amdgcn_s_waitcnt(0);
// __syncthreads();
// __builtin_amdgcn_sched_barrier(0);
{
//dk gemm, Q*dS
gpu_gemm_B_in_reg_gfx938<Is_store_Q , false, Is_even_MN, K, kBlockK_, kBlockN_, kBlockM_, kBlockK_, WARP_N_, 2, Element>(gQ, gdO, Q_lds, dS_reg_fp16, acc_dk, (binfo.actual_seqlen_k - n_block * kBlockN_), (binfo.actual_seqlen_q - m_block * kBlockM_), warp_id, seqlen_q_stride);
}
gLSE = gLSE + (-int(kBlockM_));
gdPsum = gdPsum - kBlockM_;
*(uint64_t*)&gQ -= ((kBlockM_ * seqlen_q_stride) * sizeof(Element));
*(uint64_t*)&gdO -= ((kBlockM_ * seqlen_do_stride) * sizeof(Element));
{
__syncthreads();
if (Is_preload_Q && m_block > m_block_min){
prefetch_to_lds_gfx938<true, kBlockM_, K, Element, ElementAccum, Is_even_MN>(gQ, 0, Q_lds, (binfo.actual_seqlen_q - (m_block - 1) * kBlockM_), warp_id);
}
// __syncthreads();
if (Is_preload_dO && m_block > m_block_min){
prefetch_to_lds_gfx938<true, kBlockM_, K_v, Element, ElementAccum, Is_even_MN>(gdO, 0, dO_lds, (binfo.actual_seqlen_q - (m_block - 1) * kBlockM_), warp_id);
}
}
}
{
// dv_ptr = dv_ptr + binfo.k_offset1(params.dv_batch_stride, params.dv_row_stride, bidb)*params.h_h_k_ratio + binfo.k_offset2(params.dv_head_stride,bidh);
dv_ptr = dv_ptr + binfo.k_offset1_write(params.dv_batch_stride, params.dv_row_stride, bidb) + binfo.k_offset2(params.dv_head_stride,bidh);
auto gdV = tcp_cache_swizzle_func<K_v, Element>(dv_ptr);
int dv_lane_seq_idx = (lane_id >> 4);
int dv_lane_head_dim_idx = (lane_id & 15);
int dv_global_addr_offset=0;
#pragma unroll
for(int k_loop = 0; k_loop<(K_v/kBlockK_); k_loop++) {
#pragma unroll
for(int warp_n_idx=0; warp_n_idx<(WARP_N_/32); warp_n_idx++) {
#pragma unroll
for(int k_tile_idx=0; k_tile_idx<(kBlockK_/32); k_tile_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int vec_index=0; vec_index<4; vec_index++) {
int v_offset = dv_lane_head_dim_idx*2 + dv_lane_seq_idx * seqlen_dv_stride;
int s_offset = (min_tile_n*seqlen_dv_stride*16 + vec_index * 4 * seqlen_dv_stride) + (k_tile_idx*32) + ((warp_id*WARP_N_ + warp_n_idx*32) * seqlen_dv_stride) + (k_loop * kBlockK_ + n_block * kBlockN_ * seqlen_dv_stride);
int known_offset = 0;
vec2_Element<Element> v_data;
v_data[0] = DownCast<float,Element,true>(acc_dv[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2].f32[vec_index]);
v_data[1] = DownCast<float,Element,true>(acc_dv[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2 + 1].f32[vec_index]);
if (Is_even_MN || n_block * kBlockN_ + warp_id*WARP_N_ + warp_n_idx*32 + dv_lane_seq_idx + min_tile_n*16 + vec_index * 4 < binfo.actual_seqlen_k){
inline_buffer_store_dword_glc_slc<vec2_Element<Element>, 1>(v_data, v_offset, gdV, s_offset, /* immediate integer */known_offset);
}
}
}
}
}
}
}
{
// dk_ptr = dk_ptr + binfo.k_offset1(params.dk_batch_stride, params.dk_row_stride, bidb)*params.h_h_k_ratio + binfo.k_offset2(params.dk_head_stride,bidh);
dk_ptr = dk_ptr + binfo.k_offset1_write(params.dk_batch_stride, params.dk_row_stride, bidb) + binfo.k_offset2(params.dk_head_stride,bidh);
auto gdK = tcp_cache_swizzle_func<K, Element>(dk_ptr);
int dk_lane_seq_idx = (lane_id >> 4);
int dk_lane_head_dim_idx = (lane_id & 15);
int dk_global_addr_offset=0;
#pragma unroll
for(int k_loop = 0; k_loop<(K/kBlockK_); k_loop++) {
#pragma unroll
for(int warp_n_idx=0; warp_n_idx<(WARP_N_/32); warp_n_idx++) {
#pragma unroll
for(int k_tile_idx=0; k_tile_idx<(kBlockK_/32); k_tile_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int vec_index=0; vec_index<4; vec_index++) {
vec2_Element<Element> v_data;
int v_offset = dk_lane_head_dim_idx*2 + dk_lane_seq_idx * seqlen_dk_stride;
int s_offset = n_block * kBlockN_ * seqlen_dk_stride + (warp_id*WARP_N_) * seqlen_dk_stride + (min_tile_n*seqlen_dk_stride*16 + vec_index * 4 * seqlen_dk_stride + k_tile_idx*32 + k_loop * kBlockK_ + warp_n_idx*32);
int known_offset = 0;
v_data[0] = DownCast<float,Element,true>(acc_dk[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2].f32[vec_index] * params.scale_softmax_rp_dropout);
v_data[1] = DownCast<float,Element,true>(acc_dk[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2 + 1].f32[vec_index] * params.scale_softmax_rp_dropout);
if (Is_even_MN || n_block * kBlockN_ + warp_id*WARP_N_ + dk_lane_seq_idx + min_tile_n*16 + vec_index * 4 < binfo.actual_seqlen_k){
inline_buffer_store_dword_glc_slc<vec2_Element<Element>, 1>(v_data, v_offset, gdK, s_offset, /* immediate integer */known_offset);
}
}
}
}
}
}
}
}
#undef print_dq
#undef print_softmax_rescale_o
#undef print_ds
#undef print_ds_fp16
#undef print_dp
#pragma once
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
#include "intrinsic.h"
#include "prefetch.h"
// K BLOCK_K BLOCK_N BLOCK_M BLOCK_K WARP_N
template<bool Is_preload_A, bool Is_store_A, bool Is_preload_C, bool Is_even_MN, int M/*head_dim*/, int BLOCK_M, int BLOCK_N, int BLOCK_K, int WARP_M, int WARP_N, int STAGES, typename Element, typename ElementAccum = float>
__forceinline__ __device__ void gpu_gemm_B_in_reg(vec4_uint A_ptr, vec4_uint C_ptr, Element* A_lds, union_vec2_f16x2<Element> B_reg[(WARP_M/32)*(BLOCK_K/32)][4], union_vec4_fp32 C_reg[(M/BLOCK_M)*(WARP_M/32)*(WARP_N/32)][4], int N/*seq_kv*/, int K/*seq_q*/, int warp_id, int seqlen_A_stride)
{
#if 1
const int WARP_NUM = (BLOCK_M*BLOCK_N)/(WARP_M*WARP_N);
const int A_lds_load_num = (BLOCK_M*BLOCK_K) / (4*32);
static_assert(BLOCK_K>=32, "Error: gpu_gemm_B_in_reg gemm BLOCK_K must be equal or greater than 32");
static_assert(BLOCK_N>=WARP_N, "Error: gpu_gemm_B_in_reg gemm BLOCK_N must be equal or greater than WARP_N");
static_assert(BLOCK_M==WARP_M, "Error: gpu_gemm_B_in_reg gemm BLOCK_M must be equal to WARP_M");
union_vec2_f16x2<Element> A_reg[((WARP_M*BLOCK_K)/(32*32))*2][2];
//c mini tile is 32*32
// vec4_fp32 o[(WARP_M/32)*(WARP_N/32)][4]={0};
// __shared__ Element A_lds[STAGES*BLOCK_N * BLOCK_K];
//wave size should be defined in launch file. Here use 64 threads
int lane_id = threadIdx.x & 63; //lane id, 0-63
int row = lane_id % 4;
int col = lane_id / 4;
int stage_id = 0;
if(STAGES > 1 && (!Is_preload_A)) {
int m_loop = 0;
int A_block_buffer_load_global_offset = m_loop*BLOCK_M ; //+ k_loop * BLOCK_K * N;
// A_ptr buffer load mini size is 32*32, buffer_load_dword mini size is 4*32
int A_lane_m_idx = lane_id % 16;
// int A_lane_k_idx = lane_id / 16;
int A_lane_k_idx = ((lane_id >> 4) & 1)*2 + ((lane_id >> 4) >> 1); //(0, 1, 2, 3) --> (0, 2, 1, 3)
for(int warp_loop=warp_id; warp_loop<A_lds_load_num; warp_loop+=WARP_NUM) {
// for(int warp_loop_tmp = 0; warp_loop_tmp < A_lds_load_num / WARP_NUM; warp_loop_tmp++){
// int warp_loop = warp_loop_tmp * WARP_NUM + warp_id;
//global->lds, right matrix
int A_warp_buffer_load_k_id = (warp_loop / (BLOCK_M/32)); //seq_len
int A_warp_buffer_load_m_id = (warp_loop % (BLOCK_M/32)); //head_dim
{
int A_warp_buffer_load_global_offset = (A_warp_buffer_load_m_id * 32);
int A_warp_buffer_load_lds_offset = (A_warp_buffer_load_m_id * 32) + (A_warp_buffer_load_k_id * 4 * BLOCK_M);
if(Is_store_A){
A_warp_buffer_load_lds_offset = (A_warp_buffer_load_m_id * 32) + (A_warp_buffer_load_k_id * (4 * BLOCK_M + 2));
}
int A_gsoffset = (A_block_buffer_load_global_offset + A_warp_buffer_load_global_offset)/2 ;
int A_gvoffset;
if constexpr (Is_even_MN){
A_gvoffset = ((A_lane_m_idx * 2 + (A_lane_k_idx + A_warp_buffer_load_k_id*4)* seqlen_A_stride))/2 ;
} else {
A_gvoffset = ((A_lane_m_idx * 2 + min(A_lane_k_idx + A_warp_buffer_load_k_id*4, K-1)* seqlen_A_stride))/2 ;
}
// int gvOffset = (64*8 + lane_id*8)/2;
int A_lds_offset = ((stage_id)*BLOCK_K*BLOCK_M + A_warp_buffer_load_lds_offset)/2;
if(Is_store_A){
A_lds_offset = ((stage_id)*(BLOCK_K/32)*(BLOCK_M/32)*32*34 + A_warp_buffer_load_lds_offset)/2;
}
builtin_buffer_load_dword_lds(A_lds , A_ptr, A_lds_offset, A_gsoffset, A_gvoffset);
}
}
}
#if 1
// int lds_offset = row * 8 + col * 32;
for(int m_loop = 1; m_loop<(M/BLOCK_M) + 1; m_loop++) {
if(STAGES > 1) {
if constexpr(Is_preload_A || Is_store_A){
stage_id ++;
} else {
stage_id = stage_id ^ 1;
}
}
if(STAGES == 1) {
m_loop--;
}
if((!Is_preload_A)&& m_loop < (M/BLOCK_M)) {
int A_block_buffer_load_global_offset = m_loop*BLOCK_M;
if(Is_store_A){
int A_lds_stage_offset = (stage_id)*(BLOCK_K/32)*(BLOCK_M/32)*32*34;
buffer_load_lds_tile_pad(Is_even_MN, WARP_NUM, seqlen_A_stride, BLOCK_M, BLOCK_K, Element, A_ptr, A_lds, A_block_buffer_load_global_offset, A_lds_stage_offset, K, warp_id, lane_id);
} else {
int A_lds_stage_offset = (stage_id)*BLOCK_K*BLOCK_M;
buffer_load_lds_tile(Is_even_MN, WARP_NUM, seqlen_A_stride, BLOCK_M, BLOCK_K, Element, A_ptr, A_lds, A_block_buffer_load_global_offset, A_lds_stage_offset, K, warp_id, lane_id);
}
}
if(!Is_preload_A){
if(STAGES > 1) {
if(m_loop < (M/BLOCK_M)){
// if constexpr(Is_preload_A){
// vmcnt_wait((M/BLOCK_M - m_loop) * (BLOCK_K*BLOCK_M) / (4*32)/WARP_NUM);
// } else {
vmcnt_wait((BLOCK_K*BLOCK_M) / (4*32)/WARP_NUM);
// }
} else {
vmcnt_wait(0);
}
} else {
vmcnt_wait(0);
}
}
if constexpr (STAGES > 1) {
if constexpr(Is_preload_A || Is_store_A){
stage_id --;
} else {
stage_id = stage_id ^ 1;
}
}
if (Is_preload_C && m_loop > 1) {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0)");
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
int C_block_buffer_load_global_offset = (m_loop - 2)*BLOCK_M;
// A_ptr buffer load mini size is 32*32, buffer_load_dword mini size is 4*32
int C_lane_m_idx = lane_id % 16;
// int A_lane_k_idx = lane_id / 16;
int C_lane_k_idx = ((lane_id >> 4) & 1)*2 + ((lane_id >> 4) >> 1); //(0, 1, 2, 3) --> (0, 2, 1, 3)
for(int warp_loop_temp=0; warp_loop_temp< A_lds_load_num/WARP_NUM; warp_loop_temp++) {
int warp_loop = warp_loop_temp * WARP_NUM + warp_id;
//global->lds, right matrix
int C_warp_buffer_load_k_id = (warp_loop / (BLOCK_K/32)); //seq_len
int C_warp_buffer_load_m_id = (warp_loop % (BLOCK_M/32)); //head_dim
{
int C_warp_buffer_load_global_offset = (C_warp_buffer_load_m_id * 32);
int C_warp_buffer_load_lds_offset = (C_warp_buffer_load_m_id * 32) + (C_warp_buffer_load_k_id * 4 * BLOCK_M);
if(Is_store_A){
C_warp_buffer_load_lds_offset = (C_warp_buffer_load_m_id * 32) + (C_warp_buffer_load_k_id * (4 * BLOCK_M + 2));
}
int C_gsoffset = (C_block_buffer_load_global_offset + C_warp_buffer_load_global_offset)/2 ;
int C_gvoffset;
if constexpr (Is_even_MN){
C_gvoffset = ((C_lane_m_idx * 2 + (C_lane_k_idx + C_warp_buffer_load_k_id*4)* M))/2 ;
} else {
C_gvoffset = ((C_lane_m_idx * 2 + min(C_lane_k_idx + C_warp_buffer_load_k_id*4, K-1)* M))/2 ;
}
// int gvOffset = (64*8 + lane_id*8)/2;
int A_lds_offset = ((m_loop - 2)*BLOCK_K*BLOCK_M + C_warp_buffer_load_lds_offset)/2;
if(Is_store_A){
A_lds_offset = ((m_loop - 2)*(BLOCK_K/32)*(BLOCK_M/32)*32*34 + C_warp_buffer_load_lds_offset)/2;
}
builtin_buffer_load_dword_lds(A_lds , C_ptr, A_lds_offset, C_gsoffset, C_gvoffset);
}
}
}
//lds -> vgpr use ds_read_m; left matrix
int A_lane_head_dim_idx = lane_id % 16;
int A_lane_seq_idx = lane_id / 16;
// __builtin_amdgcn_s_waitcnt(4080 + ((M/BLOCK_M) - m_loop)*(A_lds_load_num/WARP_NUM));
// vmcnt_wait_no_barrier(((M/BLOCK_M) - m_loop)*(A_lds_load_num/WARP_NUM));
vec2_Element<Element> *A_lds_v2fp16 = (vec2_Element<Element> *)(A_lds );
//lds -> vgpr use ds_read_m; right matrix
#pragma unroll
for(int head_dim_idx=0; head_dim_idx<(WARP_M/32); head_dim_idx++) {
#pragma unroll
for(int seq_idx=0; seq_idx<(BLOCK_K/32); seq_idx++) {
#pragma unroll
for(int seq_min_tile_idx=0; seq_min_tile_idx<2; seq_min_tile_idx++) { // min k tile
// __builtin_amdgcn_s_waitcnt(4082);
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) //16*32 half need 4 ds_read_b32
{
// int lds_offset = (stage_id*BLOCK_K*BLOCK_M + (seq_idx*32*BLOCK_M) + head_dim_idx*32 * 32 + A_lane_seq_idx/2*4*32 + A_lane_seq_idx%2*32 + (seq_min_tile_idx*32*2) + vec_idx*8*32 + A_lane_head_dim_idx*2)/2;
int lds_offset = stage_id * BLOCK_K * BLOCK_M / 2 + seq_idx * BLOCK_M * 16 + head_dim_idx * 512 + A_lane_seq_idx/2 * 64 + A_lane_seq_idx % 2 * 16 + seq_min_tile_idx * 32 + vec_idx * 128 + A_lane_head_dim_idx;
if constexpr(Is_preload_A || Is_store_A){
// lds_offset = (stage_id*(BLOCK_K/32)*(BLOCK_M/32)*32*34 + (seq_idx*34*BLOCK_M) + head_dim_idx*32 * 34 + A_lane_seq_idx/2*(4*32 + 2) + A_lane_seq_idx%2*32 + (seq_min_tile_idx*32*2) + vec_idx*(8*32+4) + A_lane_head_dim_idx*2)/2;
// lds_offset += (stage_id*(BLOCK_K/32)*(BLOCK_M/32)*32*2 + 2*seq_idx*BLOCK_M + head_dim_idx * 32 * 2 + A_lane_seq_idx/2*2 + vec_idx*4)/2;
lds_offset += stage_id * BLOCK_K * BLOCK_M / 32 + seq_idx * BLOCK_M + head_dim_idx * 32 + A_lane_seq_idx / 2 + vec_idx * 2;
}
inline_ds_read_b32_wait(A_lds_v2fp16, lds_offset, A_reg[(head_dim_idx*(BLOCK_K/32) + seq_idx)*2 + seq_min_tile_idx][vec_idx/2].f16x2[vec_idx%2]);
}
// #pragma unroll
// for(int vec_idx=0; vec_idx<2; vec_idx++) //16*32 half need 4 ds_read_b32
// {
// int lds_offset = (stage_id*BLOCK_K*BLOCK_M + (seq_idx*32*BLOCK_M) + head_dim_idx*32 * 32 + A_lane_seq_idx/2*4*32 + A_lane_seq_idx%2*32 + (seq_min_tile_idx*32*2) + vec_idx*16*32 + A_lane_head_dim_idx*2)/2;
// if constexpr(Is_preload_A || Is_store_A){
// lds_offset = (stage_id*(BLOCK_K/32)*(BLOCK_M/32)*32*34 + (seq_idx*34*BLOCK_M) + head_dim_idx*32 * 34 + A_lane_seq_idx/2*(4*32 + 2) + A_lane_seq_idx%2*32 + (seq_min_tile_idx*32*2) + vec_idx*(16*32+8) + A_lane_head_dim_idx*2)/2;
// }
// inline_ds_read2_b32_no_wait(A_lds_v2fp16, lds_offset, A_reg[(head_dim_idx*(BLOCK_K/32) + seq_idx)*2 + seq_min_tile_idx][vec_idx].f32, 4*32);
// }
}
}
}
// asm volatile("s_waitcnt lgkmcnt(0)");
// __builtin_amdgcn_sched_barrier(0);
if constexpr (STAGES == 1){
m_loop++;
}
asm volatile("s_setprio 1");
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) {
#pragma unroll
for(int k_idx=0; k_idx<(BLOCK_K/32); k_idx++) { //BLOCK_K mini size is 32
//min tile is 32*32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for(int min_tile_k=0; min_tile_k<2; min_tile_k++) {
if constexpr (std::is_same<Element,Float8_e4m3_t>::value){
C_reg[(m_loop-1) * ((WARP_M/32)*(WARP_N/32)) + (n_idx*(WARP_M/32) + m_idx)][min_tile_n*2 + min_tile_m].f32 = flash::mmac<half_t, ElementAccum>(
vec4_Element<half_t>{
UpCast<Element, half_t, true>(A_reg[(m_idx* (BLOCK_K/32) + k_idx)*2 + 0][min_tile_k].f16x2[0][min_tile_m]),
UpCast<Element, half_t, true>(A_reg[(m_idx* (BLOCK_K/32) + k_idx)*2 + 1][min_tile_k].f16x2[0][min_tile_m]),
UpCast<Element, half_t, true>(A_reg[(m_idx* (BLOCK_K/32) + k_idx)*2 + 0][min_tile_k].f16x2[1][min_tile_m]),
UpCast<Element, half_t, true>(A_reg[(m_idx* (BLOCK_K/32) + k_idx)*2 + 1][min_tile_k].f16x2[1][min_tile_m])},
vec4_Element<half_t>{
UpCast<Element, half_t, true>(B_reg[(k_idx)*(WARP_N/32) + n_idx][0*2 + min_tile_n].f16x2[min_tile_k][0]),
UpCast<Element, half_t, true>(B_reg[(k_idx)*(WARP_N/32) + n_idx][1*2 + min_tile_n].f16x2[min_tile_k][0]),
UpCast<Element, half_t, true>(B_reg[(k_idx)*(WARP_N/32) + n_idx][0*2 + min_tile_n].f16x2[min_tile_k][1]),
UpCast<Element, half_t, true>(B_reg[(k_idx)*(WARP_N/32) + n_idx][1*2 + min_tile_n].f16x2[min_tile_k][1])},
C_reg[(m_loop-1) * ((WARP_M/32)*(WARP_N/32)) + (n_idx*(WARP_M/32) + m_idx)][min_tile_n*2 + min_tile_m].f32);
} else {
C_reg[(m_loop-1) * ((WARP_M/32)*(WARP_N/32)) + (n_idx*(WARP_M/32) + m_idx)][min_tile_n*2 + min_tile_m].f32 = flash::mmac<Element, ElementAccum>(
vec4_Element<Element>{
A_reg[(m_idx* (BLOCK_K/32) + k_idx)*2 + 0][min_tile_k].f16x2[0][min_tile_m],
A_reg[(m_idx* (BLOCK_K/32) + k_idx)*2 + 1][min_tile_k].f16x2[0][min_tile_m],
A_reg[(m_idx* (BLOCK_K/32) + k_idx)*2 + 0][min_tile_k].f16x2[1][min_tile_m],
A_reg[(m_idx* (BLOCK_K/32) + k_idx)*2 + 1][min_tile_k].f16x2[1][min_tile_m]},
vec4_Element<Element>{
B_reg[(k_idx)*(WARP_N/32) + n_idx][0*2 + min_tile_n].f16x2[min_tile_k][0],
B_reg[(k_idx)*(WARP_N/32) + n_idx][1*2 + min_tile_n].f16x2[min_tile_k][0],
B_reg[(k_idx)*(WARP_N/32) + n_idx][0*2 + min_tile_n].f16x2[min_tile_k][1],
B_reg[(k_idx)*(WARP_N/32) + n_idx][1*2 + min_tile_n].f16x2[min_tile_k][1]},
C_reg[(m_loop-1) * ((WARP_M/32)*(WARP_N/32)) + (n_idx*(WARP_M/32) + m_idx)][min_tile_n*2 + min_tile_m].f32);
}
}
}
}
}
}
}
asm volatile("s_setprio 0");
if(STAGES > 1) {
if constexpr(Is_preload_A || Is_store_A){
stage_id ++;
} else {
stage_id ^=1;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0)");
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
}
} else {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0)");
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
}
}
if constexpr (Is_preload_C) {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0)");
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
int C_block_buffer_load_global_offset = 3*BLOCK_M;
// A_ptr buffer load mini size is 32*32, buffer_load_dword mini size is 4*32
int C_lane_m_idx = lane_id % 16;
// int A_lane_k_idx = lane_id / 16;
int C_lane_k_idx = ((lane_id >> 4) & 1)*2 + ((lane_id >> 4) >> 1); //(0, 1, 2, 3) --> (0, 2, 1, 3)
const int C_lds_load_num = (BLOCK_M*BLOCK_K) / (4*32);
for(int warp_loop_temp=0; warp_loop_temp< C_lds_load_num/WARP_NUM; warp_loop_temp++) {
int warp_loop = warp_loop_temp * WARP_NUM + warp_id;
//global->lds, right matrix
int C_warp_buffer_load_k_id = (warp_loop / (BLOCK_K/32)); //seq_len
int C_warp_buffer_load_m_id = (warp_loop % (BLOCK_M/32)); //head_dim
{
int C_warp_buffer_load_global_offset = (C_warp_buffer_load_m_id * 32);
int C_warp_buffer_load_lds_offset = (C_warp_buffer_load_m_id * 32) + (C_warp_buffer_load_k_id * 4 * BLOCK_M);
if(Is_store_A){
C_warp_buffer_load_lds_offset = (C_warp_buffer_load_m_id * 32) + (C_warp_buffer_load_k_id * (4 * BLOCK_M + 2));
}
int C_gsoffset = (C_block_buffer_load_global_offset + C_warp_buffer_load_global_offset)/2 ;
int C_gvoffset;
if constexpr (Is_even_MN){
C_gvoffset = ((C_lane_m_idx * 2 + (C_lane_k_idx + C_warp_buffer_load_k_id*4)* M))/2 ;
} else {
C_gvoffset = ((C_lane_m_idx * 2 + min(C_lane_k_idx + C_warp_buffer_load_k_id*4, K-1)* M))/2 ;
}
// int gvOffset = (64*8 + lane_id*8)/2;
int A_lds_offset = (3*BLOCK_K*BLOCK_M + C_warp_buffer_load_lds_offset)/2;
if(Is_store_A){
A_lds_offset = (3*(BLOCK_K/32)*(BLOCK_M/32)*32*34 + C_warp_buffer_load_lds_offset)/2;
}
builtin_buffer_load_dword_lds(A_lds , C_ptr, A_lds_offset, C_gsoffset, C_gvoffset);
}
}
}
#endif
#endif
}
// K BLOCK_K BLOCK_N BLOCK_M BLOCK_K WARP_N
template<bool Is_preload_A, bool Is_store_A, bool Is_even_MN, int M/*head_dim*/, int BLOCK_M, int BLOCK_N, int BLOCK_K, int WARP_M, int WARP_N, int STAGES, typename Element, typename ElementAccum = float>
__forceinline__ __device__ void gpu_gemm_B_in_reg_gfx938(
vec4_uint A_ptr,
vec4_uint C_ptr,
Element* A_lds,
union_vec4_f16x2<Element> B_reg[(WARP_M/32)*(BLOCK_K/32)*2],
union_vec4_fp32 C_reg[(M/BLOCK_M)*(WARP_M/32)*(WARP_N/32)][4],
int N/*seq_kv*/,
int K/*seq_q*/,
int warp_id,
int seqlen_A_stride) {
#if 1
const int WARP_NUM = (BLOCK_M*BLOCK_N)/(WARP_M*WARP_N);
const int A_lds_load_num = (BLOCK_M*BLOCK_K) / (4*32);
static_assert(BLOCK_K>=32, "Error: gpu_gemm_B_in_reg gemm BLOCK_K must be equal or greater than 32");
static_assert(BLOCK_N>=WARP_N, "Error: gpu_gemm_B_in_reg gemm BLOCK_N must be equal or greater than WARP_N");
static_assert(BLOCK_M==WARP_M, "Error: gpu_gemm_B_in_reg gemm BLOCK_M must be equal to WARP_M");
union_vec4_f16x2<Element> A_reg[((WARP_M*BLOCK_K)/(32*32))*2];
//c mini tile is 32*32
// vec4_fp32 o[(WARP_M/32)*(WARP_N/32)][4]={0};
// __shared__ Element A_lds[STAGES*BLOCK_N * BLOCK_K];
//wave size should be defined in launch file. Here use 64 threads
int lane_id = threadIdx.x & 63; //lane id, 0-63
int row = lane_id % 4;
int col = lane_id / 4;
int stage_id = 0;
if(STAGES > 1 && (!Is_preload_A)) {
int m_loop = 0;
int A_block_buffer_load_global_offset = m_loop * BLOCK_M;
int A_lds_stage_offset = stage_id * BLOCK_M * BLOCK_K;
prefetch_to_lds_gfx938<false, BLOCK_M, BLOCK_K, Element, ElementAccum, Is_even_MN>(A_ptr, A_block_buffer_load_global_offset, A_lds + A_lds_stage_offset, seqlen_A_stride, warp_id);
}
#if 1
// int lds_offset = row * 8 + col * 32;
for(int m_loop = 1; m_loop<(M/BLOCK_M) + 1; m_loop++) {
if(STAGES > 1) {
if constexpr(Is_preload_A || Is_store_A){
stage_id ++;
} else {
stage_id = stage_id ^ 1;
}
}
if(STAGES == 1) {
m_loop--;
}
if((!Is_preload_A)&& m_loop < (M/BLOCK_M)) {
int A_block_buffer_load_global_offset = m_loop*BLOCK_M;
int A_lds_stage_offset = (stage_id)*BLOCK_K*BLOCK_M;
prefetch_to_lds_gfx938<false, BLOCK_M, BLOCK_K, Element, ElementAccum, Is_even_MN>(A_ptr, A_block_buffer_load_global_offset, A_lds + A_lds_stage_offset, seqlen_A_stride, warp_id);
}
//BM = 32, BK = 32
if(warp_id == 0) {
if(!Is_preload_A){
if(STAGES > 1 && m_loop < (M/BLOCK_M)) {
vmcnt_wait(1);
} else {
vmcnt_wait(0);
}
}
}
if constexpr (STAGES > 1) {
if constexpr(Is_preload_A || Is_store_A){
stage_id --;
} else {
stage_id = stage_id ^ 1;
}
}
//lds -> vgpr use ds_read_m; left matrix
if(!Is_preload_A) {
int A_lds_stage_offset = stage_id * BLOCK_K * BLOCK_M;
// DS_READ_MATRIX_32X32_B16(ds_offset_cast(A_lds + A_lds_stage_offset), A_reg[0].f16, A_reg[1].f16, false);
if constexpr (std::is_same_v<Element, half_t>) {
auto *const f16_lds = hcu_ds_read_matrix_f16_lds_base(A_lds + A_lds_stage_offset);
A_reg[0].f16x8 = __builtin_hcu_ds_read_matrix_format_f16(f16_lds, 0, 2, 1, 0);
A_reg[1].f16x8 = __builtin_hcu_ds_read_matrix_format_f16(f16_lds, 1024, 2, 1, 0);
} else {
auto *const bf16_lds = hcu_ds_read_matrix_bf16_lds_base(A_lds + A_lds_stage_offset);
A_reg[0].f16x8 = __builtin_hcu_ds_read_matrix_format_bf16(bf16_lds, 0, 2, 1, 0);
A_reg[1].f16x8 = __builtin_hcu_ds_read_matrix_format_bf16(bf16_lds, 1024, 2, 1, 0);
}
} else {
// gfx938 m_ab = 0的gemm想要复用m_ab = 1的LDS数据
int A_lane_head_dim_idx = lane_id % 16;
int A_lane_seq_idx = lane_id / 16;
vec2_Element<Element> *A_lds_v2fp16 = (vec2_Element<Element> *)(A_lds);
for(int min_tile_k = 0; min_tile_k < 2; min_tile_k++) {
for(int vec_idx = 0; vec_idx < 4; vec_idx++) {
//dword为单位
int lds_offset = stage_id * BLOCK_K * BLOCK_M / 2 + A_lane_seq_idx * 4 * 16 + vec_idx * 16 + min_tile_k * 16 * 16;
lds_offset += (A_lane_head_dim_idx + vec_idx / 2 * 4 + (A_lane_seq_idx % 2) * 8) % 16;
// int lds_offset = stage_id * BLOCK_K * BLOCK_M / 2 + A_lane_seq_idx/2 * 64 + A_lane_seq_idx % 2 * 16 + min_tile_k * 32 + vec_idx * 128 + A_lane_head_dim_idx;
inline_ds_read_b32_wait(A_lds_v2fp16, lds_offset, A_reg[min_tile_k].f16x2[vec_idx]);
}
}
}
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_sched_barrier(0);
if constexpr (STAGES == 1){
m_loop++;
}
// asm volatile("s_setprio 1");
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) {
#pragma unroll
for(int k_idx=0; k_idx<(BLOCK_K/32); k_idx++) { //BLOCK_K mini size is 32
//min tile is 32*32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for(int min_tile_k=0; min_tile_k<2; min_tile_k++) {
if constexpr (std::is_same<Element,Float8_e4m3_t>::value){
} else {
//A采用ds_read后对应的mmac
C_reg[m_loop-1][min_tile_n*2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
//BN = 32, BK = 32
// vec4_Element<Element>{A_reg[min_tile_k].f16[0*2 + min_tile_m], A_reg[min_tile_k].f16[1*2 + min_tile_m], A_reg[min_tile_k].f16[2*2 + min_tile_m], A_reg[min_tile_k].f16[3*2 + min_tile_m]},
vec4_Element<Element>{A_reg[min_tile_k].f16x2[0][min_tile_m], A_reg[min_tile_k].f16x2[1][min_tile_m], A_reg[min_tile_k].f16x2[2][min_tile_m], A_reg[min_tile_k].f16x2[3][min_tile_m]},
B_reg[min_tile_k].f16x4[min_tile_n],
C_reg[m_loop-1][min_tile_n*2 + min_tile_m].f32);
}
}
}
}
}
}
}
// asm volatile("s_setprio 0");
if(STAGES > 1) {
if constexpr(Is_preload_A || Is_store_A){
stage_id ++;
} else {
stage_id ^=1;
__builtin_amdgcn_sched_barrier(0);
// asm volatile("s_waitcnt lgkmcnt(0)");
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
}
} else {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0)");
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
}
}
#endif
#endif
}
#pragma once
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
#include "intrinsic.h"
#include "numeric_types.h"
#include "intrinsic_mls_ds.h"
#include "prefetch.h"
// 无预取:prefetch_level = 0; 预取到LDS:prefetch_level = 1; 预取到寄存器:prefetch_level = 2;
template<bool Is_store_B, bool Is_preload_C, bool Is_even_MN, int A_prefetch_level, int B_prefetch_level, int K, int BLOCK_M, int BLOCK_N, int BLOCK_K, int WARP_M, int WARP_N, int STAGES, typename Element, typename ElementAccum = float>
__forceinline__ __device__ void gemm_tt_kq(vec4_uint A_ptr, vec4_uint B_ptr, Element* A_lds, Element* B_lds, int max_m_len_offset, int max_n_len_offset, union_vec2_f16x2<Element> A_reg[(K/BLOCK_K)*((WARP_M*BLOCK_K)/(32*32))*2/((A_prefetch_level == 3)? 1 : 2)][2], union_vec2_f16x2<Element> B_reg[STAGES*((WARP_N*BLOCK_K)/(32*32))*2][2], union_vec4_fp32 C_reg[(WARP_M/32)*(BLOCK_N/32)][4], int warp_id, int seqlen_A_stride, int seqlen_B_stride) {
const int WARP_NUM = BLOCK_M/WARP_M;
//wave size should be defined in launch file. Here use 64 threads
int lane_id = threadIdx.x & 63; //lane id, 0-63
int row = lane_id % 4;
int col = lane_id / 4;
#if 1
for(int n_loop = 0 ; n_loop < BLOCK_N/WARP_N; n_loop++)
{
int stage_id = 0;
int stage_id_reg = 0;
{ int k_loop = 0;
if(STAGES > 1) {
if(A_prefetch_level == 0) {
int A_block_buffer_load_global_offset = k_loop * BLOCK_K;
int A_lds_stage_offset = stage_id * (BLOCK_M/32) * (BLOCK_K/32)*(32*34);
buffer_load_lds_tile_pad(Is_even_MN, WARP_NUM, seqlen_A_stride, BLOCK_M, BLOCK_K, Element, A_ptr, A_lds, A_block_buffer_load_global_offset, A_lds_stage_offset, K, warp_id, lane_id);
}
if(B_prefetch_level == 0) {
int B_block_buffer_load_global_offset = k_loop * BLOCK_K + n_loop * WARP_N * K;
int B_lds_stage_offset = stage_id * (WARP_N/32) * (BLOCK_K/32)*(32*34);
if constexpr (Is_store_B){
B_lds_stage_offset += n_loop * (K/32) * (WARP_N/32)*(32*34);
}
buffer_load_lds_tile_pad_1(Is_even_MN, WARP_NUM, seqlen_B_stride, WARP_N, BLOCK_K, Element, B_ptr, B_lds, B_block_buffer_load_global_offset, B_lds_stage_offset, K, warp_id, lane_id);
}
}
}
// int lds_offset = row * 8 + col * 32;
for(int k_loop = 1; k_loop<(K/BLOCK_K) + 1; k_loop++) {
if constexpr (STAGES > 1) {
if constexpr (Is_store_B || B_prefetch_level == 1){
stage_id++;
} else {
stage_id ^= 1;
}
}
if(STAGES == 1) {
k_loop--;
}
if(k_loop < (K/BLOCK_K)){
if(A_prefetch_level == 0 || (A_prefetch_level == 1 && k_loop >= (K/BLOCK_K)/2)) {
int A_block_buffer_load_global_offset = k_loop * BLOCK_K;
int A_lds_stage_offset = stage_id * (BLOCK_M/32) * (BLOCK_K/32)*(32*34);
buffer_load_lds_tile_pad(Is_even_MN, WARP_NUM, seqlen_A_stride, BLOCK_M, BLOCK_K, Element, A_ptr, A_lds, A_block_buffer_load_global_offset, A_lds_stage_offset, K, warp_id, lane_id);
}
if(B_prefetch_level == 0) {
int B_block_buffer_load_global_offset = k_loop * BLOCK_K + n_loop * WARP_N * K;
int B_lds_stage_offset = stage_id * (WARP_N/32) * (BLOCK_K/32)*(32*34);
if constexpr (Is_store_B || B_prefetch_level == 1){
B_lds_stage_offset += n_loop * (K/32) * (WARP_N/32)*(32*34);
}
buffer_load_lds_tile_pad_1(Is_even_MN, WARP_NUM, seqlen_B_stride, WARP_N, BLOCK_K, Element, B_ptr, B_lds, B_block_buffer_load_global_offset, B_lds_stage_offset, K, warp_id, lane_id);
}
}
else if (B_prefetch_level==0){
vmcnt_wait(0);
}
int precompute_B_lds_offset[2*2];
//lds -> vgpr use ds_read_m; right matrix
int k_warp_n_id = (warp_id & (WARP_N/WARP_N - 1));
int k_lds_stage_offset = STAGES == 1 ? 0 : ( (Is_store_B || B_prefetch_level == 1) ? (stage_id - 1) * (WARP_N/32) * (BLOCK_K/32)*(32*17) : (stage_id ^ 1) * (WARP_N/32) * (BLOCK_K/32)*(32*17));
vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(B_lds);
//a warp load min size is (row, col) = (32,16) float
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int head_dim_idx=0; head_dim_idx<(BLOCK_K/32); head_dim_idx++) { //32 half in col direction
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
#pragma unroll
for(int vec_idx = 0; vec_idx < 2; vec_idx ++) {
int lds_offset = k_lds_stage_offset + head_dim_idx*(WARP_N*17) + (k_warp_n_id*(WARP_N/32) + n_idx)*(32*17) + vec_idx * 4 + min_tile_n*32 + (lane_id & 1)*16 + ((lane_id & 15)>>1)*64 + /*padding*/ ((lane_id & 15)>>1) + ((lane_id/16) & 1)*8 + (lane_id/32);
precompute_B_lds_offset[min_tile_n * 2 + vec_idx] = lds_offset;
if constexpr (Is_store_B || B_prefetch_level == 1){
precompute_B_lds_offset[min_tile_n * 2 + vec_idx] += n_loop * (WARP_N/32) * (K/32)*(32*17);
}
}
}
}
}
if(STAGES > 1) {
if constexpr(B_prefetch_level==1){
if constexpr (std::is_same<Element,Float8_e4m3_t>::value){
vmcnt_wait(0);
} else {
vmcnt_wait(((BLOCK_N/WARP_N * K/BLOCK_K)*(Is_preload_C ? 2 : 1) - (n_loop * (K/BLOCK_K) + k_loop)) * (WARP_N*BLOCK_K) / (4*32)/WARP_NUM);
}
} else {
if(k_loop < (K/BLOCK_K)){
if(A_prefetch_level == 0 && B_prefetch_level != 0) {
buffer_load_lds_dwordx1_wait<(BLOCK_M * BLOCK_K) / (4*32)/WARP_NUM>();
} else if(A_prefetch_level != 0 && B_prefetch_level == 0) {
buffer_load_lds_dwordx1_wait<(WARP_N*BLOCK_K) / (4*32)/WARP_NUM>();
} else if(A_prefetch_level == 0 && B_prefetch_level == 0) {
buffer_load_lds_dwordx1_wait<(BLOCK_M * BLOCK_K) / (4*32)/WARP_NUM + (WARP_N*BLOCK_K) / (4*32)/WARP_NUM>();
}
}
}
} else {
vmcnt_wait(0);
}
if constexpr (STAGES > 1) {
if constexpr (Is_store_B || B_prefetch_level == 1){
stage_id--;
} else {
stage_id ^= 1;
}
}
union_vec2_f16x2<Element> A_reg_tmp[2][2];
if (A_prefetch_level == 0 || (A_prefetch_level != 3 && k_loop >= (K/BLOCK_K)/2 )) {
//lds -> vgpr use ds_read_m; left matrix
int A_warp_m_id = (warp_id & ((BLOCK_M/WARP_M) - 1));
int A_lds_stage_offset = stage_id * (BLOCK_M/32) * (BLOCK_K/32)*(32*17);
vec2_Element<Element> *A_lds_v2fp16 = (vec2_Element<Element> *)(A_lds);
asm volatile("s_setprio 1");
// #pragma unroll
// for(int head_dim_idx=0; head_dim_idx<(BLOCK_K/32); head_dim_idx++) { //32 half in col direction
// #pragma unroll
// for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) {
// //a warp load min size is (row, col) = (32,16) float
// #pragma unroll
// for(int i=0; i<2; i++) { //sequence direction
// #pragma unroll
// for(int j=0; j<2; j++) { //head dim direction
// int lds_offset = A_lds_stage_offset + head_dim_idx*BLOCK_M*17 + (warp_id*(WARP_M/32) + m_idx)*(32*17) + j*4 + i*32 + (lane_id & 1)*16 + ((lane_id & 15)>>1)*64 + /*padding*/ ((lane_id & 15)>>1) + ((lane_id/16) &1)*8 + (lane_id/32);
// inline_ds_read2_b32_no_wait(A_lds_v2fp16, lds_offset, A_reg_tmp[(head_dim_idx*(WARP_M/32) + m_idx)*2 + i][j].f32, 2);
// }
// // #pragma unroll
// // for(int j=0; j<4; j++) { //head dim direction
// // int lds_offset = A_lds_stage_offset + head_dim_idx*BLOCK_M*17 + (warp_id*(WARP_M/32) + m_idx)*(32*17) + j*2 + i*32 + (lane_id & 1)*16 + ((lane_id & 15)>>1)*64 + /*padding*/ ((lane_id & 15)>>1) + ((lane_id/16) &1)*8 + (lane_id/32);
// // inline_ds_read_b32_wait(A_lds_v2fp16, lds_offset, A_reg_tmp[(head_dim_idx*(WARP_M/32) + m_idx)*2 + i][j/2].f16x2[j%2]);
// // }
// }
// }
// }
ds_read_tile_pad<WARP_M, BLOCK_K, WARP_NUM, Element>(A_lds_v2fp16, A_lds_stage_offset, A_reg_tmp, 0, warp_id, lane_id);
asm volatile("s_setprio 0");
}
// int k_warp_n_id = (warp_id & (WARP_N/WARP_N - 1));
// int k_lds_stage_offset = STAGES == 1 ? 0 : (stage_id ) * (WARP_N/32) * (BLOCK_K/32)*(32*17);
// if constexpr (Is_store_B || B_prefetch_level == 1){
// k_lds_stage_offset += n_loop * (WARP_N/32) * (K/32)*(32*17);
// }
// vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(B_lds);
ds_read2_tile_pad_no_wait(WARP_M, BLOCK_K, WARP_NUM, Element, k_lds_v2fp16, precompute_B_lds_offset, B_reg, stage_id_reg, k_warp_n_id, lane_id);
// ds_read2_tile_pad_no_wait<WARP_M, BLOCK_K, WARP_NUM, Element>(k_lds_v2fp16, k_lds_stage_offset, B_reg, stage_id_reg, k_warp_n_id, lane_id);
// ds_read_tile_pad<WARP_M, BLOCK_K, WARP_NUM, Element>(k_lds_v2fp16, k_lds_stage_offset, B_reg, stage_id_reg, k_warp_n_id, lane_id);
if constexpr (STAGES == 1){
k_loop++;
}
//min tile is 32*32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16
asm volatile("s_setprio 1");
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
const int lgkmcnt = 2 - min_tile_n*2;
lgkmcnt_wait(lgkmcnt);
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) {
for(int head_dim_idx=0; head_dim_idx< (BLOCK_K/32); head_dim_idx++) {
#pragma unroll
for(int min_tile_k=0; min_tile_k<2; min_tile_k++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
if constexpr (std::is_same<Element,Float8_e4m3_t>::value){
if (A_prefetch_level == 0 || (A_prefetch_level != 3 && k_loop >= (K/BLOCK_K)/2 + 1 )){
C_reg[n_loop*(WARP_N/32*WARP_M/32) + (n_idx) * (WARP_M/32) + m_idx][min_tile_n*2 + min_tile_m].f32 = flash::mmac<half_t, ElementAccum>(
vec4_Element<half_t>{
UpCast<Element, half_t, true>(A_reg_tmp[(head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][min_tile_k].f16x2[0][0]),
UpCast<Element, half_t, true>(A_reg_tmp[(head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][min_tile_k].f16x2[0][1]),
UpCast<Element, half_t, true>(A_reg_tmp[(head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][min_tile_k].f16x2[1][0]),
UpCast<Element, half_t, true>(A_reg_tmp[(head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][min_tile_k].f16x2[1][1])},
vec4_Element<half_t>{
UpCast<Element, half_t, true>(B_reg[(stage_id_reg *((WARP_N*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx))*2 + min_tile_n][min_tile_k].f16x2[0][0]),
UpCast<Element, half_t, true>(B_reg[(stage_id_reg *((WARP_N*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx))*2 + min_tile_n][min_tile_k].f16x2[0][1]),
UpCast<Element, half_t, true>(B_reg[(stage_id_reg *((WARP_N*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx))*2 + min_tile_n][min_tile_k].f16x2[1][0]),
UpCast<Element, half_t, true>(B_reg[(stage_id_reg *((WARP_N*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx))*2 + min_tile_n][min_tile_k].f16x2[1][1])},
C_reg[n_loop*(WARP_N/32*WARP_M/32) + (n_idx) * (WARP_M/32) + m_idx][min_tile_n*2 + min_tile_m].f32);
} else {
C_reg[n_loop*(WARP_N/32*WARP_M/32) + (n_idx)*(WARP_M/32) + m_idx][min_tile_n*2 + min_tile_m].f32 = flash::mmac<half_t, ElementAccum>(
vec4_Element<half_t>{
UpCast<Element, half_t, true>(A_reg[(k_loop-1)*((WARP_M*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][min_tile_k].f16x2[0][0]),
UpCast<Element, half_t, true>(A_reg[(k_loop-1)*((WARP_M*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][min_tile_k].f16x2[0][1]),
UpCast<Element, half_t, true>(A_reg[(k_loop-1)*((WARP_M*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][min_tile_k].f16x2[1][0]),
UpCast<Element, half_t, true>(A_reg[(k_loop-1)*((WARP_M*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][min_tile_k].f16x2[1][1])},
vec4_Element<half_t>{
UpCast<Element, half_t, true>(B_reg[(stage_id_reg *((WARP_N*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx))*2 + min_tile_n][min_tile_k].f16x2[0][0]),
UpCast<Element, half_t, true>(B_reg[(stage_id_reg *((WARP_N*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx))*2 + min_tile_n][min_tile_k].f16x2[0][1]),
UpCast<Element, half_t, true>(B_reg[(stage_id_reg *((WARP_N*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx))*2 + min_tile_n][min_tile_k].f16x2[1][0]),
UpCast<Element, half_t, true>(B_reg[(stage_id_reg *((WARP_N*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx))*2 + min_tile_n][min_tile_k].f16x2[1][1])},
C_reg[n_loop*(WARP_N/32*WARP_M/32) + (n_idx)*(WARP_M/32) + m_idx][min_tile_n*2 + min_tile_m].f32);
}
} else {
if (A_prefetch_level == 0 || (A_prefetch_level != 3 && k_loop >= (K/BLOCK_K)/2 + 1 )){
C_reg[n_loop*(WARP_N/32*WARP_M/32) + (n_idx) * (WARP_M/32) + m_idx][min_tile_n*2 + min_tile_m].f32 = flash::mmac<Element, ElementAccum>(
vec4_Element<Element>{
A_reg_tmp[(head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][min_tile_k].f16x2[0][0],
A_reg_tmp[(head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][min_tile_k].f16x2[0][1],
A_reg_tmp[(head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][min_tile_k].f16x2[1][0],
A_reg_tmp[(head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][min_tile_k].f16x2[1][1]},
vec4_Element<Element>{
B_reg[(stage_id_reg *((WARP_N*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx))*2 + min_tile_n][min_tile_k].f16x2[0][0],
B_reg[(stage_id_reg *((WARP_N*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx))*2 + min_tile_n][min_tile_k].f16x2[0][1],
B_reg[(stage_id_reg *((WARP_N*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx))*2 + min_tile_n][min_tile_k].f16x2[1][0],
B_reg[(stage_id_reg *((WARP_N*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx))*2 + min_tile_n][min_tile_k].f16x2[1][1]},
C_reg[n_loop*(WARP_N/32*WARP_M/32) + (n_idx) * (WARP_M/32) + m_idx][min_tile_n*2 + min_tile_m].f32);
} else {
C_reg[n_loop*(WARP_N/32*WARP_M/32) + (n_idx)*(WARP_M/32) + m_idx][min_tile_n*2 + min_tile_m].f32 = flash::mmac<Element, ElementAccum>(
vec4_Element<Element>{
A_reg[(k_loop-1)*((WARP_M*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][min_tile_k].f16x2[0][0],
A_reg[(k_loop-1)*((WARP_M*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][min_tile_k].f16x2[0][1],
A_reg[(k_loop-1)*((WARP_M*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][min_tile_k].f16x2[1][0],
A_reg[(k_loop-1)*((WARP_M*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][min_tile_k].f16x2[1][1]},
vec4_Element<Element>{
B_reg[(stage_id_reg *((WARP_N*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx))*2 + min_tile_n][min_tile_k].f16x2[0][0],
B_reg[(stage_id_reg *((WARP_N*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx))*2 + min_tile_n][min_tile_k].f16x2[0][1],
B_reg[(stage_id_reg *((WARP_N*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx))*2 + min_tile_n][min_tile_k].f16x2[1][0],
B_reg[(stage_id_reg *((WARP_N*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx))*2 + min_tile_n][min_tile_k].f16x2[1][1]},
C_reg[n_loop*(WARP_N/32*WARP_M/32) + (n_idx)*(WARP_M/32) + m_idx][min_tile_n*2 + min_tile_m].f32);
}
}
}
}
}
}
}
}
asm volatile("s_setprio 0");
if constexpr (STAGES > 1){
if constexpr (!Is_store_B && B_prefetch_level !=1) {
stage_id ^= 1;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0)");
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
} else{
stage_id++;
}
} else {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0)");
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
}
}
}
#endif
}
// 无预取:prefetch_level = 0; 预取到LDS:prefetch_level = 1; 预取到寄存器:prefetch_level = 2;
template<bool Is_store_B, bool Is_preload_C, bool Is_even_MN, int A_prefetch_level, int B_prefetch_level, int K, int BLOCK_M, int BLOCK_N, int BLOCK_K, int WARP_M, int WARP_N, int STAGES, typename Element, typename ElementAccum = float>
__forceinline__ __device__ void gemm_tt_kq_gfx938(
vec4_uint A_ptr,
vec4_uint B_ptr,
Element* A_lds,
Element* B_lds,
int max_m_len_offset,
int max_n_len_offset,
union_vec4_f16x2<Element> A_reg[(K/BLOCK_K)*((WARP_M*BLOCK_K)/(32*32))*2/((A_prefetch_level == 3)? 1 : 2)],
union_vec4_f16x2<Element> B_reg[STAGES*((WARP_N*BLOCK_K)/(32*32))*2],
union_vec4_fp32 C_reg[(WARP_M/32)*(BLOCK_N/32)][4],
int warp_id,
int seqlen_A_stride,
int seqlen_B_stride) {
const int ELEMENT_BYTES = sizeof(Element);
const int WARP_NUM = BLOCK_M/WARP_M;
//wave size should be defined in launch file. Here use 64 threads
int lane_id = threadIdx.x & 63; //lane id, 0-63
int row = lane_id % 4;
int col = lane_id / 4;
#if 1
for(int n_loop = 0 ; n_loop < BLOCK_N/WARP_N; n_loop++)
{
int stage_id = 0;
int stage_id_reg = 0;
{ int k_loop = 0;
if(STAGES > 1) {
if(A_prefetch_level == 0) {
int A_block_buffer_load_global_offset = k_loop * BLOCK_K;
int A_lds_stage_offset = stage_id * BLOCK_M* BLOCK_K;
prefetch_to_lds_gfx938<true, BLOCK_M, BLOCK_K, Element, ElementAccum, Is_even_MN>(A_ptr, A_block_buffer_load_global_offset, A_lds + A_lds_stage_offset, seqlen_A_stride, warp_id);
}
if(B_prefetch_level == 0) {
int B_block_buffer_load_global_offset = k_loop * BLOCK_K + n_loop * WARP_N * K;
int B_lds_stage_offset = stage_id * WARP_N * BLOCK_K;
if constexpr (Is_store_B){
B_lds_stage_offset += n_loop * K * WARP_N;
}
prefetch_to_lds_gfx938<true, WARP_N, BLOCK_K, Element, ElementAccum, Is_even_MN>(B_ptr, B_block_buffer_load_global_offset, B_lds + B_lds_stage_offset, seqlen_B_stride, warp_id);
}
}
}
// int lds_offset = row * 8 + col * 32;
for(int k_loop = 1; k_loop<(K/BLOCK_K) + 1; k_loop++) {
if constexpr (STAGES > 1) {
if constexpr (Is_store_B || B_prefetch_level == 1){
stage_id++;
} else {
stage_id ^= 1;
}
}
if(STAGES == 1) {
k_loop--;
}
if(k_loop < (K/BLOCK_K)){
if(A_prefetch_level == 0 || (A_prefetch_level == 1 && k_loop >= (K/BLOCK_K)/2)) {
int A_block_buffer_load_global_offset = k_loop * BLOCK_K;
int A_lds_stage_offset = stage_id * BLOCK_M* BLOCK_K;
prefetch_to_lds_gfx938<true, BLOCK_M, BLOCK_K, Element, ElementAccum, Is_even_MN>(A_ptr, A_block_buffer_load_global_offset, A_lds + A_lds_stage_offset, seqlen_A_stride, warp_id);
}
if(B_prefetch_level == 0) {
int B_block_buffer_load_global_offset = k_loop * BLOCK_K + n_loop * WARP_N * K;
int B_lds_stage_offset = stage_id * WARP_N * BLOCK_K;
if constexpr (Is_store_B){
B_lds_stage_offset += n_loop * K * WARP_N;
}
prefetch_to_lds_gfx938<true, WARP_N, BLOCK_K, Element, ElementAccum, Is_even_MN>(B_ptr, B_block_buffer_load_global_offset, B_lds + B_lds_stage_offset, seqlen_B_stride, warp_id);
}
}
else if (B_prefetch_level==0){
vmcnt_wait_nosync(0);
}
//MLS可以一条指令读32*32,而buffer_load_dword一条指令读2*64的数据,所以waitcnt需要进行修改
//且MLS是一个warp读32*32,4个warp读128*32,需要判断warp_id来wait
if(STAGES > 1) {
if constexpr(B_prefetch_level==1){
if((k_loop - 1) % WARP_NUM == warp_id)
{
if(Is_preload_C) {
vmcnt_wait_nosync(1);
} else {
vmcnt_wait_nosync(0);
}
}
} else {
if(k_loop < (K/BLOCK_K)){
if(A_prefetch_level == 0 && B_prefetch_level != 0) {
//BM = 128
vmcnt_wait_nosync((BLOCK_M * BLOCK_K) / (32*32)/WARP_NUM);
} else if(A_prefetch_level != 0 && B_prefetch_level == 0) {
//BN = 32
if(warp_id == 0) {
vmcnt_wait_nosync(1);
}
} else if(A_prefetch_level == 0 && B_prefetch_level == 0) {
//BM= 128 & BN = 32
if(warp_id == 0) {
vmcnt_wait_nosync((BLOCK_M * BLOCK_K) / (32*32)/WARP_NUM + 1);
} else {
vmcnt_wait_nosync(1);
}
}
}
}
} else {
vmcnt_wait_nosync(0);
}
__syncthreads();
if constexpr (STAGES > 1) {
if constexpr (Is_store_B || B_prefetch_level == 1){
stage_id--;
} else {
stage_id ^= 1;
}
}
union_vec4_f16x2<Element> A_reg_tmp[2];
if (A_prefetch_level == 0 || (A_prefetch_level != 3 && k_loop >= (K/BLOCK_K)/2 )) {
int A_lds_stage_offset = stage_id * BLOCK_M * BLOCK_K;
// DS_READ_MATRIX_32X32_B16(ds_offset_cast(A_lds + A_lds_stage_offset), A_reg_tmp[0].f16, A_reg_tmp[1].f16, true);
if constexpr (std::is_same_v<Element, half_t>) {
auto *const f16_lds = hcu_ds_read_matrix_f16_lds_base(A_lds + A_lds_stage_offset);
A_reg_tmp[0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(f16_lds, 0, 2, 1, 0);
A_reg_tmp[1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(f16_lds, 1024, 2, 1, 0);
} else {
auto *const bf16_lds = hcu_ds_read_matrix_bf16_lds_base(A_lds + A_lds_stage_offset);
A_reg_tmp[0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(bf16_lds, 0, 2, 1, 0);
A_reg_tmp[1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(bf16_lds, 1024, 2, 1, 0);
}
}
int B_lds_stage_offset = stage_id * WARP_N * BLOCK_K;
DS_READ_MATRIX_32X32_B16(ds_offset_cast(B_lds + B_lds_stage_offset), B_reg[0].f16, B_reg[1].f16, true);
if constexpr (STAGES == 1){
k_loop++;
}
//min tile is 32*32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16
// asm volatile("s_setprio 1");
lgkmcnt_wait(0);
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) {
for(int head_dim_idx=0; head_dim_idx< (BLOCK_K/32); head_dim_idx++) {
#pragma unroll
for(int min_tile_k=0; min_tile_k<2; min_tile_k++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
if constexpr (std::is_same<Element,Float8_e4m3_t>::value){
} else {
if (A_prefetch_level == 0 || (A_prefetch_level != 3 && k_loop >= (K/BLOCK_K)/2 + 1 )){
C_reg[n_loop*(WARP_N/32*WARP_M/32) + (n_idx) * (WARP_M/32) + m_idx][min_tile_n*2 + min_tile_m].f32 = mmac_4interleave<Element, ElementAccum>(
A_reg_tmp[min_tile_m].f16x4[min_tile_k],
B_reg[min_tile_n].f16x4[min_tile_k],
C_reg[n_loop*(WARP_N/32*WARP_M/32) + (n_idx) * (WARP_M/32) + m_idx][min_tile_n*2 + min_tile_m].f32);
} else {
C_reg[n_loop*(WARP_N/32*WARP_M/32) + (n_idx) * (WARP_M/32) + m_idx][min_tile_n*2 + min_tile_m].f32 = mmac_4interleave<Element, ElementAccum>(
A_reg[(k_loop - 1) * 2 + min_tile_m].f16x4[min_tile_k],
B_reg[min_tile_n].f16x4[min_tile_k],
C_reg[n_loop*(WARP_N/32*WARP_M/32) + (n_idx)*(WARP_M/32) + m_idx][min_tile_n*2 + min_tile_m].f32);
}
}
}
}
}
}
}
}
// asm volatile("s_setprio 0");
if constexpr (STAGES > 1){
if constexpr (!Is_store_B && B_prefetch_level !=1) {
stage_id ^= 1;
__builtin_amdgcn_sched_barrier(0);
// asm volatile("s_waitcnt lgkmcnt(0)");
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
} else{
stage_id++;
}
} else {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0)");
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
}
}
}
#endif
}
\ No newline at end of file
#pragma once
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
#include "utils.h"
#include "static_switch.h"
#include "numeric_types.h"
#include "intrinsic_mls_ds.h"
template<int K, int BLOCK_M, int BLOCK_K, int WARP_M, typename Element, typename ElementAccum, bool Is_even_MN>
inline __device__ void prefetch_to_vgpr(
vec4_uint k_ptr,
Element* k_lds,
union_vec2_f16x2<Element> k_reg[(K/BLOCK_K)*((WARP_M*BLOCK_K)/(32*32))*2][2],
int max_seq_k_offset,
int row_stride) {
const int WARP_NUM = (BLOCK_M)/(WARP_M);
const int k_lds_load_num = (BLOCK_M * BLOCK_K) / (4*32);
const int K_LOAD_REQUESTS = k_lds_load_num / WARP_NUM;
int warp_id =0;
int warp_id_vec = threadIdx.x / 64; //warp id in a block
warp_id = __builtin_amdgcn_readfirstlane(warp_id_vec);
int k_warp_m_id = (warp_id & ((BLOCK_M/WARP_M) - 1));
int lane_id = threadIdx.x & 63; //lane id, 0-63
int k_lane_m_idx = ((lane_id >> 4) & 1)*2 + ((lane_id >> 4) >> 1); //(0, 1, 2, 3) --> (0, 2, 1, 3)
int k_lane_head_dim_idx = lane_id & 15;
// int lds_offset = row * 8 + col * 32;
int stage_id = 0;
// MLS
vec4_uint k_srsrc;
k_srsrc[2] = row_stride; // stride
k_srsrc[3] = 0;
#pragma unroll
for(int k_loop = 0; k_loop<K/BLOCK_K; k_loop++) {
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
//global->lds, left matrix
int q_block_buffer_load_global_offset = k_loop * BLOCK_K ;//+ block_id_m * BLOCK_M * K;
// k_ptr buffer load mini size is 4*32, (BLOCK_M * BLOCK_K) mini size is (32*32)
int k_lds_stage_offset = stage_id * (BLOCK_M/32) * (BLOCK_K/32)*(32*34);
for(int load = 0,warp_loop = warp_id; load < K_LOAD_REQUESTS; warp_loop += WARP_NUM, ++load) {
int padding = (warp_loop & 7)*2; // padding size in shared memory per buffer load, to avoid bank conflict
int k_warp_buffer_load_m_id = (warp_loop & (BLOCK_M/4 - 1)); //这样子对L1和utlc1有啥影响呢?
// int q_warp_buffer_load_k_id = (warp_loop / (BLOCK_M/4));
int q_warp_buffer_load_lds_offset = k_lds_stage_offset/* + (q_warp_buffer_load_k_id * BLOCK_M * 34)*/ + ((k_warp_buffer_load_m_id >> 3)*(32*34) + (k_warp_buffer_load_m_id & 7)*(4*32));
// int q_warp_buffer_load_global_offset = (q_warp_buffer_load_k_id * 32);
int gvOffset_s = (q_block_buffer_load_global_offset/* + q_warp_buffer_load_global_offset*/) / 2;
int gvOffset_v;
if constexpr (not Is_even_MN) {
gvOffset_v = ((min(k_warp_buffer_load_m_id * 4 + k_lane_m_idx, max_seq_k_offset - 1)) * row_stride) / 2 + k_lane_head_dim_idx;
} else {
gvOffset_v = ((k_warp_buffer_load_m_id * 4 + k_lane_m_idx) * row_stride) / 2 + k_lane_head_dim_idx;
}
int lds_offset = (q_warp_buffer_load_lds_offset + padding) / 2; // + lane_id;
builtin_buffer_load_dword_lds_bypass_glc_slc(k_lds, k_ptr, lds_offset, gvOffset_s, gvOffset_v);
}
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
// k_lds_stage_offset = stage_id * (BLOCK_M/32) * (BLOCK_K/32)*(32*17);
vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(k_lds);
ds_read_tile_pad(WARP_M, BLOCK_K, WARP_NUM, Element, k_lds_v2fp16, k_lds_stage_offset, k_reg, k_loop, warp_id, lane_id);
}
}
}
//matrix_load单位:32 * 32
//ds_read_matrix单位:32 * 16
//M = 128, N = 128
template<bool trans, int M, int N, typename Element, typename ElementAccum, bool Is_even_MN>
inline __device__ void prefetch_to_vgpr_gfx938(
vec4_uint ptr,
Element* lds,
union_vec4_f16x2<Element> reg[M * N / (64 * 8)],//vec4_fp16x2有8个element,64个线程
int max_column_offset,
int warp_id) {
constexpr int ELEMENT_BYTES = sizeof(Element);
const int stages = 2;
const int WARP_NUM = 4;
int row_stride = ptr[2];
vec4_uint srsrc;
srsrc[2] = row_stride;
srsrc[3] = 0;
//计算LDS地址,每个warp使用一个32*32
int lds_offset = (warp_id * 32 * 32);
size_t lds_load_offset = reinterpret_cast<size_t>(lds) + lds_offset * ELEMENT_BYTES;
int stages_id = 0;
if(stages == 2) {
int m_loop = 0;
int n_loop = 0;
int global_offset = (warp_id * row_stride * 32 + n_loop * 32);
int lds_offset_stage = (lds_offset + stages_id * (WARP_NUM * 32 * 32)) * ELEMENT_BYTES;
if constexpr (!Is_even_MN) {
//对M方向进行边界判断,看需要pad多少0
int nm_filter_max = (m_loop * 128 + (warp_id + 1) * 32) - max_column_offset;
int nm_filter = max(0, (m_loop * 128 + (warp_id + 1) * 32) - max_column_offset);
if(nm_filter_max >= 32) {
global_offset = (0 * row_stride * 32 + n_loop * 32);
nm_filter = max(0, (m_loop * 128 + 0 * 32) - max_column_offset);
}
srsrc[3] = nm_filter << 8; // set only once
}
*(uint64_t*)&srsrc = VA_LIMIT_BITS(*(uint64_t*)&ptr + global_offset * ELEMENT_BYTES);
union union_vec4_uint rsrc_bits;
rsrc_bits.v32 = srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(lds) + lds_offset_stage;
if(trans) {
matrix_load_b16_lds_trans_builtin<32, 32, 0, 0>(lds_addr_warp, rsrc_bits.i32, 0);
} else {
matrix_load_b16_lds_builtin<32, 32, 0, 0>(lds_addr_warp, rsrc_bits.i32, 0);
}
}
for(int m_loop = 0; m_loop < M / 128; ++m_loop) {
for(int n_loop = stages - 1; n_loop < N / 32 + stages - 1; ++n_loop) {
if(stages == 2) {
stages_id ^= 1;
}
//更新global地址
int global_offset = (warp_id * row_stride * 32 + n_loop * 32);
int lds_offset_stage = (lds_offset + stages_id * (WARP_NUM * 32 * 32)) * ELEMENT_BYTES;
// size_t lds_load_offset_stage = reinterpret_cast<size_t>(lds) + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) * ELEMENT_BYTES + lds_offset * ELEMENT_BYTES;
if constexpr (!Is_even_MN) {
//对M方向进行边界判断,看需要pad多少0
int nm_filter_max = (m_loop * 128 + (warp_id + 1) * 32) - max_column_offset;
int nm_filter = max(0, (m_loop * 128 + (warp_id + 1) * 32) - max_column_offset);
if(nm_filter_max >= 32) {
global_offset = (0 * row_stride * 32 + n_loop * 32);
nm_filter = max(0, (m_loop * 128 + 0 * 32) - max_column_offset);
}
srsrc[3] = nm_filter << 8; // set only once
}
*(uint64_t*)&srsrc = VA_LIMIT_BITS(*(uint64_t*)&ptr + global_offset * ELEMENT_BYTES);
if(n_loop < N / 32) {
union union_vec4_uint rsrc_bits;
rsrc_bits.v32 = srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(lds) + lds_offset_stage;
if(trans) {
matrix_load_b16_lds_trans_builtin<32, 32, 0, 0>(lds_addr_warp, rsrc_bits.i32, 0);
} else {
matrix_load_b16_lds_builtin<32, 32, 0, 0>(lds_addr_warp, rsrc_bits.i32, 0);
}
}
if(stages == 2 && n_loop < N /32) {
vmcnt_wait_nosync(1);
} else {
vmcnt_wait_nosync(0);
}
// __builtin_amdgcn_s_waitcnt(0);
// __syncthreads();
if(trans){
// DS_READ_MATRIX_32X32_B16(ds_offset_cast(lds_load_offset_stage), reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2].f16, reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2 + 1].f16, true);
if constexpr (std::is_same_v<Element, half_t>) {
auto *const f16_lds = hcu_ds_read_matrix_f16_lds_base(
lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2].f16x8 =
__builtin_hcu_ds_read_matrix_trans_format_f16(f16_lds, 0, 2, 1, 0);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2 + 1].f16x8 =
__builtin_hcu_ds_read_matrix_trans_format_f16(f16_lds, 1024, 2, 1, 0);
} else {
auto *const bf16_lds = hcu_ds_read_matrix_bf16_lds_base(
lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2].f16x8 =
__builtin_hcu_ds_read_matrix_trans_format_bf16(bf16_lds, 0, 2, 1, 0);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2 + 1].f16x8 =
__builtin_hcu_ds_read_matrix_trans_format_bf16(bf16_lds, 1024, 2, 1, 0);
}
} else {
// DS_READ_MATRIX_32X32_B16(ds_offset_cast(lds_load_offset_stage), reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2].f16, reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2 + 1].f16, false);
if constexpr (std::is_same_v<Element, half_t>) {
auto *const f16_lds = hcu_ds_read_matrix_f16_lds_base(
lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2].f16x8 =
__builtin_hcu_ds_read_matrix_format_f16(f16_lds, 0, 2, 1, 0);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2 + 1].f16x8 =
__builtin_hcu_ds_read_matrix_format_f16(f16_lds, 1024, 2, 1, 0);
} else {
auto *const bf16_lds = hcu_ds_read_matrix_bf16_lds_base(
lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2].f16x8 =
__builtin_hcu_ds_read_matrix_format_bf16(bf16_lds, 0, 2, 1, 0);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2 + 1].f16x8 =
__builtin_hcu_ds_read_matrix_format_bf16(bf16_lds, 1024, 2, 1, 0);
}
}
lgkmcnt_wait(0);
// __builtin_amdgcn_s_waitcnt(0);
// __syncthreads();
}
}
}
//matrix_load单位:32 * 32
//ds_read_matrix单位:32 * 16
//M = 32, N = 128
template<bool trans, int M, int N, typename Element, typename ElementAccum, bool Is_even_MN, int WARP_NUM = 4>
inline __device__ void prefetch_to_lds_gfx938(
vec4_uint ptr,
int global_start_offset,
Element* lds,
int max_column_offset,
int warp_id) {
const int ELEMENT_BYTES = sizeof(Element);
const int LOAD_NUM = M * N / (32 * 32);
int row_stride = ptr[2];
vec4_uint srsrc;
srsrc[2] = row_stride;
srsrc[3] = 0;
// __builtin_amdgcn_s_waitcnt(0);
// __syncthreads();
//直接拉通M * N,看有多少个 32*32 的矩阵需要load
for(int loop = 0; loop < (LOAD_NUM + WARP_NUM - 1) / WARP_NUM; loop++) {
int loop_warp = loop * WARP_NUM + warp_id;
if (loop_warp < LOAD_NUM) {
int m_loop = loop_warp / (N / 32);
int n_loop = loop_warp % (N / 32);
//更新global地址
int global_offset = (global_start_offset + m_loop * row_stride + n_loop * 32) * ELEMENT_BYTES;
if constexpr (!Is_even_MN) {
//对M方向进行边界判断,看需要pad多少0
int nm_filter_max = (m_loop + 1) * 32 - max_column_offset;
int nm_filter = nm_filter_max;
if(nm_filter_max >= 32) {
global_offset = (global_start_offset + 0 * row_stride + n_loop * 32) * ELEMENT_BYTES;
nm_filter = (0 + 1) * 32 - max_column_offset;
}
nm_filter = max(0, nm_filter);
srsrc[3] = nm_filter << 8; // set only once
}
*(uint64_t*)&srsrc = VA_LIMIT_BITS(*(uint64_t*)&ptr + global_offset);
//计算LDS地址,每个warp使用一个32*32;下一个loop重复利用
int lds_offset = (loop_warp * 32 * 32) * ELEMENT_BYTES;
union union_vec4_uint rsrc_bits;
rsrc_bits.v32 = srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(lds) + lds_offset;
if (trans) {
matrix_load_b16_lds_trans_builtin<32, 32, 0, 0>(lds_addr_warp, rsrc_bits.i32, 0);
} else {
matrix_load_b16_lds_builtin<32, 32, 0, 0>(lds_addr_warp, rsrc_bits.i32, 0);
}
}
}
// __builtin_amdgcn_s_waitcnt(0);
// __syncthreads();
}
template<bool Is_even_MN, int K/*head_dim*/, int BLOCK_M, int BLOCK_N, int BLOCK_K, int WARP_M, int WARP_N, typename Element>
__forceinline__ __device__ void prefetch_to_tmp_lds_wait(vec4_uint B_ptr, Element* B_lds, int max_n_len_offset, int warp_id, int row_stride)
{
const int WARP_NUM = BLOCK_M/WARP_M;
int lane_id = threadIdx.x & 63; //lane id, 0-63
for(int n_loop = 0 ; n_loop < BLOCK_N/WARP_N; n_loop++){
for(int k_loop = 0; k_loop < K/BLOCK_K; k_loop++) {
const int lgkmcnt = (BLOCK_N/WARP_N * K/BLOCK_K - 1) - (n_loop * K/BLOCK_K + k_loop);
lgkmcnt_wait(lgkmcnt);
int B_block_buffer_load_global_offset = k_loop * BLOCK_K + n_loop * WARP_N * K;
// headdim=256时的LDS用量为 256/32 * 32 * 34 * 2byte= 17 KB,如果同时读Q和dO到LDS,就会超过32KB
// headdim=224时的LDS用量为 224/32 * 32 * 34 * 2byte= 14.875 KB,如果同时读Q和dO到LDS,不会超32KB
int B_lds_stage_offset = k_loop * (WARP_N/32) * (BLOCK_K/32)*(32*34) + n_loop * (K/32) * (WARP_N/32)*(32*34);
buffer_load_lds_tile_pad(Is_even_MN, WARP_NUM, row_stride, WARP_N, BLOCK_K, Element, B_ptr, B_lds, B_block_buffer_load_global_offset, B_lds_stage_offset, max_n_len_offset, warp_id, lane_id);
}
}
}
#pragma once
#include "numeric_types.h"
#include "utils.h"
using namespace flash;
//32*32的tile,结果矩阵根据奇偶分开设计
//mask_type == 0:无mask
//mask_type == 1: mask矩阵右上角
//mask_type == 2: mask矩阵左下角
template <bool Is_even_MN, int mask_type>
inline __device__ void apply_mask_bwd(union_vec4_fp32 tensor[1][4], int M, int N, int M_minus_N, int window_size_left, int window_size_right) {
const int lane_id = threadIdx.x & 63;
const int lane_m_idx = (lane_id & 15);
const int lane_n_idx = (lane_id >> 4);
//无mask,仅进行边界判断
if(!Is_even_MN && mask_type == 0) {
for(int min_tile_m = 0; min_tile_m < 2; min_tile_m ++) {
for(int min_tile_n = 0; min_tile_n < 2; min_tile_n ++) {
for(int vec_idx = 0; vec_idx < 4; vec_idx ++) {
int N_offset = lane_n_idx * 2 + min_tile_n + vec_idx * 8;
if(N_offset > N - 1){
tensor[0][min_tile_n * 2 + min_tile_m].f32[vec_idx] = -INFINITY;
}
}
}
}
}
//mask右上角
if (mask_type == 1 && (!Is_even_MN || Is_even_MN && std::abs(M_minus_N) < 128)) {
for(int min_tile_m = 0; min_tile_m < 2; min_tile_m ++) {
int M_offset = lane_m_idx * 2 + min_tile_m;
for(int min_tile_n = 0; min_tile_n < 2; min_tile_n ++) {
for(int vec_idx = 0; vec_idx < 4; vec_idx ++) {
int N_offset = lane_n_idx * 2 + min_tile_n + vec_idx * 8;
int N_limit = Is_even_MN ? (M_offset + M_minus_N) : min(N - 1, M_offset + M_minus_N);
if(N_offset > N_limit){
tensor[0][min_tile_n * 2 + min_tile_m].f32[vec_idx] = -INFINITY;
}
}
}
}
}
//mask左下角
if (mask_type == 2 && (!Is_even_MN || Is_even_MN && std::abs(M_minus_N) < 128)) {
for(int min_tile_m = 0; min_tile_m < 2; min_tile_m ++) {
int M_offset = lane_m_idx * 2 + min_tile_m;
for(int min_tile_n = 0; min_tile_n < 2; min_tile_n ++) {
for(int vec_idx = 0; vec_idx < 4; vec_idx ++) {
int N_offset = lane_n_idx * 2 + min_tile_n + vec_idx * 8;
int N_limit = (M_offset + M_minus_N);
if((!Is_even_MN && N_offset > N - 1) || N_offset < N_limit){
tensor[0][min_tile_n * 2 + min_tile_m].f32[vec_idx] = -INFINITY;
}
}
}
}
}
//local mask
if (mask_type == 3) {// && (!Is_even_MN || Is_even_MN && (std::abs(M_minus_N - window_size_left) < 128 || std::abs(M_minus_N + window_size_right) < 128))
for(int min_tile_m = 0; min_tile_m < 2; min_tile_m ++) {
int M_offset = lane_m_idx * 2 + min_tile_m;
for(int min_tile_n = 0; min_tile_n < 2; min_tile_n ++) {
for(int vec_idx = 0; vec_idx < 4; vec_idx ++) {
int N_offset = lane_n_idx * 2 + min_tile_n + vec_idx * 8;
int N_limit_left = (M_offset + M_minus_N - window_size_left);
int N_limit_right = (M_offset + M_minus_N + window_size_right);
if((!Is_even_MN && N_offset > N - 1) || N_offset <= N_limit_left || N_offset >= N_limit_right){
tensor[0][min_tile_n * 2 + min_tile_m].f32[vec_idx] = -INFINITY;
}
}
}
}
}
}
//32*32的tile,结果矩阵根据mmac_4interleave设计
//mask_type == 0:无mask
//mask_type == 1: mask矩阵右上角
//mask_type == 2: mask矩阵左下角
template <bool Is_even_MN, int mask_type>
inline __device__ void apply_mask_bwd_gfx938(union_vec4_fp32 tensor[1][4], int M, int N, int M_minus_N, int window_size_left, int window_size_right) {
const int lane_id = threadIdx.x & 63;
const int lane_m_idx = (lane_id & 15);
const int lane_n_idx = (lane_id >> 4);
//无mask,仅进行边界判断
if(!Is_even_MN && mask_type == 0) {
for(int min_tile_m = 0; min_tile_m < 2; min_tile_m ++) {
for(int min_tile_n = 0; min_tile_n < 2; min_tile_n ++) {
for(int vec_idx = 0; vec_idx < 4; vec_idx ++) {
int N_offset = min_tile_n * 16 + lane_n_idx * 4 + vec_idx;
if(N_offset > N - 1){
tensor[0][min_tile_n * 2 + min_tile_m].f32[vec_idx] = -INFINITY;
}
}
}
}
}
//mask右上角
if (mask_type == 1 && (!Is_even_MN || Is_even_MN && std::abs(M_minus_N) < 128)) {
for(int min_tile_m = 0; min_tile_m < 2; min_tile_m ++) {
int M_offset = min_tile_m * 16 + lane_m_idx;
for(int min_tile_n = 0; min_tile_n < 2; min_tile_n ++) {
for(int vec_idx = 0; vec_idx < 4; vec_idx ++) {
int N_offset = min_tile_n * 16 + lane_n_idx * 4 + vec_idx;
int N_limit = Is_even_MN ? (M_offset + M_minus_N) : min(N - 1, M_offset + M_minus_N);
if(N_offset > N_limit){
tensor[0][min_tile_n * 2 + min_tile_m].f32[vec_idx] = -INFINITY;
}
}
}
}
}
//mask左下角
if (mask_type == 2 && (!Is_even_MN || Is_even_MN && std::abs(M_minus_N) < 128)) {
for(int min_tile_m = 0; min_tile_m < 2; min_tile_m ++) {
int M_offset = min_tile_m * 16 + lane_m_idx;
for(int min_tile_n = 0; min_tile_n < 2; min_tile_n ++) {
for(int vec_idx = 0; vec_idx < 4; vec_idx ++) {
int N_offset = min_tile_n * 16 + lane_n_idx * 4 + vec_idx;
int N_limit = (M_offset + M_minus_N);
if((!Is_even_MN && N_offset > N - 1) || N_offset < N_limit){
tensor[0][min_tile_n * 2 + min_tile_m].f32[vec_idx] = -INFINITY;
}
}
}
}
}
//local mask
if (mask_type == 3) {// && (!Is_even_MN || Is_even_MN && (std::abs(M_minus_N - window_size_left) < 128 || std::abs(M_minus_N + window_size_right) < 128))
for(int min_tile_m = 0; min_tile_m < 2; min_tile_m ++) {
int M_offset = min_tile_m * 16 + lane_m_idx;
for(int min_tile_n = 0; min_tile_n < 2; min_tile_n ++) {
for(int vec_idx = 0; vec_idx < 4; vec_idx ++) {
int N_offset = min_tile_n * 16 + lane_n_idx * 4 + vec_idx;
int N_limit_left = (M_offset + M_minus_N - window_size_left);
int N_limit_right = (M_offset + M_minus_N + window_size_right);
if((!Is_even_MN && N_offset > N - 1) || N_offset <= N_limit_left || N_offset >= N_limit_right){
tensor[0][min_tile_n * 2 + min_tile_m].f32[vec_idx] = -INFINITY;
}
}
}
}
}
}
template <bool encode_dropout_in_sign_bit=false, typename DataType, int WARP_M, int WARP_N>
inline __device__ void apply_dropout(const DataType tensor[(WARP_M/32)*(WARP_N/32)][4], uint8_t p_dropout_in_uint8_t,
unsigned long long seed, unsigned long long offset,
int block_col_start, int block_row_start,
int block_col_stride) {
// tensor has shape (8, MMA_M, MMA_N / 2)
auto encode_dropout = [](bool keep, DataType val) {
return keep ? val : (encode_dropout_in_sign_bit ? -val : DataType(0));
};
// static_assert(decltype(size<2>(tensor))::value % 2 == 0);
const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t);
const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t);
// if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); }
#pragma unroll
for (int n = 0; n < (WARP_N/32); ++n, block_col_start += block_col_stride) {
uint2 rowcol = make_uint2(block_row_start, block_col_start);
#pragma unroll
for (int m = 0; m < (WARP_M/32); ++m, ++rowcol.x) {
// if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));}
uint4 random_uint4 = flash::philox(seed, reinterpret_cast<unsigned long long&>(rowcol), offset);
// if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);}
uint8_t (&rnd_8)[16] = reinterpret_cast<uint8_t (&)[16]>(random_uint4);
// Special implementation for 16-bit types: we duplicate the threshold to the
// low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction
// to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000,
// and the high 16 bits will be either 0xffff or 0x0000, depending on whether
// the random value is less than the threshold.
// We then do a bit-wise AND between the mask and the original value (in 32-bit).
// We're exploiting the fact that floating point comparison is equivalent to integer
// comparison, since we're comparing unsigned integers whose top 8-bits are zero.
if (!encode_dropout_in_sign_bit
&& (std::is_same<DataType, Float16>::value || std::is_same<DataType, BFloat16>::value)) {
// uint16_t rnd_16[16];
// #pragma unroll
// for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); }
// uint32_t (&rnd_32)[8] = reinterpret_cast<uint32_t (&)[8]>(rnd_16);
// #pragma unroll
// for (int j = 0; j < 2; j++) {
// Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
// // if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); }
// // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
// #pragma unroll
// for (int i = 0; i < 4; i++) {
// uint32_t mask;
// asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t));
// tensor_uint32(i) &= mask;
// }
// // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
// }
} else {
//min tile for a warp is 32*32
#pragma unroll
for (int n_idx = 0; n_idx < 2; n_idx++) {
#pragma unroll
for (int m_idx = 0; m_idx < 2; m_idx++) {
for(int vec_idx=0; vec_idx<4; vec_idx++) { //mmac min_tile is 16*16, a warp is 64 thread
tensor[(n*(WARP_N/16)*(WARP_M/16) + m*(WARP_M/16)) + n_idx * 2 + m_idx][vec_idx] = encode_dropout(rnd_8[n_idx * 8 + m_idx] <= p_dropout_in_uint8_t, tensor[(n*(WARP_N/16)*(WARP_M/16) + m*(WARP_M/16)) + n_idx * 2 + m_idx][vec_idx]);
}
}
// Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
}
}
// // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w);
// // }
}
}
}
template<bool zero_init=true, typename Operator, int OpType, typename DataType0, typename DataType1, int WARP_M, int WARP_N>
__device__ inline void thread_reduce_(const DataType0 tensor[(WARP_M/32)*(WARP_N/32)][4], DataType1 *summary, Operator &op, DataType1 *summary_cur=nullptr) {
if(zero_init == true) {
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
summary[m_idx*2 + min_tile_m] = (OpType==0)? 0 : -INFINITY; //OpType:0 is sum operator, 1 is max operator
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
for(int vec_idx=0; vec_idx<4; vec_idx++) { //mmac min_tile is 16*16, a warp is 64 thread
summary[m_idx*2 + min_tile_m] = op(summary[m_idx*2 + min_tile_m], tensor[m_idx + n_idx*(WARP_M/32)][min_tile_n*2 + min_tile_m][vec_idx]);
}
}
}
}
}
} else {
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
summary_cur[m_idx*2 + min_tile_m] = summary[m_idx*2 + min_tile_m];// op(summary[m_idx*2 + min_tile_m], tensor[m_idx][min_tile_m][0]);
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
for(int vec_idx=0; vec_idx<4; vec_idx++) { //mmac min_tile is 16*16, a warp is 64 thread
summary_cur[m_idx*2 + min_tile_m] = op(summary_cur[m_idx*2 + min_tile_m], tensor[m_idx + n_idx*(WARP_M/32)][min_tile_n*2 + min_tile_m][vec_idx]);
}
}
}
}
}
}
}
template<typename Operator, typename DataType, int WARP_M>
__device__ inline void quad_allreduce_(DataType *dst, DataType *src, Operator &op) {
#pragma unroll
for (int i = 0; i < (WARP_M/16); i++) {
dst[i] = Allreduce<64>::run(src[i], op);
}
}
template<bool zero_init=true, typename Operator, int OpType, typename DataType0, typename DataType1, int WARP_M, int WARP_N>
__device__ inline void reduce_(const DataType0 tensor[(WARP_M/32)*(WARP_N/32)][4], DataType1 *summary, Operator &op, DataType1 *summary_cur=nullptr) {
if(zero_init == true) {
thread_reduce_<true, Operator, OpType, DataType0, DataType1, WARP_M, WARP_N>(tensor, summary, op);
quad_allreduce_<Operator, DataType1, WARP_M>(summary, summary, op);
} else {
thread_reduce_<false, Operator, OpType, DataType0, DataType1, WARP_M, WARP_N>(tensor, summary, op, summary_cur);
quad_allreduce_<Operator, DataType1, WARP_M>(summary_cur, summary_cur, op);
}
}
//zero_init==true, max is current max_score, max_cur=nullptr
//zero_init==true, max is prev max_score, max_cur!=nullptr
template<bool zero_init=true, typename DataType0, typename DataType1, int WARP_M, int WARP_N>
__device__ inline void reduce_max(const DataType0 tensor[(WARP_M/32)*(WARP_N/32)][4], DataType1 *max , DataType1 *max_cur=nullptr) {
MaxOp<float> max_op;
if(zero_init == true) {
reduce_<true, MaxOp<float>, 1, DataType0, DataType1, WARP_M, WARP_N>(tensor, max, max_op);
} else {
reduce_<false, MaxOp<float>, 1, DataType0, DataType1, WARP_M, WARP_N>(tensor, max, max_op, max_cur);
}
}
template<bool zero_init=true, typename DataType0, typename DataType1, int WARP_M, int WARP_N>
__device__ inline void reduce_sum(DataType0 tensor[(WARP_M/32)*(WARP_N/32)][4], DataType1 *sum, DataType1 *sum_cur=nullptr){
SumOp<float> sum_op;
if(zero_init == true) {
reduce_<true, SumOp<float>, 0, DataType0, DataType1, WARP_M, WARP_N>(tensor, sum, sum_op);
} else {
reduce_<false, SumOp<float>, 0, DataType0, DataType1, WARP_M, WARP_N>(tensor, sum, sum_op, sum_cur);
}
}
// Apply the exp to all the elements.
template <bool Scale_max=true, int BLOCK_M, int WARP_N, typename DataType0, typename DataType1>
inline __device__ void scale_apply_exp2_bwd(DataType0 tensor[(BLOCK_M/32)*(WARP_N/32)][4], const DataType1 *max, const float scale) {
// #if defined(__gfx936__)
// auto vec2_scale = vec2_fp32{scale, scale};
// #endif
#pragma unroll
for (int mi = 0; mi < (BLOCK_M/32); ++mi) {
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
// If we don't have float around M_LOG2E the multiplication is done in fp64.
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
for(int vec_idx=0; vec_idx<4; vec_idx++) {
const float max_scaled = (max[(mi*2 + min_tile_m)*4 + vec_idx] * (Scale_max ? scale : float(M_LOG2E)));
// #if defined(__gfx936__)
// auto vec2_max_scaled = vec2_fp32{-max_scaled, -max_scaled};
// #endif
#pragma unroll
for (int ni = 0; ni < (WARP_N/32); ++ni) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
//min tile is 32*32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16
#if 0//defined(__gfx936__)
auto vec2_tensor = vec2_fp32{tensor[ni + mi*(WARP_N/32)][min_tile_m*2].f32[vec_idx], tensor[ni + mi*(WARP_N/32)][min_tile_m*2 + 1].f32[vec_idx]};
auto vec2_scale = vec2_fp32{scale, scale};
auto vec2_max_scaled = vec2_fp32{-max_scaled, -max_scaled};
auto tensor_tmp =
hcu_pk_fma_f32(
vec2_tensor,
vec2_scale,
vec2_max_scaled);
// __builtin_hcu_v_pk_fma_f32(
// vec2_tensor,
// vec2_scale,
// vec2_max_scaled);
tensor[ni + mi*(WARP_N/32)][min_tile_m*2].f32[vec_idx] = __llvm_exp2_f32(tensor_tmp[0]);
tensor[ni + mi*(WARP_N/32)][min_tile_m*2 + 1].f32[vec_idx] = __llvm_exp2_f32(tensor_tmp[1]);
#else
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { //使用__llvm_exp2_f32会产生nan,使用exp2f则没问题
// tensor[ni + mi*(WARP_N/32)][min_tile_n + min_tile_m*2].f32[vec_idx] =exp2f(tensor[ni + mi*(WARP_N/32)][min_tile_n + min_tile_m*2].f32[vec_idx] * scale - max_scaled);
tensor[ni + mi*(WARP_N/32)][min_tile_n + min_tile_m*2].f32[vec_idx] =__llvm_exp2_f32(tensor[ni + mi*(WARP_N/32)][min_tile_n + min_tile_m*2].f32[vec_idx] * scale - max_scaled);
}
#endif
}
}
}
}
}
// Apply the exp to all the elements.
template <bool Scale_max=true, int WARP_M, int BLOCK_N, typename DataType0, typename DataType1>
inline __device__ void scale_apply_exp2_bwd_seq_q_major(DataType0 tensor[(BLOCK_N/32)*(WARP_M/32)][4], const DataType1 max[WARP_M/16], const float scale) {
// const float max_scaled = max[0] * float(M_LOG2E);
#pragma unroll
for (int ni = 0; ni < (BLOCK_N/32); ++ni) {
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
// If we don't have float around M_LOG2E the multiplication is done in fp64.
#pragma unroll
for (int mi = 0; mi < (WARP_M/32); ++mi) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
//min tile is 32*32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
const float max_scaled = (max[mi*2 + min_tile_m] * (Scale_max ? scale : float(M_LOG2E)));
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
tensor[mi + ni*(WARP_M/32)][min_tile_n*2 + min_tile_m].f32[vec_idx] =
__llvm_exp2_f32(tensor[mi + ni*(WARP_M/32)][min_tile_n*2 + min_tile_m].f32[vec_idx] * scale - max_scaled);
// tensor[mi + ni*(WARP_M/32)][min_tile_n*2 + min_tile_m].f32[vec_idx] =
// exp2f(tensor[mi + ni*(WARP_M/32)][min_tile_n*2 + min_tile_m].f32[vec_idx] * scale - max_scaled);
// tensor[mi + ni*(WARP_M/32)][min_tile_n*2 + min_tile_m].f32[vec_idx] =
// __llvm_exp2_f32(tensor[mi + ni*(WARP_M/32)][min_tile_n*2 + min_tile_m].f32[vec_idx] * scale - max_scaled + 64) * __llvm_exp2_f32(-64);
}
}
}
}
}
}
#if 0
template<bool Is_first, bool Check_inf=false, typename DataType0, typename DataType1,int K/*head_dim*/, int kBlockK, int WARP_M, int WARP_N>
inline __device__ void softmax_rescale_o(DataType0 scores[(WARP_N/32)*(WARP_M/32)][4], DataType1 *scores_max, DataType1 *scores_sum,
DataType0 acc_o[(K/kBlockK) * ((WARP_M/32)*(kBlockK/32))][4], float softmax_scale_log2) {
if (Is_first) {
reduce_max</*zero_init=*/true, DataType0, DataType1, WARP_M, WARP_N>(scores, scores_max);
scale_apply_exp2<true, DataType0, DataType1, WARP_M, WARP_N>(scores, scores_max, softmax_scale_log2);
reduce_sum<true, DataType0, DataType1, WARP_M, WARP_N>(scores, scores_sum);
} else {
float scores_max_cur[WARP_M/16]; //calculate max of each row
reduce_max</*zero_init=*/false, DataType0, DataType1, WARP_M, WARP_N>(scores, scores_max, scores_max_cur); // scores_max is prev scores max
for (int mi = 0; mi < (WARP_M/32); ++mi) {
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
// If we don't have float around M_LOG2E the multiplication is done in fp64.
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
float scores_max_cur_reg = !Check_inf
? scores_max_cur[mi*2 + min_tile_m]
: (scores_max_cur[mi*2 + min_tile_m] == -INFINITY ? 0.0f : scores_max_cur[mi*2 + min_tile_m]);
float scores_scale = __llvm_exp2_f32((scores_max[mi*2 + min_tile_m] - scores_max_cur_reg) * softmax_scale_log2);
scores_sum[mi*2 + min_tile_m] *= scores_scale;
#pragma unroll
for(int pv_n_loop=0; pv_n_loop<(K/kBlockK); pv_n_loop++) {
#pragma unroll
for (int ni = 0; ni < (kBlockK/32); ++ni) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
for(int vec_idx=0; vec_idx<4; vec_idx++) {
//min tile is 32*32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
acc_o[pv_n_loop * ((WARP_M/32)*(kBlockK/32)) + (mi + ni*(WARP_M/32))][min_tile_n*2 + min_tile_m][vec_idx] = acc_o[pv_n_loop * ((WARP_M/32)*(kBlockK/32)) + (mi + ni*(WARP_M/32))][min_tile_n*2 + min_tile_m][vec_idx] * scores_scale;
}
}
}
}
}
}
scale_apply_exp2<true, DataType0, DataType1, WARP_M, WARP_N>(scores, scores_max_cur, softmax_scale_log2);
float scores_sum_cur[WARP_M/16]={0.0f};
reduce_sum<true, DataType0, DataType1, WARP_M, WARP_N>(scores, scores_sum_cur);
#pragma unroll
for (int mi = 0; mi < (WARP_M/16); ++mi) { scores_sum[mi] += scores_sum_cur[mi]; }
}
};
#endif
#pragma once
#include "hip/hip_fp16.h"
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
#include "numeric_types.h"
#include "intrinsic.h"
#if defined(__gfx936__) || defined(__gfx938__)
#define parallel_degree 3
#else
#define parallel_degree 2
#endif
template<typename T>
void check(T result, char const* const func, const char* const file, int const line)
{
if (result) {
throw std::runtime_error(std::string("[GPU][ERROR] HIP runtime error: ") + hipGetErrorString(result) + " " + file + ":" + std::to_string(line) + " \n");
}
}
#define check_hip_error(val) check((val), #val, __FILE__, __LINE__)
namespace flash {
inline __device__ constexpr int ceil_div(int const& a, int const& b) {
return (a + b - 1) / b;
}
template<class T>
__device__ vec4_fp32 mmac(const vec4_Element<T> &v1, const vec4_Element<T> &v2, const vec4_fp32 &v3)
{
#if defined(__gfx916__) || defined(__gfx926__)
return {0, 0, 0, 0};
#else
return __builtin_hcu_mmac_f32_16x16x16_f16(v1, v2, v3);
#endif
}
template<>
__device__ vec4_fp32 mmac<half_t>(const vec4_fp16 &v1, const vec4_fp16 &v2, const vec4_fp32 &v3)
{
#if defined(__gfx916__) || defined(__gfx926__)
return {0, 0, 0, 0};
#else
return __builtin_hcu_mmac_f32_16x16x16_f16(v1, v2, v3);
#endif
}
template<>
__device__ vec4_fp32 mmac<bhalf_t>(const vec4_bf16 &v1, const vec4_bf16 &v2, const vec4_fp32 &v3)
{
#if defined(__gfx916__) || defined(__gfx926__)
return {0, 0, 0, 0};
#else
return __builtin_hcu_mmac_f32_16x16x16_bf16(v1, v2, v3);
#endif
}
template<typename T>
__forceinline__ __device__ T __shfl_xor_tmp(T x, const int lane_mask) {
int lane_id = threadIdx.x & 63;
int index = (lane_id ^ lane_mask) << 2;
int res = __builtin_amdgcn_ds_bpermute(index, *(int*)&x); // attention, __builtin only support int
return *(T*)&res;
}
template<typename T>
struct MaxOp {
__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; }
};
template <>
struct MaxOp<float> {
// This is slightly faster
__device__ inline float operator()(float const &x, float const &y) { return max(x, y); }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct SumOp {
__device__ inline T operator()(T const & x, T const & y) {
T res = (x + y);
return res;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int THREADS>
struct Allreduce {
static_assert(THREADS == 64);
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
x = op(x, __shfl_xor_tmp(x, 32));
return op(x, __shfl_xor_tmp(x, 16));
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Allreduce<32> {
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
//x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
x = op(x, __shfl_xor_tmp(x, 16));
return x;
}
};
template<typename T, int WARP_M>
void copy(T *src, T *dst) {
for(int i=0; i<(WARP_M/16); i++) {
dst[i] = src[i];
}
}
//TODO:后续优化得用上V_CVT_PKRTZ_FP16_FP32
//QK(seq_q, seq_k), two float in seq_k dim convert to two half, and packed into a U32
template <int BLOCK_M, int WARP_N, typename ElementType>
inline __device__ void convert_pk_type(union_vec2_f16x2<ElementType> p_reg[(BLOCK_M/32)*(WARP_N/32)][4], union_vec4_fp32 s_reg[(BLOCK_M/32)*(WARP_N/32)][4]) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
#pragma unroll
for(int m_idx=0; m_idx<(BLOCK_M/32); m_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
// for(int vec_idx=0; vec_idx<4; vec_idx++) {
// p_reg[n_idx + m_idx*(WARP_N/32)][min_tile_n*2 + min_tile_m].f16[vec_idx] = DownCast<float,ElementType,true>(s_reg[n_idx + m_idx*(WARP_N/32)][min_tile_n*2 + min_tile_m].f32[vec_idx]);
// }
for(int vec_idx=0; vec_idx<2; vec_idx++) {
p_reg[n_idx + m_idx*(WARP_N/32)][min_tile_n*2 + min_tile_m].f16x2[vec_idx][0] = DownCast<float,ElementType,true>(s_reg[n_idx + m_idx*(WARP_N/32)][min_tile_n*2 + min_tile_m].f32[vec_idx*2]);
p_reg[n_idx + m_idx*(WARP_N/32)][min_tile_n*2 + min_tile_m].f16x2[vec_idx][1] = DownCast<float,ElementType,true>(s_reg[n_idx + m_idx*(WARP_N/32)][min_tile_n*2 + min_tile_m].f32[vec_idx*2+1]);
}
}
}
}
}
}
//TODO:后续优化得用上V_CVT_PKRTZ_FP16_FP32
//QK(seq_q, seq_k), two float in seq_k dim convert to two half, and packed into a U32
template <int BLOCK_M, int WARP_N, typename ElementType>
inline __device__ void convert_pk_type_gfx938(union_vec4_f16x2<ElementType> p_reg[(BLOCK_M/32)*(WARP_N/32)*2], union_vec4_fp32 s_reg[(BLOCK_M/32)*(WARP_N/32)][4]) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
#pragma unroll
for(int m_idx=0; m_idx<(BLOCK_M/32); m_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
for(int vec_idx=0; vec_idx<4; vec_idx++) {
p_reg[(n_idx + m_idx*(WARP_N/32)) * 2 + min_tile_n].f16[min_tile_m * 4 + vec_idx] = DownCast<float,ElementType,false>(s_reg[n_idx + m_idx*(WARP_N/32)][min_tile_n*2 + min_tile_m].f32[vec_idx]);
// p_reg[(n_idx + m_idx*(WARP_N/32)) * 2 + min_tile_n].f16[min_tile_m * 4 + vec_idx] = s_reg[n_idx + m_idx*(WARP_N/32)][min_tile_n*2 + min_tile_m].f32[vec_idx];
}
}
}
}
}
}
template<const int kHeadDim, typename T>
__device__ __forceinline__ vec4_uint tcp_cache_swizzle_func(T* ptr) {
vec4_uint res;
*(uint64_t*)&res = reinterpret_cast<uint64_t>(ptr);
if constexpr (kHeadDim == 196) {
res[1] += 0x41800000; // 62 bit: cache swizzle; 48~61: Stride
} else if constexpr (kHeadDim == 128) {
res[1] += 0x41000000; // stride 256 Bytes and change tagram
} else if constexpr (kHeadDim == 64) {
res[1] += 0x40800000; // stride 128 Bytes and change tagram
}
res[2] = 0x80000000;
res[3] = 0x00020000;
return res;
}
template<typename T>
__device__ __forceinline__ vec4_uint prepare_for_matrix_load_gfx938(T* ptr, int row_stride) {
vec4_uint srsrc;
*(uint64_t*)&srsrc = reinterpret_cast<uint64_t>(ptr);
srsrc[2] = row_stride;
srsrc[3] = 0;
return srsrc;
}
template<class T, class AccumType>
inline __device__ vec4_fp32 mmac(const vec4_Element<T> &v1, const vec4_Element<T> &v2, const vec4_fp32 &v3)
{
#if defined(__gfx916__) || defined(__gfx926__)
return {0, 0, 0, 0};
#else
return __builtin_hcu_mmac_f32_16x16x16_f16(v1, v2, v3);
#endif
}
template<>
inline __device__ vec4_fp32 mmac<half_t, float>(const vec4_fp16 &v1, const vec4_fp16 &v2, const vec4_fp32 &v3)
{
#if defined(__gfx916__) || defined(__gfx926__)
return {0, 0, 0, 0};
#else
return __builtin_hcu_mmac_f32_16x16x16_f16(v1, v2, v3);
#endif
}
template<>
inline __device__ vec4_fp32 mmac<bhalf_t, float>(const vec4_bf16 &v1, const vec4_bf16 &v2, const vec4_fp32 &v3)
{
#if defined(__gfx916__) || defined(__gfx926__)
return {0, 0, 0, 0};
#else
return __builtin_hcu_mmac_f32_16x16x16_bf16(v1, v2, v3);
#endif
}
//封装buffer_load
template<int Is_M_equal,int WARP_NUM,int N_row_len,int M,int N,typename Element>
__forceinline__ __device__ void buffer_load_lds_tile(vec4_uint global_ptr, Element* lds_ptr, int global_offset, int lds_stage_offset, int max_M_len, int warp_id, int lane_id) {
int bytes_per_Element = 2;
if constexpr (std::is_same<Element, int8_t>::value || std::is_same<Element, Float8_e4m3_t>::value) {
bytes_per_Element = 1;
}
int Elment_per_dword = 4/bytes_per_Element;
//M维度index变换,(0, 1, 2, 3) --> (0, 2, 1, 3)
int lane_M_idx = ((lane_id >> 4) & 1)*2 + ((lane_id >> 4) >> 1);
int lane_N_idx = lane_id & 15;
const int lds_load_num = (M*N*bytes_per_Element) / (4*64);
// for(int warp_loop=warp_id; warp_loop<lds_load_num; warp_loop+=WARP_NUM) {
for(int load = 0,warp_loop = warp_id; load < lds_load_num/WARP_NUM; warp_loop += WARP_NUM, ++load) {
int warp_buffer_load_lds_offset = lds_stage_offset + warp_loop * (4*32);
int gsOffset = global_offset/Elment_per_dword;
int gvOffset;
if constexpr (Is_M_equal){
gvOffset = (warp_loop * 4 + lane_M_idx) * N_row_len/Elment_per_dword + lane_N_idx;
} else {
gvOffset = (min(warp_loop * 4 + lane_M_idx, max_M_len - 1) * N_row_len)/Elment_per_dword + lane_N_idx;
}
int lds_offset = warp_buffer_load_lds_offset/Elment_per_dword;
builtin_buffer_load_dword_lds(lds_ptr, global_ptr, lds_offset, gsOffset, gvOffset);
}
}
//封装buffer_load
template<int Is_M_equal,int WARP_NUM,int N_row_len,int M,int N,typename Element>
__forceinline__ __device__ void buffer_load_lds_tile_pad(vec4_uint global_ptr, Element* lds_ptr, int global_offset, int lds_stage_offset, int max_M_len, int warp_id, int lane_id) {
int bytes_per_Element = 2;
if constexpr (std::is_same<Element, int8_t>::value || std::is_same<Element, Float8_e4m3_t>::value) {
bytes_per_Element = 1;
}
int Elment_per_dword = 4/bytes_per_Element;
//M维度index变换,(0, 1, 2, 3) --> (0, 2, 1, 3)
int lane_M_idx = ((lane_id >> 4) & 1)*2 + ((lane_id >> 4) >> 1);
int lane_N_idx = lane_id & 15;
const int lds_load_num = (M*N*bytes_per_Element) / (4*64);
for(int load = 0,warp_loop = warp_id; load < lds_load_num/WARP_NUM; warp_loop += WARP_NUM, ++load) {
int padding = (warp_loop & 7)*2; // padding size in shared memory per buffer load, to avoid bank conflict
int warp_buffer_load_lds_offset = lds_stage_offset + ((warp_loop >> 3)*(32*34) + ( warp_loop & 7)*(4*32));
int gsOffset = global_offset/Elment_per_dword;
int gvOffset;
if constexpr (Is_M_equal){
gvOffset = (warp_loop * 4 + lane_M_idx) * N_row_len/Elment_per_dword + lane_N_idx;
} else {
gvOffset = (min(warp_loop * 4 + lane_M_idx, max_M_len - 1) * N_row_len)/Elment_per_dword + lane_N_idx;
}
int lds_offset = (warp_buffer_load_lds_offset + padding)/Elment_per_dword;
builtin_buffer_load_dword_lds(lds_ptr, global_ptr, lds_offset, gsOffset, gvOffset);
}
}
//封装ds_read
template<int M, int N, int WARP_NUM, typename Element>
__forceinline__ __device__ void ds_read_tile_pad(vec2_Element<Element>* lds_v2fp16, int lds_stage_offset, union_vec2_f16x2<Element> (*reg)[2], int loop, int warp_id, int lane_id){
#pragma unroll
for(int m_idx = 0; m_idx < M / 32; m_idx ++){
#pragma unroll
for(int n_idx = 0; n_idx < N / 32; n_idx ++){
#pragma unroll
for(int i=0; i<2; i++) {
#pragma unroll
for(int j=0; j<4; j++) {
int lds_offset = lds_stage_offset + n_idx*M*17 + (warp_id*(M/32) + m_idx)*(N*17) + j*2 + i*32 + (lane_id & 1)*16 + ((lane_id & 15)>>1)*64 + /*padding*/ ((lane_id & 15)>>1) + ((lane_id/16) &1)*8 + (lane_id/32);
inline_ds_read_b32_wait(lds_v2fp16, lds_offset, reg[(loop)*((M*N)/(32*32))*2 + (n_idx*(M/32) + m_idx)*2 + i][j/2].f16x2[j%2]);
}
}
}
}
}
//封装ds_read2
template<int M, int N, int WARP_NUM, typename Element>
__forceinline__ __device__ void ds_read2_tile_pad_no_wait(vec2_Element<Element>* lds_v2fp16, int lds_stage_offset, union_vec2_f16x2<Element> (*reg)[2], int loop, int warp_id, int lane_id){
#pragma unroll
for(int m_idx = 0; m_idx < M / 32; m_idx ++){
#pragma unroll
for(int n_idx = 0; n_idx < N / 32; n_idx ++){
#pragma unroll
for(int i=0; i<2; i++) {
#pragma unroll
for(int j=0; j<2; j++) {
int lds_offset = lds_stage_offset + n_idx*M*17 + (warp_id*(M/32) + m_idx)*(N*17) + j*4 + i*32 + (lane_id & 1)*16 + ((lane_id & 15)>>1)*64 + /*padding*/ ((lane_id & 15)>>1) + ((lane_id/16) &1)*8 + (lane_id/32);
inline_ds_read2_b32_no_wait(lds_v2fp16, lds_offset, reg[(loop)*((M*N)/(32*32))*2 + (n_idx*(M/32) + m_idx)*2 + i][j].f32, 2);
}
}
}
}
}
//封装buffer_load
#define buffer_load_lds_tile_pad(Is_M_equal, WARP_NUM, N_row_len, M, N, Element, global_ptr, lds_ptr, global_offset, lds_stage_offset, max_M_len, warp_id, lane_id)\
{\
int bytes_per_Element = 2;\
if constexpr (std::is_same<Element, int8_t>::value || std::is_same<Element, Float8_e4m3_t>::value) {\
bytes_per_Element = 1;\
}\
int Elment_per_dword = 4/bytes_per_Element;\
int lane_M_idx = ((lane_id >> 4) & 1)*2 + ((lane_id >> 4) >> 1);\
int lane_N_idx = lane_id & 15;\
const int lds_load_num = (M*N*bytes_per_Element) / (4*64);\
for(int load = 0,warp_loop = warp_id; load < lds_load_num/WARP_NUM; warp_loop += WARP_NUM, ++load) {\
int padding = (warp_loop & 7);\
int gsOffset = global_offset/Elment_per_dword;\
int gvOffset;\
if constexpr (Is_M_equal){\
gvOffset = (warp_loop * 4 + lane_M_idx) * N_row_len/Elment_per_dword + lane_N_idx;\
} else {\
gvOffset = (min(warp_loop * 4 + lane_M_idx, max_M_len - 1) * N_row_len)/Elment_per_dword + lane_N_idx;\
}\
int lds_offset = lds_stage_offset/Elment_per_dword + padding + warp_loop * 64;\
builtin_buffer_load_dword_lds(lds_ptr, global_ptr, lds_offset, gsOffset, gvOffset);\
}\
}
//封装buffer_load
#define buffer_load_lds_tile_pad_1(Is_M_equal, WARP_NUM, N_row_len, M, N, Element, global_ptr, lds_ptr, global_offset, lds_stage_offset, max_M_len, warp_id, lane_id)\
{\
int bytes_per_Element = 2;\
if constexpr (std::is_same<Element, int8_t>::value || std::is_same<Element, Float8_e4m3_t>::value) {\
bytes_per_Element = 1;\
}\
int Elment_per_dword = 4/bytes_per_Element;\
int lane_M_idx = ((lane_id >> 4) & 1)*2 + ((lane_id >> 4) >> 1);\
int lane_N_idx = lane_id & 15;\
const int lds_load_num = (M*N*bytes_per_Element) / (4*64);\
for(int load = 0,warp_loop = warp_id; load < lds_load_num/WARP_NUM; warp_loop += WARP_NUM, ++load) {\
int padding = (warp_loop & 7);\
int gsOffset = global_offset/Elment_per_dword;\
int gvOffset;\
gvOffset = (warp_loop * 4 + lane_M_idx) * N_row_len/Elment_per_dword + lane_N_idx;\
int lds_offset = lds_stage_offset/Elment_per_dword + padding + warp_loop * 64;\
builtin_buffer_load_dword_lds(lds_ptr, global_ptr, lds_offset, gsOffset, gvOffset);\
}\
}
//封装buffer_load
#define buffer_load_lds_tile(Is_M_equal, WARP_NUM, N_row_len, M, N, Element, global_ptr, lds_ptr, global_offset, lds_stage_offset, max_M_len, warp_id, lane_id)\
{\
int bytes_per_Element = 2;\
if constexpr (std::is_same<Element, int8_t>::value || std::is_same<Element, Float8_e4m3_t>::value) {\
bytes_per_Element = 1;\
}\
int Elment_per_dword = 4/bytes_per_Element;\
int lane_M_idx = ((lane_id >> 4) & 1)*2 + ((lane_id >> 4) >> 1);\
int lane_N_idx = lane_id & 15;\
const int lds_load_num = (M*N*bytes_per_Element) / (4*64);\
for(int load = 0,warp_loop = warp_id; load < lds_load_num/WARP_NUM; warp_loop += WARP_NUM, ++load) {\
int gsOffset = global_offset/Elment_per_dword;\
int gvOffset;\
if constexpr (Is_M_equal){\
gvOffset = (warp_loop * 4 + lane_M_idx) * N_row_len/Elment_per_dword + lane_N_idx;\
} else {\
gvOffset = (min(warp_loop * 4 + lane_M_idx, max_M_len - 1) * N_row_len)/Elment_per_dword + lane_N_idx;\
}\
int lds_offset = lds_stage_offset/Elment_per_dword + warp_loop * 64;\
builtin_buffer_load_dword_lds(lds_ptr, global_ptr, lds_offset, gsOffset, gvOffset);\
}\
}
#define ds_read_tile_pad(M, N, WARP_NUM, Element, lds_v2fp16, lds_stage_offset, reg, loop, warp_id, lane_id)\
{\
for(int m_idx = 0; m_idx < M / 32; m_idx ++){\
for(int n_idx = 0; n_idx < N / 32; n_idx ++){\
for(int i=0; i<2; i++) {\
for(int j=0; j<4; j++) {\
int lds_offset = lds_stage_offset + n_idx*M*17 + (warp_id*(M/32) + m_idx)*(N*17) + j*2 + i*32 + (lane_id & 1)*16 + ((lane_id & 15)>>1)*64 + /*padding*/ ((lane_id & 15)>>1) + ((lane_id/16) &1)*8 + (lane_id/32);\
inline_ds_read_b32_wait(lds_v2fp16, lds_offset, reg[(loop)*((M*N)/(32*32))*2 + (n_idx*(M/32) + m_idx)*2 + i][j/2].f16x2[j%2]);\
}\
}\
}\
}\
}
#define ds_read2_tile_pad_no_wait(M,N,WARP_NUM,Element,lds_v2fp16,precompute_offset,reg,loop,warp_id,lane_id)\
{\
for(int m_idx = 0; m_idx < M / 32; m_idx ++){\
for(int n_idx = 0; n_idx < N / 32; n_idx ++){\
for(int i=0; i<2; i++) {\
for(int j=0; j<2; j++) {\
inline_ds_read2_b32_no_wait(lds_v2fp16, precompute_B_lds_offset[i*2+j], reg[(loop)*((M*N)/(32*32))*2 + (n_idx*(M/32) + m_idx)*2 + i][j].f32, 2); \
}\
}\
}\
}\
}
#define ds_offset_cast(OFFSET) \
static_cast<int>(reinterpret_cast<uintptr_t>(OFFSET) & 0xFFFFFFFF)
}
#pragma once
#include <iostream>
#include "flash.h"
# define FLASH_HOST_DEVICE __forceinline__ __host__ __device__
# define FLASH_DEVICE __forceinline__ __device__
# define FLASH_HOST __forceinline__ __host__
#define HIP_KERNEL_LAUNCH_CHECK() { \
hipError_t error = hipGetLastError(); \
if (error != hipSuccess) { \
std::cout << "HIP Kernel Launch error: " << hipGetErrorString(error) << std::endl;\
}\
}
#define HIP_CHECK(func) { \
hipError_t error = func; \
if (error != hipSuccess) { \
std::cout << "HIP API call error: " << hipGetErrorString(error) << std::endl;\
}\
}
#define PRINT_TENSOR_INFO(tensor, name) \
std::cout << name << ": shape " << tensor.sizes() << ", stride " << tensor.strides() << ", contiguous " << std::boolalpha << tensor.is_contiguous() << "\n";
#define PRINT_QKV_INFO(q, k, v) \
std::cout << "qkv shape: " << q.sizes() << ", " << k.sizes() << ", " << v.sizes() << "\n"; \
std::cout << "qkv stride: " << q.strides() << ", " << k.strides() << ", " << v.strides() << "\n"; \
std::cout << "qkv contiguous: " << std::boolalpha << q.is_contiguous() << ", " << k.is_contiguous() << ", " << v.is_contiguous() << "\n";
#define PRINT_TENSOR(tensor, name) \
{ \
auto temp_tensor = tensor.to(at::DeviceType::CPU).contiguous(); \
std::vector<int32_t> temp_vector(temp_tensor.data_ptr<int32_t>(), temp_tensor.data_ptr<int32_t>() + temp_tensor.numel()); \
printf("%s: [", name); \
for (const auto val: temp_vector) { printf("%d ", val); } \
printf("]\n"); \
}
#define PRINT_PARAMS \
printf("layout: %d\n", params.layout);\
printf("mtp: %d\n", params.mtp);\
printf("is_causal: %d\n", params.is_causal); \
printf("is_bf16: %d\n", params.is_bf16); \
printf("is_e4m3: %d\n", params.is_e4m3); \
printf("b: %d\n", params.b);\
printf("h: %d\n", params.h);\
printf("h_k: %d\n", params.h_k);\
printf("h_h_k_ratio: %d\n", params.h_h_k_ratio);\
printf("seqlen_q: %d\n", params.seqlen_q);\
printf("seqlen_k: %d\n", params.seqlen_k);\
printf("total_q: %d\n", params.total_q);\
printf("seqlen_knew: %d\n", params.seqlen_knew);\
printf("seqlen_q_rounded: %d\n", params.seqlen_q_rounded);\
printf("seqlen_k_rounded: %d\n", params.seqlen_k_rounded);\
printf("ngroups: %d\n", params.ngroups);\
printf("seqlenq_ngroups_swapped: %d\n", params.seqlenq_ngroups_swapped);\
printf("d: %d\n", params.d);\
printf("dv: %d\n", params.d_value);\
printf("q_batch_stride: %d\n", params.q_batch_stride);\
printf("q_row_stride: %d\n", params.q_row_stride);\
printf("q_head_stride: %d\n", params.q_head_stride);\
printf("k_batch_stride: %d\n", params.k_batch_stride);\
printf("k_row_stride: %d\n", params.k_row_stride);\
printf("k_head_stride: %d\n", params.k_head_stride);\
printf("v_batch_stride: %d\n", params.v_batch_stride);\
printf("v_row_stride: %d\n", params.v_row_stride);\
printf("v_head_stride: %d\n", params.v_head_stride);\
printf("o_batch_stride: %d\n", params.o_batch_stride);\
printf("o_row_stride: %d\n", params.o_row_stride);\
printf("o_head_stride: %d\n", params.o_head_stride);\
printf("varlen_proj_qkv_head: %d\n", params.varlen_proj_qkv_head);\
printf("knew_batch_stride: %d\n", params.knew_batch_stride);\
printf("vnew_batch_stride: %d\n", params.vnew_batch_stride);\
printf("knew_row_stride: %d\n", params.knew_row_stride);\
printf("vnew_row_stride: %d\n", params.vnew_row_stride);\
printf("knew_head_stride: %d\n", params.knew_head_stride);\
printf("vnew_head_stride: %d\n", params.vnew_head_stride);\
printf("window_size_left: %d\n", params.window_size_left);\
printf("window_size_right: %d\n", params.window_size_right);\
printf("qkvheaddim_compute: %d\n", params.qkvheaddim_compute);\
printf("qkvheaddim_tail_tile16: %d\n", params.qkvheaddim_tail_tile16);\
printf("softcap: %.12f\n", params.softcap);\
printf("p_dropout: %.12f\n", params.p_dropout);\
printf("rp_dropout: %.12f\n", params.rp_dropout);\
printf("p_dropout_in_uint8_t: %u\n", params.p_dropout_in_uint8_t);\
printf("random_seed: %llu\n", params.rand_seed);\
printf("random_offset: %llu\n", params.rand_offset);\
printf("scale_softmax: %.12f\n", params.scale_softmax);\
printf("scale_softmax_log2: %.12f\n", params.scale_softmax_log2);\
printf("is_seqlens_k_cumulative: %d\n", params.is_seqlens_k_cumulative); \
printf("q_ptr: %p\n", params.q_ptr); \
printf("k_ptr: %p\n", params.k_ptr); \
printf("v_ptr: %p\n", params.v_ptr); \
printf("o_ptr: %p\n", params.o_ptr); \
printf("qv_ptr: %p\n", params.qv_ptr); \
printf("cu_seqlens_q: %p\n", params.cu_seqlens_q); \
printf("cu_seqlens_k: %p\n", params.cu_seqlens_k); \
printf("seqused_k: %p\n", params.seqused_k); \
printf("attn_mask: %p\n", params.attn_mask); \
printf("padding_mask: %p\n", params.padding_mask); \
printf("softmax_lse_ptr: %p\n", params.softmax_lse_ptr); \
printf("softmax_lseaccum_ptr: %p\n", params.softmax_lseaccum_ptr); \
printf("scores_sum_ptr: %p\n", params.scores_sum_ptr); \
printf("scores_max_ptr: %p\n", params.scores_max_ptr); \
printf("oaccum_ptr: %p\n", params.oaccum_ptr); \
printf("block_table : %p\n", params.block_table); \
printf("page_block_size: %d\n", params.page_block_size); \
printf("block_table_batch_stride: %d\n", params.block_table_batch_stride); \
printf("num_pages : %d\n", params.num_pages); \
printf("num_splits : %d\n", params.num_splits); \
printf("partition_size : %d\n", params.partition_size); \
printf("splitkv_use_fp32_as_accum: %d\n", params.splitkv_use_fp32_as_accum);
// output details in oneline
#define PRINT_PARAMS_ONELINE \
printf("{layout: %d, is_causal: %d, is_bf16: %d, is_e4m3: %d, b: %d, h: %d, h_k: %d, h_h_k_ratio: %d, seqlen_q: %d, seqlen_k: %d, total_q: %d, ngroups: %d, seqlenq_ngroups_swapped: %d, d: %d, dv: %d,\
q_batch_stride: %d, q_row_stride: %d, q_head_stride: %d, k_batch_stride: %d, k_row_stride: %d, k_head_stride: %d, v_batch_stride: %d, v_row_stride: %d, v_head_stride: %d, o_batch_stride: %d, o_row_stride: %d, o_head_stride: %d,\
varlen_proj_qkv_head: %d, window_size_left: %d, window_size_right: %d, softcap: %.12f, scale_softmax: %.12f, scale_softmax_log2: %.12f, is_seqlens_k_cumulative: %d,\
q_ptr: %p, k_ptr: %p, v_ptr: %p, o_ptr: %p, qv: %p, cu_seqlens_q: %p, cu_seqlens_k: %p, seqused_k: %p, attn_mask: %p, padding_mask: %p, softmax_lse_ptr: %p, scores_sum_ptr: %p, scores_max_ptr: %p, oaccum_ptr: %p, block_table : %p,\
page_block_size: %d, block_table_batch_stride: %d, num_pages : %d, num_splits : %d, partition_size : %d, splitkv_use_fp32_as_accum: %d}\n",\
params.layout, \
params.is_causal, \
params.is_bf16, \
params.is_e4m3, \
params.b, \
params.h, \
params.h_k, \
params.h_h_k_ratio, \
params.seqlen_q, \
params.seqlen_k, \
params.total_q, \
params.ngroups, \
params.seqlenq_ngroups_swapped, \
params.d, \
params.d_value, \
params.q_batch_stride, \
params.q_row_stride, \
params.q_head_stride, \
params.k_batch_stride, \
params.k_row_stride, \
params.k_head_stride, \
params.v_batch_stride, \
params.v_row_stride, \
params.v_head_stride, \
params.o_batch_stride, \
params.o_row_stride, \
params.o_head_stride, \
params.varlen_proj_qkv_head, \
params.window_size_left, \
params.window_size_right, \
params.softcap, \
params.scale_softmax, \
params.scale_softmax_log2, \
params.is_seqlens_k_cumulative, \
params.q_ptr, \
params.k_ptr, \
params.v_ptr, \
params.o_ptr, \
params.qv_ptr, \
params.cu_seqlens_q, \
params.cu_seqlens_k, \
params.seqused_k, \
params.attn_mask, \
params.padding_mask, \
params.softmax_lse_ptr, \
params.scores_sum_ptr, \
params.scores_max_ptr, \
params.oaccum_ptr, \
params.block_table , \
params.page_block_size, \
params.block_table_batch_stride, \
params.num_pages , \
params.num_splits , \
params.partition_size , \
params.splitkv_use_fp32_as_accum \
);
__attribute__((weak)) void printFlashBwdParams(const Flash_bwd_params& params) {
std::cout << "Flash_bwd_params:\n";
// 打印 Flash_fwd_params 成员
std::cout << "o_ptr: " << params.o_ptr << "\n";
std::cout << "oaccum_ptr: " << params.oaccum_ptr << "\n";
#ifdef DEBUGING
std::cout << "qk_ptr: " << params.qk_ptr << "\n";
std::cout << "qk_softmax_ptr: " << params.qk_softmax_ptr << "\n";
#endif
std::cout << "o_batch_stride: " << params.o_batch_stride << "\n";
std::cout << "o_row_stride: " << params.o_row_stride << "\n";
std::cout << "o_head_stride: " << params.o_head_stride << "\n";
std::cout << "p_ptr: " << params.p_ptr << "\n";
std::cout << "softmax_lse_ptr: " << params.softmax_lse_ptr << "\n";
std::cout << "softmax_lseaccum_ptr: " << params.softmax_lseaccum_ptr << "\n";
std::cout << "scores_sum_ptr: " << params.scores_sum_ptr << "\n";
std::cout << "scores_max_ptr: " << params.scores_max_ptr << "\n";
std::cout << "b: " << params.b << "\n";
std::cout << "seqlen_q: " << params.seqlen_q << "\n";
std::cout << "seqlen_k: " << params.seqlen_k << "\n";
std::cout << "seqlen_knew: " << params.seqlen_knew << "\n";
std::cout << "d: " << params.d << "\n";
std::cout << "seqlen_q_rounded: " << params.seqlen_q_rounded << "\n";
std::cout << "seqlen_k_rounded: " << params.seqlen_k_rounded << "\n";
std::cout << "d_rounded: " << params.d_rounded << "\n";
std::cout << "rotary_dim: " << params.rotary_dim << "\n";
std::cout << "total_q: " << params.total_q << "\n";
std::cout << "scale_softmax: " << params.scale_softmax << "\n";
std::cout << "scale_softmax_log2: " << params.scale_softmax_log2 << "\n";
std::cout << "cu_seqlens_q: " << params.cu_seqlens_q << "\n";
std::cout << "cu_seqlens_k: " << params.cu_seqlens_k << "\n";
std::cout << "leftpad_k: " << params.leftpad_k << "\n";
std::cout << "seqused_k: " << params.seqused_k << "\n";
// std::cout << "blockmask: " << params.blockmask << "\n";
std::cout << "knew_ptr: " << params.knew_ptr << "\n";
std::cout << "vnew_ptr: " << params.vnew_ptr << "\n";
std::cout << "knew_batch_stride: " << params.knew_batch_stride << "\n";
std::cout << "vnew_batch_stride: " << params.vnew_batch_stride << "\n";
std::cout << "knew_row_stride: " << params.knew_row_stride << "\n";
std::cout << "vnew_row_stride: " << params.vnew_row_stride << "\n";
std::cout << "knew_head_stride: " << params.knew_head_stride << "\n";
std::cout << "vnew_head_stride: " << params.vnew_head_stride << "\n";
std::cout << "rotary_cos_ptr: " << params.rotary_cos_ptr << "\n";
std::cout << "rotary_sin_ptr: " << params.rotary_sin_ptr << "\n";
std::cout << "cache_batch_idx: " << params.cache_batch_idx << "\n";
std::cout << "block_table: " << params.block_table << "\n";
std::cout << "block_table_batch_stride: " << params.block_table_batch_stride << "\n";
std::cout << "page_block_size: " << params.page_block_size << "\n";
std::cout << "p_dropout: " << params.p_dropout << "\n";
std::cout << "p_dropout_in_uint8_t: " << (int)params.p_dropout_in_uint8_t << "\n";
std::cout << "rp_dropout: " << params.rp_dropout << "\n";
std::cout << "scale_softmax_rp_dropout: " << params.scale_softmax_rp_dropout << "\n";
std::cout << "window_size_left: " << params.window_size_left << "\n";
std::cout << "window_size_right: " << params.window_size_right << "\n";
std::cout << "softcap: " << params.softcap << "\n";
std::cout << "rand_seed: " << params.rand_seed << "\n";
std::cout << "rand_offset: " << params.rand_offset << "\n";
std::cout << "dropout_debug_count: " << params.dropout_debug_count << "\n";
std::cout << "rng_state: " << params.rng_state << "\n";
std::cout << "is_bf16: " << params.is_bf16 << "\n";
std::cout << "is_causal: " << params.is_causal << "\n";
std::cout << "is_seqlens_k_cumulative: " << params.is_seqlens_k_cumulative << "\n";
std::cout << "is_rotary_interleaved: " << params.is_rotary_interleaved << "\n";
std::cout << "num_splits: " << params.num_splits << "\n";
std::cout << "partition_size: " << params.partition_size << "\n";
std::cout << "alibi_slopes_ptr: " << params.alibi_slopes_ptr << "\n";
std::cout << "alibi_slopes_batch_stride: " << params.alibi_slopes_batch_stride << "\n";
std::cout << "unpadded_lse: " << params.unpadded_lse << "\n";
std::cout << "seqlenq_ngroups_swapped: " << params.seqlenq_ngroups_swapped << "\n";
// 打印 Flash_bwd_params 独有成员
std::cout << "q_ptr: " << params.q_ptr << "\n";
std::cout << "k_ptr: " << params.k_ptr << "\n";
std::cout << "v_ptr: " << params.v_ptr << "\n";
std::cout << "o_ptr: " << params.o_ptr << "\n";
std::cout << "softmax_lse_ptr: " << params.softmax_lse_ptr << "\n";
std::cout << "do_ptr: " << params.do_ptr << "\n";
std::cout << "dq_ptr: " << params.dq_ptr << "\n";
std::cout << "dk_ptr: " << params.dk_ptr << "\n";
std::cout << "dv_ptr: " << params.dv_ptr << "\n";
std::cout << "dq_accum_ptr: " << params.dq_accum_ptr << "\n";
std::cout << "dk_accum_ptr: " << params.dk_accum_ptr << "\n";
std::cout << "dv_accum_ptr: " << params.dv_accum_ptr << "\n";
#ifdef DEBUGING
std::cout << "kq_ptr: " << params.kq_ptr << "\n";
std::cout << "s_ptr: " << params.s_ptr << "\n";
std::cout << "dp_ptr: " << params.dp_ptr << "\n";
std::cout << "ds_ptr: " << params.ds_ptr << "\n";
#endif
std::cout << "do_batch_stride: " << params.do_batch_stride << "\n";
std::cout << "do_row_stride: " << params.do_row_stride << "\n";
std::cout << "do_head_stride: " << params.do_head_stride << "\n";
std::cout << "dq_batch_stride: " << params.dq_batch_stride << "\n";
std::cout << "dk_batch_stride: " << params.dk_batch_stride << "\n";
std::cout << "dv_batch_stride: " << params.dv_batch_stride << "\n";
std::cout << "dq_row_stride: " << params.dq_row_stride << "\n";
std::cout << "dk_row_stride: " << params.dk_row_stride << "\n";
std::cout << "dv_row_stride: " << params.dv_row_stride << "\n";
std::cout << "dq_head_stride: " << params.dq_head_stride << "\n";
std::cout << "dk_head_stride: " << params.dk_head_stride << "\n";
std::cout << "dv_head_stride: " << params.dv_head_stride << "\n";
std::cout << "dsoftmax_sum: " << params.dsoftmax_sum << "\n";
std::cout << "deterministic: " << params.deterministic << "\n";
std::cout << "dq_accum_split_stride: " << params.dq_accum_split_stride << "\n";
}
#define PRINT_BWD_PARAMS printf("b is %d\n", params.b); \
printf("params.h is %d \n",params.h); \
printf("params.h_k is %d \n",params.h_k); \
printf("params.h_h_k_ratio is %d \n",params.h_h_k_ratio); \
printf("params.d is %d \n",params.d); \
printf("params.seqlen_q is %d \n",params.seqlen_q); \
printf("params.seqlen_k is %d \n",params.seqlen_k); \
printf("params.q_row_stride is %d \n",params.q_row_stride); \
printf("params.k_row_stride is %d \n",params.k_row_stride); \
printf("params.v_row_stride is %d \n",params.v_row_stride); \
printf("params.o_row_stride is %d \n",params.o_row_stride); \
printf("params.do_row_stride is %d \n",params.do_row_stride); \
printf("params.dq_row_stride is %d \n",params.dq_row_stride); \
printf("params.dk_row_stride is %d \n",params.dk_row_stride); \
printf("params.dv_row_stride is %d \n",params.dv_row_stride); \
printf("params.q_head_stride is %d \n",params.q_head_stride); \
printf("params.k_head_stride is %d \n",params.k_head_stride); \
printf("params.v_head_stride is %d \n",params.v_head_stride); \
printf("params.o_head_stride is %d \n",params.o_head_stride); \
printf("params.dq_head_stride is %d \n",params.dq_head_stride); \
printf("params.do_head_stride is %d \n",params.do_head_stride); \
printf("params.dk_head_stride is %d \n",params.dk_head_stride); \
printf("params.dv_head_stride is %d \n",params.dv_head_stride); \
printf("params.q_batch_stride is %d \n",params.q_batch_stride); \
printf("params.k_batch_stride is %d \n",params.k_batch_stride); \
printf("params.o_batch_stride is %d \n",params.o_batch_stride); \
printf("params.do_batch_stride is %d \n",params.do_batch_stride); \
printf("params.dq_batch_stride is %d \n",params.dq_batch_stride); \
printf("params.dk_batch_stride is %d \n",params.dk_batch_stride); \
printf("params.dv_batch_stride is %d \n",params.dv_batch_stride); \
printf("params.scale_softmax is %d \n",params.scale_softmax); \
printf("params.deterministic is %d \n",params.deterministic);
#define PRINT_MLA_PARAMS \
printf("layout: %d\n", params.layout);\
printf("mtp: %d\n", params.mtp);\
printf("is_causal: %d\n", params.is_causal); \
printf("is_bf16: %d\n", params.is_bf16); \
printf("is_e4m3: %d\n", params.is_e4m3); \
printf("b: %d\n", params.b);\
printf("h: %d\n", params.h);\
printf("h_k: %d\n", params.h_k);\
printf("h_h_k_ratio: %d\n", params.h_h_k_ratio);\
printf("total_q: %d\n", params.total_q);\
printf("seqlen_q: %d\n", params.seqlen_q);\
printf("seqlen_k: %d\n", params.seqlen_k);\
printf("ngroups: %d\n", params.ngroups);\
printf("seqlenq_ngroups_swapped: %d\n", params.seqlenq_ngroups_swapped);\
printf("d: %d\n", params.d);\
printf("q_batch_stride: %d\n", params.q_batch_stride);\
printf("q_row_stride: %d\n", params.q_row_stride);\
printf("q_head_stride: %d\n", params.q_head_stride);\
printf("k_batch_stride: %d\n", params.k_batch_stride);\
printf("k_row_stride: %d\n", params.k_row_stride);\
printf("k_head_stride: %d\n", params.k_head_stride);\
printf("v_batch_stride: %d\n", params.v_batch_stride);\
printf("v_row_stride: %d\n", params.v_row_stride);\
printf("v_head_stride: %d\n", params.v_head_stride);\
printf("qv_batch_stride: %d\n", params.qv_batch_stride);\
printf("qv_row_stride: %d\n", params.qv_row_stride);\
printf("qv_head_stride: %d\n", params.qv_head_stride);\
printf("o_batch_stride: %d\n", params.o_batch_stride);\
printf("o_row_stride: %d\n", params.o_row_stride);\
printf("o_head_stride: %d\n", params.o_head_stride);\
printf("scale_softmax: %.12f\n", params.scale_softmax);\
printf("scale_softmax_log2: %.12f\n", params.scale_softmax_log2);\
printf("is_seqlens_k_cumulative: %d\n", params.is_seqlens_k_cumulative); \
printf("q_ptr: %p\n", params.q_ptr); \
printf("k_ptr: %p\n", params.k_ptr); \
printf("v_ptr: %p\n", params.v_ptr); \
printf("qv_ptr: %p\n", params.qv_ptr); \
printf("o_ptr: %p\n", params.o_ptr); \
printf("cu_seqlens_q: %p\n", params.cu_seqlens_q); \
printf("cu_seqlens_k: %p\n", params.cu_seqlens_k); \
printf("cu_seqlens_k_new: %p\n", params.cu_seqlens_k_new); \
printf("leftpad_k: %p\n", params.leftpad_k); \
printf("num_splits_ptr: %p\n", params.num_splits_ptr); \
printf("tile_scheduler_metadata_ptr: %p\n", params.tile_scheduler_metadata_ptr); \
printf("oaccum_ptr: %p\n", params.oaccum_ptr); \
printf("scores_max_ptr: %p\n", params.scores_max_ptr); \
printf("scores_sum_ptr: %p\n", params.scores_sum_ptr); \
printf("softmax_lse_ptr: %p\n", params.softmax_lse_ptr); \
printf("block_table : %p\n", params.block_table); \
printf("page_block_size: %d\n", params.page_block_size); \
printf("block_table_batch_stride: %d\n", params.block_table_batch_stride); \
printf("num_splits : %d\n", params.num_splits); \
printf("partition_size : %d\n", params.partition_size); \
printf("splitkv_use_fp32_as_accum: %d\n", params.splitkv_use_fp32_as_accum); \
printf("cu_count : %d\n", params.cu_count);
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment