// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/softmax.h #pragma once #include #include #include #include "utils.h" namespace flash { using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// template __device__ __forceinline__ void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); #pragma unroll for (int mi = 0; mi < size<0>(tensor); mi++) { summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); #pragma unroll for (int ni = 1; ni < size<1>(tensor); ni++) { summary(mi) = op(summary(mi), tensor(mi, ni)); } } } template __device__ __forceinline__ void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { CUTE_STATIC_ASSERT_V(size(dst) == size(src)); #pragma unroll for (int i = 0; i < size(dst); i++){ dst(i) = Allreduce<64>::run(src(i), op); } } template __device__ __forceinline__ void warp_allreduce_(Tensor &dst, Tensor &smem_reduce, Operator &op) { const int tidx = threadIdx.x; const int row = tidx % 16; const int col = tidx / 64; const int warp_id = tidx / 64; // static_assert(size(dst) == 1); // 这里两种写法,一种是写连续,读不连续;另一种是读不连续,写连续。如何权衡?性能影响不大 if ((tidx % 64) / 16 == 0) // if (tidx >= warp_id * 64 && tidx <= warp_id * 64 + 16) { // smem_reduce(row + warp_id * 16) = dst(0); smem_reduce(row * 4 + warp_id * 1) = dst(0); // smem_reduce(row, col) = dst(0); } __syncthreads(); // if (tidx < 16) // { // smem_reduce(row + 64) = op(op(smem_reduce(row * 4), smem_reduce(row * 4 + 1)), op(smem_reduce(row * 4 + 2), smem_reduce(row * 4 + 3))); // } // __syncthreads(); // dst(0) = smem_reduce(row + 64); dst(0) = op(op(smem_reduce(row * 4), smem_reduce(row * 4 + 1)), op(smem_reduce(row * 4 + 2), smem_reduce(row * 4 + 3))); } template __device__ __forceinline__ void warp_allreduce_tp1(Tensor &dst, Tensor &smem_reduce, Operator &op) { const int tidx = threadIdx.x; const int col = (tidx % 64) / 16; const int warp_id = tidx / 64; const int row = tidx % 16 + (warp_id % 4) * 16; // 0-4 1-5 2-6 3-7 if (col == 0) { // printf("sum %d %d %d %d %.2f \n", row * 2 + (warp_id / 4), row, tidx, warp_id, dst(0)); smem_reduce[row * 2 + (warp_id / 4)] = dst[0]; } __syncthreads(); // if (col == 0 && warp_id < 4) { // // printf("sum %d %d %d %.2f %.2f \n", row, tidx, warp_id, smem_reduce[row * 2], smem_reduce[row * 2 + warp_id / 4]); // smem_reduce[128 + row] = op(smem_reduce[row * 2], smem_reduce[row * 2 + 1]); // } // __syncthreads(); // dst(0) = smem_reduce(128 + row); dst(0) = op(smem_reduce[row * 2], smem_reduce[row * 2 + 1]); } template __device__ __forceinline__ void warp_allreduce_tp4(Tensor &dst, Tensor &smem_reduce, Operator &op) { const int tidx = threadIdx.x; const int col = (tidx % 64) / 16; const int warp_id = tidx / 64; const int row = tidx % 16 + (warp_id % 2) * 16; // 0-4 1-5 2-6 3-7 if (col == 0) { // printf("sum %d %d %d %d %.2f \n", row * 2 + (warp_id / 4), row, tidx, warp_id, dst(0)); smem_reduce[row * 2 + (warp_id / 2)] = dst[0]; } __syncthreads(); if (col == 0 && warp_id < 2) { // printf("sum %d %d %d %.2f %.2f \n", row, tidx, warp_id, smem_reduce[row * 2], smem_reduce[row * 2 + warp_id / 4]); smem_reduce[row + 64] = op(smem_reduce[row * 2], smem_reduce[row * 2 + 1]); } __syncthreads(); dst(0) = smem_reduce(row + 64); } template __device__ __forceinline__ void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { thread_reduce_(tensor, summary, op); quad_allreduce_(summary, summary, op); } template __device__ __forceinline__ void reduce_max(Tensor const& tensor, Tensor &max){ MaxOp max_op; reduce_(tensor, max, max_op); } template __device__ __forceinline__ void reduce_sum(Tensor const& tensor, Tensor &sum){ SumOp sum_op; thread_reduce_(tensor, sum, sum_op); } // Apply the exp to all the elements. template __forceinline__ __device__ void scale_apply_exp2(Tensor &tensor, Tensor const &max, const float scale) { static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); #pragma unroll for (int mi = 0; mi < size<0>(tensor); ++mi) { // If max is -inf, then all elements must have been -inf (possibly due to masking). // We don't want (-inf - (-inf)) since that would give NaN. // If we don't have float around M_LOG2E the multiplication is done in fp64. const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); #pragma unroll for (int ni = 0; ni < size<1>(tensor); ++ni) { // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - // max * log_2(e)) This allows the compiler to use the ffma // instruction instead of fadd and fmul separately. // The following macro will disable the use of fma. // See: https://github.com/pytorch/pytorch/issues/121558 for more details // This macro is set in PyTorch and not FlashAttention #if 0 #ifdef UNFUSE_FMA tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled); #else tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); #endif #else tensor(mi, ni) = __builtin_amdgcn_exp2f(tensor(mi, ni) * scale - max_scaled); #endif } } } //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Softmax { using TensorT = decltype(make_tensor(Shape>{})); TensorT row_max, row_sum; __forceinline__ __device__ Softmax() {}; template __forceinline__ __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, Tensor2 &sRow_max_reduce_buffer, float softmax_scale_log2) { // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) MaxOp max_op; Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); static_assert(decltype(size<0>(scores))::value == kNRows); if (Is_first) { flash::template reduce_max(scores, row_max); if constexpr (is_tp1) { flash::template warp_allreduce_tp1(row_max, sRow_max_reduce_buffer, max_op); } else { flash::template warp_allreduce_(row_max, sRow_max_reduce_buffer, max_op); } flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); flash::reduce_sum(scores, row_sum); } else { Tensor scores_max_prev = make_fragment_like(row_max); cute::copy(row_max, scores_max_prev); flash::template reduce_max(scores, row_max); if constexpr (is_tp1) { flash::template warp_allreduce_tp1(row_max, sRow_max_reduce_buffer, max_op); } else { flash::template warp_allreduce_(row_max, sRow_max_reduce_buffer, max_op); } // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); #pragma unroll for (int mi = 0; mi < size(row_max); ++mi) { float scores_max_cur = !Check_inf ? row_max(mi) : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); #if 0 float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); #else float scores_scale = __builtin_amdgcn_exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); #endif // if (blockIdx.x == 0 && threadIdx.x == 0) // { // printf("threadIdx.x %.2f, scores_scale = %.4f\n",row_sum(mi), scores_scale ); // } row_sum(mi) *= scores_scale; #pragma unroll for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; } } flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); // We don't do the reduce across threads here since we don't need to use the row_sum. // We do that reduce at the end when we need to normalize the softmax. flash::reduce_sum(scores, row_sum); } // if (block0()) // { // printf("normalize_softmax_lse %.4f\n", row_sum(0)); // } }; template __forceinline__ __device__ void softmax_rescale_o_fp8(Tensor0 &acc_s, Tensor2 &sRow_max_reduce_buffer, float softmax_scale_log2, v4f& c0_0, v4f& c0_1, v4f& c1_0, v4f& c1_1, v4f& c2_0, v4f& c2_1, v4f& c3_0, v4f& c3_1 ) { // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) MaxOp max_op; Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); static_assert(decltype(size<0>(scores))::value == kNRows); if (Is_first) { flash::template reduce_max(scores, row_max); flash::template warp_allreduce_(row_max, sRow_max_reduce_buffer, max_op); flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); flash::reduce_sum(scores, row_sum); } else { Tensor scores_max_prev = make_fragment_like(row_max); cute::copy(row_max, scores_max_prev); flash::template reduce_max(scores, row_max); flash::template warp_allreduce_(row_max, sRow_max_reduce_buffer, max_op); // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) // Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); static_assert(1 == kNRows); // #pragma unroll // for (int mi = 0; mi < size(row_max); ++mi) { int mi = 0; float scores_max_cur = !Check_inf ? row_max(mi) : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); #if 0 float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); #else float scores_scale = __builtin_amdgcn_exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); #endif row_sum(mi) *= scores_scale; // #pragma unroll // for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; } c0_0.x *= scores_scale; c0_0.y *= scores_scale; c0_0.z *= scores_scale; c0_0.w *= scores_scale; c0_1.x *= scores_scale; c0_1.y *= scores_scale; c0_1.z *= scores_scale; c0_1.w *= scores_scale; c1_0.x *= scores_scale; c1_0.y *= scores_scale; c1_0.z *= scores_scale; c1_0.w *= scores_scale; c1_1.x *= scores_scale; c1_1.y *= scores_scale; c1_1.z *= scores_scale; c1_1.w *= scores_scale; c2_0.x *= scores_scale; c2_0.y *= scores_scale; c2_0.z *= scores_scale; c2_0.w *= scores_scale; c2_1.x *= scores_scale; c2_1.y *= scores_scale; c2_1.z *= scores_scale; c2_1.w *= scores_scale; c3_0.x *= scores_scale; c3_0.y *= scores_scale; c3_0.z *= scores_scale; c3_0.w *= scores_scale; c3_1.x *= scores_scale; c3_1.y *= scores_scale; c3_1.z *= scores_scale; c3_1.w *= scores_scale; } flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); // We don't do the reduce across threads here since we don't need to use the row_sum. // We do that reduce at the end when we need to normalize the softmax. flash::reduce_sum(scores, row_sum); } }; template __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, Tensor1& sRow_sum_reduce_buffer, float softmax_scale, float rp_dropout=1.0) { SumOp sum_op; quad_allreduce_(row_sum, row_sum, sum_op); if constexpr (is_tp1) { flash::template warp_allreduce_tp1(row_sum, sRow_sum_reduce_buffer, sum_op); } else { flash::template warp_allreduce_(row_sum, sRow_sum_reduce_buffer, sum_op); } // if (block0()) // { // printf("is_tp1 %d %d normalize_softmax_lse %.4f\n",is_tp1, threadIdx.x, row_sum(0)); // } TensorT lse = make_fragment_like(row_sum); Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); #pragma unroll for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { float sum = row_sum(mi); float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; #pragma unroll for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } } return lse; }; template __forceinline__ __device__ void softmax_rescale_o_prefill(Tensor0 &acc_s, Tensor1 &acc_o, Tensor2 &sRow_max_reduce_buffer, float softmax_scale_log2) { // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) MaxOp max_op; Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); static_assert(decltype(size<0>(scores))::value == kNRows); if (Is_first) { flash::template reduce_max(scores, row_max); flash::template warp_allreduce_(row_max, sRow_max_reduce_buffer, max_op); flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); flash::reduce_sum(scores, row_sum); } else { Tensor scores_max_prev = make_fragment_like(row_max); cute::copy(row_max, scores_max_prev); flash::template reduce_max(scores, row_max); flash::template warp_allreduce_(row_max, sRow_max_reduce_buffer, max_op); // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); #pragma unroll for (int mi = 0; mi < size(row_max); ++mi) { float scores_max_cur = !true ? row_max(mi) : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); #if 0 float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); #else float scores_scale = __builtin_amdgcn_exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); #endif // if (blockIdx.x == 0 && threadIdx.x == 0) // { // printf("threadIdx.x %.2f, scores_scale = %.4f\n",row_sum(mi), scores_scale ); // } row_sum(mi) *= scores_scale; #pragma unroll for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; } } // if (blockIdx.x == 2) // { // printf("threadIdx.x %.2f \n",row_sum(mi) ); // } flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); // We don't do the reduce across threads here since we don't need to use the row_sum. // We do that reduce at the end when we need to normalize the softmax. flash::reduce_sum(scores, row_sum); } }; template __forceinline__ __device__ TensorT normalize_softmax_lse_prefill(Tensor0 &acc_o, Tensor1& sRow_sum_reduce_buffer, float softmax_scale, float rp_dropout=1.0) { SumOp sum_op; quad_allreduce_(row_sum, row_sum, sum_op); flash::template warp_allreduce_(row_sum, sRow_sum_reduce_buffer, sum_op); TensorT lse = make_fragment_like(row_sum); Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); #pragma unroll for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { float sum = row_sum(mi); float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; #pragma unroll for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } } return lse; }; template __forceinline__ __device__ TensorT normalize_softmax_lse_fp8(Tensor0 &acc_o, Tensor1& sRow_sum_reduce_buffer, float softmax_scale,float descale_v, float rp_dropout=1.0) { SumOp sum_op; quad_allreduce_(row_sum, row_sum, sum_op); flash::template warp_allreduce_(row_sum, sRow_sum_reduce_buffer, sum_op); TensorT lse = make_fragment_like(row_sum); Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); #pragma unroll for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { float sum = row_sum(mi); float inv_sum = (sum == 0.f || sum != sum) ? 1.f : descale_v / sum; lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; #pragma unroll for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } } return lse; }; template __forceinline__ __device__ void softmax_rescale_o_fp8_tp1(Tensor0 &acc_s, Tensor2 &sRow_max_reduce_buffer, float softmax_scale_log2, v4f *acco_f32 ) { // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) MaxOp max_op; Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); static_assert(decltype(size<0>(scores))::value == kNRows); if constexpr (Is_first) { flash::template reduce_max(scores, row_max); if constexpr (is_tp1) { flash::template warp_allreduce_tp1(row_max, sRow_max_reduce_buffer, max_op); } else { flash::template warp_allreduce_(row_max, sRow_max_reduce_buffer, max_op); } flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); flash::reduce_sum(scores, row_sum); } else { Tensor scores_max_prev = make_fragment_like(row_max); cute::copy(row_max, scores_max_prev); flash::template reduce_max(scores, row_max); if constexpr (is_tp1) { flash::template warp_allreduce_tp1(row_max, sRow_max_reduce_buffer, max_op); } else { flash::template warp_allreduce_(row_max, sRow_max_reduce_buffer, max_op); } // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) // Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); static_assert(1 == kNRows); // #pragma unroll // for (int mi = 0; mi < size(row_max); ++mi) { int mi = 0; float scores_max_cur = !Check_inf ? row_max(mi) : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); #if 0 float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); #else float scores_scale = __builtin_amdgcn_exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); #endif row_sum(mi) *= scores_scale; for (int i = 0; i < 16; i++) { acco_f32[i].x *= scores_scale; acco_f32[i].y *= scores_scale; acco_f32[i].z *= scores_scale; acco_f32[i].w *= scores_scale; } // #pragma unroll // for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; } // c0_0.x *= scores_scale; c0_0.y *= scores_scale; c0_0.z *= scores_scale; c0_0.w *= scores_scale; // c0_1.x *= scores_scale; c0_1.y *= scores_scale; c0_1.z *= scores_scale; c0_1.w *= scores_scale; // c1_0.x *= scores_scale; c1_0.y *= scores_scale; c1_0.z *= scores_scale; c1_0.w *= scores_scale; // c1_1.x *= scores_scale; c1_1.y *= scores_scale; c1_1.z *= scores_scale; c1_1.w *= scores_scale; // c2_0.x *= scores_scale; c2_0.y *= scores_scale; c2_0.z *= scores_scale; c2_0.w *= scores_scale; // c2_1.x *= scores_scale; c2_1.y *= scores_scale; c2_1.z *= scores_scale; c2_1.w *= scores_scale; // c3_0.x *= scores_scale; c3_0.y *= scores_scale; c3_0.z *= scores_scale; c3_0.w *= scores_scale; // c3_1.x *= scores_scale; c3_1.y *= scores_scale; c3_1.z *= scores_scale; c3_1.w *= scores_scale; } flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); // We don't do the reduce across threads here since we don't need to use the row_sum. // We do that reduce at the end when we need to normalize the softmax. flash::reduce_sum(scores, row_sum); } }; template __forceinline__ __device__ void softmax_rescale_o_fp8_tp4(Tensor0 &acc_s, Tensor2 &sRow_max_reduce_buffer, float softmax_scale_log2, v4f *acco_f32 ) { // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) MaxOp max_op; Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); static_assert(decltype(size<0>(scores))::value == kNRows); if constexpr (Is_first) { flash::template reduce_max(scores, row_max); flash::template warp_allreduce_tp4(row_max, sRow_max_reduce_buffer, max_op); flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); flash::reduce_sum(scores, row_sum); } else { Tensor scores_max_prev = make_fragment_like(row_max); cute::copy(row_max, scores_max_prev); flash::template reduce_max(scores, row_max); flash::template warp_allreduce_tp4(row_max, sRow_max_reduce_buffer, max_op); // Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); static_assert(1 == kNRows); // #pragma unroll // for (int mi = 0; mi < size(row_max); ++mi) { int mi = 0; float scores_max_cur = !Check_inf ? row_max(mi) : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); #if 0 float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); #else float scores_scale = __builtin_amdgcn_exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); #endif row_sum(mi) *= scores_scale; for (int i = 0; i < 16; i++) { acco_f32[i].x *= scores_scale; acco_f32[i].y *= scores_scale; acco_f32[i].z *= scores_scale; acco_f32[i].w *= scores_scale; } // #pragma unroll // for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; } // c0_0.x *= scores_scale; c0_0.y *= scores_scale; c0_0.z *= scores_scale; c0_0.w *= scores_scale; // c0_1.x *= scores_scale; c0_1.y *= scores_scale; c0_1.z *= scores_scale; c0_1.w *= scores_scale; // c1_0.x *= scores_scale; c1_0.y *= scores_scale; c1_0.z *= scores_scale; c1_0.w *= scores_scale; // c1_1.x *= scores_scale; c1_1.y *= scores_scale; c1_1.z *= scores_scale; c1_1.w *= scores_scale; // c2_0.x *= scores_scale; c2_0.y *= scores_scale; c2_0.z *= scores_scale; c2_0.w *= scores_scale; // c2_1.x *= scores_scale; c2_1.y *= scores_scale; c2_1.z *= scores_scale; c2_1.w *= scores_scale; // c3_0.x *= scores_scale; c3_0.y *= scores_scale; c3_0.z *= scores_scale; c3_0.w *= scores_scale; // c3_1.x *= scores_scale; c3_1.y *= scores_scale; c3_1.z *= scores_scale; c3_1.w *= scores_scale; } flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); // We don't do the reduce across threads here since we don't need to use the row_sum. // We do that reduce at the end when we need to normalize the softmax. flash::reduce_sum(scores, row_sum); } }; template __forceinline__ __device__ TensorT normalize_softmax_lse_fp8_tp1(v4f *acco_f, Tensor1& sRow_sum_reduce_buffer, float softmax_scale,float descale_v, float rp_dropout=1.0) { SumOp sum_op; quad_allreduce_(row_sum, row_sum, sum_op); if constexpr (is_tp1) { flash::template warp_allreduce_tp1(row_sum, sRow_sum_reduce_buffer, sum_op); } else { flash::template warp_allreduce_(row_sum, sRow_sum_reduce_buffer, sum_op); } TensorT lse = make_fragment_like(row_sum); // Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); // static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); #pragma unroll for (int mi = 0; mi < 1; ++mi) { float sum = row_sum(mi); float inv_sum = (sum == 0.f || sum != sum) ? 1.f : descale_v / sum; lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; // #pragma unroll // for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } for (int i = 0; i < 16; i++) { acco_f[i].x *= scale; acco_f[i].y *= scale; acco_f[i].z *= scale; acco_f[i].w *= scale; } } return lse; }; template __forceinline__ __device__ TensorT normalize_softmax_lse_fp8_tp4(v4f *acco_f, Tensor1& sRow_sum_reduce_buffer, float softmax_scale,float descale_v, float rp_dropout=1.0) { SumOp sum_op; quad_allreduce_(row_sum, row_sum, sum_op); flash::template warp_allreduce_tp4(row_sum, sRow_sum_reduce_buffer, sum_op); TensorT lse = make_fragment_like(row_sum); // Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); // static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); #pragma unroll for (int mi = 0; mi < 1; ++mi) { float sum = row_sum(mi); float inv_sum = (sum == 0.f || sum != sum) ? 1.f : descale_v / sum; lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; // #pragma unroll // for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } for (int i = 0; i < 16; i++) { acco_f[i].x *= scale; acco_f[i].y *= scale; acco_f[i].z *= scale; acco_f[i].w *= scale; } } return lse; }; template __forceinline__ __device__ void softmax_rescale_o_prefill_4x1(Tensor0& scores, v4f* acc_o, float softmax_scale_log2) { // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) MaxOp max_op; // Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); static_assert(decltype(size<0>(scores))::value == kNRows); if constexpr(Is_first) { flash::template reduce_max(scores, row_max); flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); flash::reduce_sum(scores, row_sum); } else { Tensor scores_max_prev = make_fragment_like(row_max); cute::copy(row_max, scores_max_prev); flash::template reduce_max(scores, row_max); // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) // static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); #pragma unroll for (int mi = 0; mi < size(row_max); ++mi) { float scores_max_cur = !Check_inf ? row_max(mi) : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); #if 0 float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); #else float scores_scale = __builtin_amdgcn_exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); #endif // if (blockIdx.x == 0 && threadIdx.x == 0) // { // printf("threadIdx.x %.2f, scores_scale = %.4f\n",row_sum(mi), scores_scale ); // } row_sum(mi) *= scores_scale; for (int i = 0; i < 32; i++) { acc_o[i].x *= scores_scale; acc_o[i].y *= scores_scale; acc_o[i].z *= scores_scale; acc_o[i].w *= scores_scale; } } // if (blockIdx.x == 2) // { // printf("threadIdx.x %.2f \n",row_sum(mi) ); // } flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); // We don't do the reduce across threads here since we don't need to use the row_sum. // We do that reduce at the end when we need to normalize the softmax. flash::reduce_sum(scores, row_sum); } // if (thread0()) // { // printf("max sum %.3f %.3f \n", row_max(0), row_sum(0)); // } }; template __forceinline__ __device__ TensorT normalize_softmax_lse_prefill_4x1(v4f *acc_o, float softmax_scale, float rp_dropout=1.0) { SumOp sum_op; quad_allreduce_(row_sum, row_sum, sum_op); // flash::template warp_allreduce_(row_sum, sRow_sum_reduce_buffer, sum_op); TensorT lse = make_fragment_like(row_sum); // Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); // static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); // if (thread0()) // { // printf(" %.3f %.3f \n", row_max(0), row_sum(0)); // } #pragma unroll for (int mi = 0; mi < 1; ++mi) { float sum = row_sum(mi); float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; for (int i = 0; i < 32; i++) { acc_o[i].x *= scale; acc_o[i].y *= scale; acc_o[i].z *= scale; acc_o[i].w *= scale; } } return lse; }; }; } // namespace flash