Commit 5813dcc1 authored by zhanghj2's avatar zhanghj2
Browse files

添加softmax

parent 0e1300f7
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/softmax.h
#pragma once
#include <cmath>
#include <cute/tensor.hpp>
#include <cutlass/numeric_types.h>
#include "utils.h"
namespace flash {
using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); mi++) {
summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
#pragma unroll
for (int ni = 1; ni < size<1>(tensor); ni++) {
summary(mi) = op(summary(mi), tensor(mi, ni));
}
}
}
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
#pragma unroll
for (int i = 0; i < size(dst); i++){
dst(i) = Allreduce<64>::run(src(i), op);
}
}
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void warp_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &smem_reduce, Operator &op) {
const int tidx = threadIdx.x;
const int row = tidx % 16;
const int col = tidx / 64;
const int warp_id = tidx / 64;
// static_assert(size(dst) == 1);
// 这里两种写法,一种是写连续,读不连续;另一种是读不连续,写连续。如何权衡?性能影响不大
if ((tidx % 64) / 16 == 0)
// if (tidx >= warp_id * 64 && tidx <= warp_id * 64 + 16)
{
// smem_reduce(row + warp_id * 16) = dst(0);
smem_reduce(row * 4 + warp_id * 1) = dst(0);
// smem_reduce(row, col) = dst(0);
}
__syncthreads();
if (tidx < 16)
{
smem_reduce(row + 64) = op(op(smem_reduce(row * 4), smem_reduce(row * 4 + 1)), op(smem_reduce(row * 4 + 2), smem_reduce(row * 4 + 3)));
}
__syncthreads();
dst(0) = smem_reduce(row + 64);
}
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void warp_allreduce_tp1(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &smem_reduce, Operator &op) {
const int tidx = threadIdx.x;
const int col = (tidx % 64) / 16;
const int warp_id = tidx / 64;
const int row = tidx % 16 + (warp_id % 4) * 16;
// 0-4 1-5 2-6 3-7
if (col == 0) {
// printf("sum %d %d %d %d %.2f \n", row * 2 + (warp_id / 4), row, tidx, warp_id, dst(0));
smem_reduce[row * 2 + (warp_id / 4)] = dst[0];
}
__syncthreads();
if (col == 0 && warp_id < 4) {
// printf("sum %d %d %d %.2f %.2f \n", row, tidx, warp_id, smem_reduce[row * 2], smem_reduce[row * 2 + warp_id / 4]);
smem_reduce[128 + row] = op(smem_reduce[row * 2], smem_reduce[row * 2 + 1]);
}
__syncthreads();
dst(0) = smem_reduce(128 + row);
}
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void warp_allreduce_tp4(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &smem_reduce, Operator &op) {
const int tidx = threadIdx.x;
const int col = (tidx % 64) / 16;
const int warp_id = tidx / 64;
const int row = tidx % 16 + (warp_id % 2) * 16;
// 0-4 1-5 2-6 3-7
if (col == 0) {
// printf("sum %d %d %d %d %.2f \n", row * 2 + (warp_id / 4), row, tidx, warp_id, dst(0));
smem_reduce[row * 2 + (warp_id / 2)] = dst[0];
}
__syncthreads();
if (col == 0 && warp_id < 2) {
// printf("sum %d %d %d %.2f %.2f \n", row, tidx, warp_id, smem_reduce[row * 2], smem_reduce[row * 2 + warp_id / 4]);
smem_reduce[row + 64] = op(smem_reduce[row * 2], smem_reduce[row * 2 + 1]);
}
__syncthreads();
dst(0) = smem_reduce(row + 64);
}
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
thread_reduce_<zero_init>(tensor, summary, op);
quad_allreduce_(summary, summary, op);
}
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
MaxOp<float> max_op;
reduce_<zero_init>(tensor, max, max_op);
}
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
SumOp<float> sum_op;
thread_reduce_<zero_init>(tensor, sum, sum_op);
}
// Apply the exp to all the elements.
template <bool Scale_max=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
// If we don't have float around M_LOG2E the multiplication is done in fp64.
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E));
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
// The following macro will disable the use of fma.
// See: https://github.com/pytorch/pytorch/issues/121558 for more details
// This macro is set in PyTorch and not FlashAttention
#if 0
#ifdef UNFUSE_FMA
tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled);
#else
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
#endif
#else
tensor(mi, ni) = __builtin_amdgcn_exp2f(tensor(mi, ni) * scale - max_scaled);
#endif
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int kNRows>
struct Softmax {
using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
TensorT row_max, row_sum;
__forceinline__ __device__ Softmax() {};
template<bool Is_first, bool Check_inf=false, bool is_tp1=false, typename Tensor0, typename Tensor1, typename Tensor2>
__forceinline__ __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, Tensor2 &sRow_max_reduce_buffer, float softmax_scale_log2) {
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
MaxOp<float> max_op;
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == kNRows);
if (Is_first) {
flash::template reduce_max</*zero_init=*/true>(scores, row_max);
if constexpr (is_tp1)
{
flash::template warp_allreduce_tp1(row_max, sRow_max_reduce_buffer, max_op);
}
else
{
flash::template warp_allreduce_(row_max, sRow_max_reduce_buffer, max_op);
}
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
} else {
Tensor scores_max_prev = make_fragment_like(row_max);
cute::copy(row_max, scores_max_prev);
flash::template reduce_max</*zero_init=*/false>(scores, row_max);
if constexpr (is_tp1)
{
flash::template warp_allreduce_tp1(row_max, sRow_max_reduce_buffer, max_op);
}
else
{
flash::template warp_allreduce_(row_max, sRow_max_reduce_buffer, max_op);
} // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
float scores_max_cur = !Check_inf
? row_max(mi)
: (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
#if 0
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
#else
float scores_scale = __builtin_amdgcn_exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
#endif
// if (blockIdx.x == 0 && threadIdx.x == 0)
// {
// printf("threadIdx.x %.2f, scores_scale = %.4f\n",row_sum(mi), scores_scale );
// }
row_sum(mi) *= scores_scale;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
}
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
// We don't do the reduce across threads here since we don't need to use the row_sum.
// We do that reduce at the end when we need to normalize the softmax.
flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
}
// if (block0())
// {
// printf("normalize_softmax_lse %.4f\n", row_sum(0));
// }
};
template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor2>
__forceinline__ __device__ void softmax_rescale_o_fp8(Tensor0 &acc_s, Tensor2 &sRow_max_reduce_buffer, float softmax_scale_log2,
v4f& c0_0, v4f& c0_1, v4f& c1_0, v4f& c1_1, v4f& c2_0, v4f& c2_1, v4f& c3_0, v4f& c3_1
) {
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
MaxOp<float> max_op;
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == kNRows);
if (Is_first) {
flash::template reduce_max</*zero_init=*/true>(scores, row_max);
flash::template warp_allreduce_(row_max, sRow_max_reduce_buffer, max_op);
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
} else {
Tensor scores_max_prev = make_fragment_like(row_max);
cute::copy(row_max, scores_max_prev);
flash::template reduce_max</*zero_init=*/false>(scores, row_max);
flash::template warp_allreduce_(row_max, sRow_max_reduce_buffer, max_op);
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
// Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert(1 == kNRows);
// #pragma unroll
// for (int mi = 0; mi < size(row_max); ++mi)
{
int mi = 0;
float scores_max_cur = !Check_inf
? row_max(mi)
: (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
#if 0
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
#else
float scores_scale = __builtin_amdgcn_exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
#endif
row_sum(mi) *= scores_scale;
// #pragma unroll
// for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
c0_0.x *= scores_scale; c0_0.y *= scores_scale; c0_0.z *= scores_scale; c0_0.w *= scores_scale;
c0_1.x *= scores_scale; c0_1.y *= scores_scale; c0_1.z *= scores_scale; c0_1.w *= scores_scale;
c1_0.x *= scores_scale; c1_0.y *= scores_scale; c1_0.z *= scores_scale; c1_0.w *= scores_scale;
c1_1.x *= scores_scale; c1_1.y *= scores_scale; c1_1.z *= scores_scale; c1_1.w *= scores_scale;
c2_0.x *= scores_scale; c2_0.y *= scores_scale; c2_0.z *= scores_scale; c2_0.w *= scores_scale;
c2_1.x *= scores_scale; c2_1.y *= scores_scale; c2_1.z *= scores_scale; c2_1.w *= scores_scale;
c3_0.x *= scores_scale; c3_0.y *= scores_scale; c3_0.z *= scores_scale; c3_0.w *= scores_scale;
c3_1.x *= scores_scale; c3_1.y *= scores_scale; c3_1.z *= scores_scale; c3_1.w *= scores_scale;
}
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
// We don't do the reduce across threads here since we don't need to use the row_sum.
// We do that reduce at the end when we need to normalize the softmax.
flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
}
};
template<bool Is_dropout=false, bool Split=false, bool is_tp1 = false, typename Tensor0, typename Tensor1>
__forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, Tensor1& sRow_sum_reduce_buffer, float softmax_scale, float rp_dropout=1.0) {
SumOp<float> sum_op;
quad_allreduce_(row_sum, row_sum, sum_op);
if constexpr (is_tp1)
{
flash::template warp_allreduce_tp1(row_sum, sRow_sum_reduce_buffer, sum_op);
}
else
{
flash::template warp_allreduce_(row_sum, sRow_sum_reduce_buffer, sum_op);
}
// if (block0())
// {
// printf("is_tp1 %d %d normalize_softmax_lse %.4f\n",is_tp1, threadIdx.x, row_sum(0));
// }
TensorT lse = make_fragment_like(row_sum);
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
float sum = row_sum(mi);
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
}
return lse;
};
template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1, typename Tensor2>
__forceinline__ __device__ void softmax_rescale_o_prefill(Tensor0 &acc_s, Tensor1 &acc_o, Tensor2 &sRow_max_reduce_buffer, float softmax_scale_log2) {
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
MaxOp<float> max_op;
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == kNRows);
if (Is_first) {
flash::template reduce_max</*zero_init=*/true>(scores, row_max);
flash::template warp_allreduce_(row_max, sRow_max_reduce_buffer, max_op);
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
} else {
Tensor scores_max_prev = make_fragment_like(row_max);
cute::copy(row_max, scores_max_prev);
flash::template reduce_max</*zero_init=*/false>(scores, row_max);
flash::template warp_allreduce_(row_max, sRow_max_reduce_buffer, max_op);
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
float scores_max_cur = !true
? row_max(mi)
: (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
#if 0
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
#else
float scores_scale = __builtin_amdgcn_exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
#endif
// if (blockIdx.x == 0 && threadIdx.x == 0)
// {
// printf("threadIdx.x %.2f, scores_scale = %.4f\n",row_sum(mi), scores_scale );
// }
row_sum(mi) *= scores_scale;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
}
// if (blockIdx.x == 2)
// {
// printf("threadIdx.x %.2f \n",row_sum(mi) );
// }
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
// We don't do the reduce across threads here since we don't need to use the row_sum.
// We do that reduce at the end when we need to normalize the softmax.
flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
}
};
template<bool Is_dropout=false, bool Split=false, typename Tensor0, typename Tensor1>
__forceinline__ __device__ TensorT normalize_softmax_lse_prefill(Tensor0 &acc_o, Tensor1& sRow_sum_reduce_buffer, float softmax_scale, float rp_dropout=1.0) {
SumOp<float> sum_op;
quad_allreduce_(row_sum, row_sum, sum_op);
flash::template warp_allreduce_(row_sum, sRow_sum_reduce_buffer, sum_op);
TensorT lse = make_fragment_like(row_sum);
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
float sum = row_sum(mi);
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __log2f(sum);
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
}
return lse;
};
template<bool Is_dropout=false, bool Split=false, typename Tensor0, typename Tensor1>
__forceinline__ __device__ TensorT normalize_softmax_lse_fp8(Tensor0 &acc_o, Tensor1& sRow_sum_reduce_buffer, float softmax_scale,float descale_v, float rp_dropout=1.0) {
SumOp<float> sum_op;
quad_allreduce_(row_sum, row_sum, sum_op);
flash::template warp_allreduce_(row_sum, sRow_sum_reduce_buffer, sum_op);
TensorT lse = make_fragment_like(row_sum);
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
float sum = row_sum(mi);
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : descale_v / sum;
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
}
return lse;
};
template<bool Is_first, bool Check_inf=false, bool is_tp1=false, typename Tensor0, typename Tensor2>
__forceinline__ __device__ void softmax_rescale_o_fp8_tp1(Tensor0 &acc_s, Tensor2 &sRow_max_reduce_buffer, float softmax_scale_log2,
v4f *acco_f32
) {
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
MaxOp<float> max_op;
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == kNRows);
if constexpr (Is_first) {
flash::template reduce_max</*zero_init=*/true>(scores, row_max);
if constexpr (is_tp1)
{
flash::template warp_allreduce_tp1(row_max, sRow_max_reduce_buffer, max_op);
}
else
{
flash::template warp_allreduce_(row_max, sRow_max_reduce_buffer, max_op);
}
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
} else {
Tensor scores_max_prev = make_fragment_like(row_max);
cute::copy(row_max, scores_max_prev);
flash::template reduce_max</*zero_init=*/false>(scores, row_max);
if constexpr (is_tp1)
{
flash::template warp_allreduce_tp1(row_max, sRow_max_reduce_buffer, max_op);
}
else
{
flash::template warp_allreduce_(row_max, sRow_max_reduce_buffer, max_op);
} // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
// Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert(1 == kNRows);
// #pragma unroll
// for (int mi = 0; mi < size(row_max); ++mi)
{
int mi = 0;
float scores_max_cur = !Check_inf
? row_max(mi)
: (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
#if 0
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
#else
float scores_scale = __builtin_amdgcn_exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
#endif
row_sum(mi) *= scores_scale;
for (int i = 0; i < 16; i++)
{
acco_f32[i].x *= scores_scale;
acco_f32[i].y *= scores_scale;
acco_f32[i].z *= scores_scale;
acco_f32[i].w *= scores_scale;
}
// #pragma unroll
// for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
// c0_0.x *= scores_scale; c0_0.y *= scores_scale; c0_0.z *= scores_scale; c0_0.w *= scores_scale;
// c0_1.x *= scores_scale; c0_1.y *= scores_scale; c0_1.z *= scores_scale; c0_1.w *= scores_scale;
// c1_0.x *= scores_scale; c1_0.y *= scores_scale; c1_0.z *= scores_scale; c1_0.w *= scores_scale;
// c1_1.x *= scores_scale; c1_1.y *= scores_scale; c1_1.z *= scores_scale; c1_1.w *= scores_scale;
// c2_0.x *= scores_scale; c2_0.y *= scores_scale; c2_0.z *= scores_scale; c2_0.w *= scores_scale;
// c2_1.x *= scores_scale; c2_1.y *= scores_scale; c2_1.z *= scores_scale; c2_1.w *= scores_scale;
// c3_0.x *= scores_scale; c3_0.y *= scores_scale; c3_0.z *= scores_scale; c3_0.w *= scores_scale;
// c3_1.x *= scores_scale; c3_1.y *= scores_scale; c3_1.z *= scores_scale; c3_1.w *= scores_scale;
}
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
// We don't do the reduce across threads here since we don't need to use the row_sum.
// We do that reduce at the end when we need to normalize the softmax.
flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
}
};
template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor2>
__forceinline__ __device__ void softmax_rescale_o_fp8_tp4(Tensor0 &acc_s, Tensor2 &sRow_max_reduce_buffer, float softmax_scale_log2,
v4f *acco_f32
) {
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
MaxOp<float> max_op;
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == kNRows);
if constexpr (Is_first) {
flash::template reduce_max</*zero_init=*/true>(scores, row_max);
flash::template warp_allreduce_tp4(row_max, sRow_max_reduce_buffer, max_op);
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
} else {
Tensor scores_max_prev = make_fragment_like(row_max);
cute::copy(row_max, scores_max_prev);
flash::template reduce_max</*zero_init=*/false>(scores, row_max);
flash::template warp_allreduce_tp4(row_max, sRow_max_reduce_buffer, max_op);
// Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert(1 == kNRows);
// #pragma unroll
// for (int mi = 0; mi < size(row_max); ++mi)
{
int mi = 0;
float scores_max_cur = !Check_inf
? row_max(mi)
: (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
#if 0
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
#else
float scores_scale = __builtin_amdgcn_exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
#endif
row_sum(mi) *= scores_scale;
for (int i = 0; i < 16; i++)
{
acco_f32[i].x *= scores_scale;
acco_f32[i].y *= scores_scale;
acco_f32[i].z *= scores_scale;
acco_f32[i].w *= scores_scale;
}
// #pragma unroll
// for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
// c0_0.x *= scores_scale; c0_0.y *= scores_scale; c0_0.z *= scores_scale; c0_0.w *= scores_scale;
// c0_1.x *= scores_scale; c0_1.y *= scores_scale; c0_1.z *= scores_scale; c0_1.w *= scores_scale;
// c1_0.x *= scores_scale; c1_0.y *= scores_scale; c1_0.z *= scores_scale; c1_0.w *= scores_scale;
// c1_1.x *= scores_scale; c1_1.y *= scores_scale; c1_1.z *= scores_scale; c1_1.w *= scores_scale;
// c2_0.x *= scores_scale; c2_0.y *= scores_scale; c2_0.z *= scores_scale; c2_0.w *= scores_scale;
// c2_1.x *= scores_scale; c2_1.y *= scores_scale; c2_1.z *= scores_scale; c2_1.w *= scores_scale;
// c3_0.x *= scores_scale; c3_0.y *= scores_scale; c3_0.z *= scores_scale; c3_0.w *= scores_scale;
// c3_1.x *= scores_scale; c3_1.y *= scores_scale; c3_1.z *= scores_scale; c3_1.w *= scores_scale;
}
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
// We don't do the reduce across threads here since we don't need to use the row_sum.
// We do that reduce at the end when we need to normalize the softmax.
flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
}
};
template<bool Is_dropout=false, bool Split=false, bool is_tp1=false, typename Tensor1>
__forceinline__ __device__ TensorT normalize_softmax_lse_fp8_tp1(v4f *acco_f, Tensor1& sRow_sum_reduce_buffer, float softmax_scale,float descale_v, float rp_dropout=1.0) {
SumOp<float> sum_op;
quad_allreduce_(row_sum, row_sum, sum_op);
if constexpr (is_tp1)
{
flash::template warp_allreduce_tp1(row_sum, sRow_sum_reduce_buffer, sum_op);
}
else
{
flash::template warp_allreduce_(row_sum, sRow_sum_reduce_buffer, sum_op);
}
TensorT lse = make_fragment_like(row_sum);
// Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
// static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < 1; ++mi) {
float sum = row_sum(mi);
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : descale_v / sum;
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
// #pragma unroll
// for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
for (int i = 0; i < 16; i++)
{
acco_f[i].x *= scale;
acco_f[i].y *= scale;
acco_f[i].z *= scale;
acco_f[i].w *= scale;
}
}
return lse;
};
template<bool Is_dropout=false, bool Split=false, bool is_tp1=false, typename Tensor1>
__forceinline__ __device__ TensorT normalize_softmax_lse_fp8_tp4(v4f *acco_f, Tensor1& sRow_sum_reduce_buffer, float softmax_scale,float descale_v, float rp_dropout=1.0) {
SumOp<float> sum_op;
quad_allreduce_(row_sum, row_sum, sum_op);
flash::template warp_allreduce_tp4(row_sum, sRow_sum_reduce_buffer, sum_op);
TensorT lse = make_fragment_like(row_sum);
// Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
// static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < 1; ++mi) {
float sum = row_sum(mi);
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : descale_v / sum;
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
// #pragma unroll
// for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
for (int i = 0; i < 16; i++)
{
acco_f[i].x *= scale;
acco_f[i].y *= scale;
acco_f[i].z *= scale;
acco_f[i].w *= scale;
}
}
return lse;
};
};
} // namespace flash
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment