Commit 54617a85 authored by Jiming Ruan's avatar Jiming Ruan
Browse files

Adds support to Welford algorithm and fast div for rmsnorm

parent c2ea75ed
...@@ -36,7 +36,7 @@ struct BlockNormReduce ...@@ -36,7 +36,7 @@ struct BlockNormReduce
constexpr auto I1 = number<1>{}; constexpr auto I1 = number<1>{};
constexpr auto spans = XDistributedTensor_::get_distributed_spans(); constexpr auto spans = XDistributedTensor_::get_distributed_spans();
constexpr bool computeVariance = constexpr bool comp_var =
!std::is_same<VarDistributedTensor_, null_tensor>::value && kComputeVariance; !std::is_same<VarDistributedTensor_, null_tensor>::value && kComputeVariance;
sweep_tile_span(spans[I1], [&](auto dstr_idx_i1) { sweep_tile_span(spans[I1], [&](auto dstr_idx_i1) {
...@@ -50,7 +50,7 @@ struct BlockNormReduce ...@@ -50,7 +50,7 @@ struct BlockNormReduce
auto x = ck_tile::type_convert<ComputeDataType>(x_tensor[in_dstr_idx]); auto x = ck_tile::type_convert<ComputeDataType>(x_tensor[in_dstr_idx]);
if(kWelford) if(kWelford)
{ {
if constexpr(computeVariance) if constexpr(comp_var)
{ {
welford_update(mean_tensor(out_dstr_idx), welford_update(mean_tensor(out_dstr_idx),
var_tensor(out_dstr_idx), var_tensor(out_dstr_idx),
...@@ -67,7 +67,7 @@ struct BlockNormReduce ...@@ -67,7 +67,7 @@ struct BlockNormReduce
else else
{ {
mean_tensor(out_dstr_idx) += x; mean_tensor(out_dstr_idx) += x;
if constexpr(computeVariance) if constexpr(comp_var)
{ {
var_tensor(out_dstr_idx) += x * x; var_tensor(out_dstr_idx) += x * x;
} }
...@@ -98,7 +98,8 @@ struct BlockNormReduce ...@@ -98,7 +98,8 @@ struct BlockNormReduce
int& cur_count_, int& cur_count_,
const int& max_count_) const int& max_count_)
{ {
Impl(x_tensor, mean_tensor, null_tensor{}, cur_count_, max_count_); auto nt = null_tensor{};
Impl(x_tensor, mean_tensor, nt, cur_count_, max_count_);
} }
template <typename XDistributedTensor_> template <typename XDistributedTensor_>
...@@ -152,31 +153,39 @@ struct BlockNormReduceSync ...@@ -152,31 +153,39 @@ struct BlockNormReduceSync
using DstrEncode = typename Dstr::DstrEncode; using DstrEncode = typename Dstr::DstrEncode;
using DstrEncodeDetail = typename DstrEncode::detail; using DstrEncodeDetail = typename DstrEncode::detail;
static_assert(std::is_same_v<Dstr, typename VarDistributedTensor_::StaticTileDistribution>,
"wrong!");
constexpr index_t NDimP = Dstr::get_num_of_dimension_p(); constexpr index_t NDimP = Dstr::get_num_of_dimension_p();
constexpr index_t NDimR = Dstr::get_num_of_dimension_r(); constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
constexpr index_t idim_p_lane = NDimP - 1; constexpr index_t idim_p_lane = NDimP - 1;
constexpr bool computeVariance = constexpr bool comp_var =
!std::is_same<VarDistributedTensor_, null_tensor>::value && kComputeVariance; !std::is_same<VarDistributedTensor_, null_tensor>::value && kComputeVariance;
constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size();
if constexpr(comp_var)
{
static_assert(
std::is_same_v<Dstr, typename VarDistributedTensor_::StaticTileDistribution>,
"wrong!");
static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size());
}
// const auto ps_idx = make_array<index_t>(get_warp_id(), get_lane_id()); // const auto ps_idx = make_array<index_t>(get_warp_id(), get_lane_id());
// const auto rs_idx = // const auto rs_idx =
// mean_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx); // mean_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size();
static_assert((computeVariance == false) ||
(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size()));
const int original_count = count; const int original_count = count;
// loop over thread data // loop over thread data
static_for<0, thread_buf_size, 1>{}([&](auto i) { static_for<0, thread_buf_size, 1>{}([&](auto i) {
auto v_local_mean = mean_tensor.get_thread_buffer()[i]; auto v_local_mean = mean_tensor.get_thread_buffer()[i];
auto v_local_var = computeVariance ? var_tensor.get_thread_buffer()[i] : 0; auto v_local_var = [&]() {
if constexpr(comp_var)
return var_tensor.get_thread_buffer()[i];
else
return 0;
}();
auto v_local_count = original_count; auto v_local_count = original_count;
// cross-lane reduce for replication // cross-lane reduce for replication
...@@ -206,13 +215,13 @@ struct BlockNormReduceSync ...@@ -206,13 +215,13 @@ struct BlockNormReduceSync
// pull data from remote lane // pull data from remote lane
const auto v_remote_mean = warp_shuffle(v_local_mean, src_lane); const auto v_remote_mean = warp_shuffle(v_local_mean, src_lane);
const auto v_remote_var = const auto v_remote_var =
computeVariance ? warp_shuffle(v_local_var, src_lane) : 0; comp_var ? warp_shuffle(v_local_var, src_lane) : 0;
if(kWelford) if(kWelford)
{ {
const auto v_remote_count = warp_shuffle(v_local_count, src_lane); const auto v_remote_count = warp_shuffle(v_local_count, src_lane);
// norm_reduce merge // norm_reduce merge
if constexpr(computeVariance) if constexpr(comp_var)
{ {
welford_merge(v_local_mean, welford_merge(v_local_mean,
v_local_var, v_local_var,
...@@ -234,7 +243,7 @@ struct BlockNormReduceSync ...@@ -234,7 +243,7 @@ struct BlockNormReduceSync
else else
{ {
v_local_mean += v_remote_mean; v_local_mean += v_remote_mean;
if constexpr(computeVariance) if constexpr(comp_var)
{ {
v_local_var += v_remote_var; v_local_var += v_remote_var;
} }
...@@ -244,7 +253,11 @@ struct BlockNormReduceSync ...@@ -244,7 +253,11 @@ struct BlockNormReduceSync
}); });
mean_tensor.get_thread_buffer()(i) = v_local_mean; mean_tensor.get_thread_buffer()(i) = v_local_mean;
if constexpr(comp_var)
{
var_tensor.get_thread_buffer()(i) = v_local_var; var_tensor.get_thread_buffer()(i) = v_local_var;
}
if(kWelford) if(kWelford)
{ {
count = v_local_count; count = v_local_count;
...@@ -263,7 +276,8 @@ struct BlockNormReduceSync ...@@ -263,7 +276,8 @@ struct BlockNormReduceSync
template <typename MeanDistributedTensor_> template <typename MeanDistributedTensor_>
CK_TILE_DEVICE void operator()(MeanDistributedTensor_& mean_tensor, int& count) CK_TILE_DEVICE void operator()(MeanDistributedTensor_& mean_tensor, int& count)
{ {
Impl(mean_tensor, null_tensor{}, count); auto nt = null_tensor{};
Impl(mean_tensor, nt, count);
} }
}; };
...@@ -348,17 +362,18 @@ struct BlockNormReduceCrossWarpSync ...@@ -348,17 +362,18 @@ struct BlockNormReduceCrossWarpSync
// using DstrEncode = typename Dstr::DstrEncode; // using DstrEncode = typename Dstr::DstrEncode;
// using DstrEncodeDetail = typename DstrEncode::detail; // using DstrEncodeDetail = typename DstrEncode::detail;
constexpr bool computeVariance = constexpr bool comp_var =
!std::is_same<VarDistributedTensor_, null_tensor>::value && kComputeVariance; !std::is_same<VarDistributedTensor_, null_tensor>::value && kComputeVariance;
constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size();
if constexpr(comp_var)
{
static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size());
static_assert( static_assert(
(computeVariance == false) ||
std::is_same_v<Dstr, typename VarDistributedTensor_::StaticTileDistribution>, std::is_same_v<Dstr, typename VarDistributedTensor_::StaticTileDistribution>,
"wrong!"); "wrong!");
}
constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size();
static_assert((computeVariance == false) ||
(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size()));
// Note: we always pack everything into fp32x4 // Note: we always pack everything into fp32x4
smem_dtype* smem_ptr = reinterpret_cast<smem_dtype*>(smem); smem_dtype* smem_ptr = reinterpret_cast<smem_dtype*>(smem);
...@@ -413,7 +428,12 @@ struct BlockNormReduceCrossWarpSync ...@@ -413,7 +428,12 @@ struct BlockNormReduceCrossWarpSync
// TODO: use descriptor for this // TODO: use descriptor for this
auto v_local = all_scratch[i_0 * num_reduce_warps]; auto v_local = all_scratch[i_0 * num_reduce_warps];
auto v_local_mean = bit_cast<DataType>(v_local[0]); auto v_local_mean = bit_cast<DataType>(v_local[0]);
auto v_local_var = kComputeVariance ? bit_cast<DataType>(v_local[1]) : 0; auto v_local_var = [&]() {
if constexpr(comp_var)
return bit_cast<DataType>(v_local[1]);
else
return 0;
}();
int v_local_count = kWelford ? (kComputeVariance ? bit_cast<int>(v_local[2]) int v_local_count = kWelford ? (kComputeVariance ? bit_cast<int>(v_local[2])
: bit_cast<int>(v_local[1])) : bit_cast<int>(v_local[1]))
: 0; : 0;
...@@ -458,7 +478,11 @@ struct BlockNormReduceCrossWarpSync ...@@ -458,7 +478,11 @@ struct BlockNormReduceCrossWarpSync
}); });
mean_tensor.get_thread_buffer()(i_0) = v_local_mean; mean_tensor.get_thread_buffer()(i_0) = v_local_mean;
if constexpr(comp_var)
{
var_tensor.get_thread_buffer()(i_0) = v_local_var; var_tensor.get_thread_buffer()(i_0) = v_local_var;
}
if constexpr(kWelford) if constexpr(kWelford)
{ {
count = v_local_count; count = v_local_count;
...@@ -479,7 +503,8 @@ struct BlockNormReduceCrossWarpSync ...@@ -479,7 +503,8 @@ struct BlockNormReduceCrossWarpSync
template <typename MeanDistributedTensor_> template <typename MeanDistributedTensor_>
CK_TILE_DEVICE void operator()(MeanDistributedTensor_& mean_tensor, int& count, void* smem) CK_TILE_DEVICE void operator()(MeanDistributedTensor_& mean_tensor, int& count, void* smem)
{ {
Impl(mean_tensor, null_tensor{}, count, smem); auto nt = null_tensor{};
Impl(mean_tensor, nt, count, smem);
} }
}; };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp" #include "ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d.hpp" #include "ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -43,30 +43,39 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy ...@@ -43,30 +43,39 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2d() CK_TILE_HOST_DEVICE static constexpr auto GetBlockNormReduce()
{ {
using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType, using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType, typename Problem::ComputeDataType,
typename Problem::BlockShape>; typename Problem::BlockShape,
return BlockReduce2d<P_>{}; false,
Problem::Traits::kFastFDiv,
Problem::Traits::kWelford>;
return BlockNormReduce<P_>{};
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dSync() CK_TILE_HOST_DEVICE static constexpr auto GetBlockNormReduceSync()
{ {
using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType, using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType, typename Problem::ComputeDataType,
typename Problem::BlockShape>; typename Problem::BlockShape,
return BlockReduce2dSync<P_>{}; false,
Problem::Traits::kFastFDiv,
Problem::Traits::kWelford>;
return BlockNormReduceSync<P_>{};
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dCrossWarpSync() CK_TILE_HOST_DEVICE static constexpr auto GetBlockNormReduceCrossWarpSync()
{ {
using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType, using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType, typename Problem::ComputeDataType,
typename Problem::BlockShape>; typename Problem::BlockShape,
return BlockReduce2dCrossWarpSync<P_>{}; false,
Problem::Traits::kFastFDiv,
Problem::Traits::kWelford>;
return BlockNormReduceCrossWarpSync<P_>{};
} }
template <typename Problem> template <typename Problem>
...@@ -74,17 +83,22 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy ...@@ -74,17 +83,22 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
{ {
if constexpr(Problem::kNeedCrossWarpSync) if constexpr(Problem::kNeedCrossWarpSync)
{ {
using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType, using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType, typename Problem::ComputeDataType,
typename Problem::BlockShape>; typename Problem::BlockShape,
false,
Problem::Traits::kFastFDiv,
Problem::Traits::kWelford>;
using block_reduce2d = BlockReduce2d<P_>; using block_reduce = BlockNormReduce<P_>;
using x_block_tile = using x_block_tile =
decltype(make_static_distributed_tensor<typename Problem::ComputeDataType>( decltype(make_static_distributed_tensor<typename Problem::ComputeDataType>(
MakeXBlockTileDistribution<Problem>())); MakeXBlockTileDistribution<Problem>()));
using y_block_tile = decltype(block_reduce2d::template MakeYBlockTile<x_block_tile>()); using mean_var_block_tile =
decltype(block_reduce::template MakeMeanVarBlockTile<x_block_tile>());
return GetBlockReduce2dCrossWarpSync<Problem>().template GetSmemSize<y_block_tile>(); return GetBlockNormReduceCrossWarpSync<Problem>()
.template GetSmemSize<mean_var_block_tile>();
} }
else else
{ {
......
...@@ -31,6 +31,8 @@ struct Rmsnorm2dFwdPipelineOnePass ...@@ -31,6 +31,8 @@ struct Rmsnorm2dFwdPipelineOnePass
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::Traits::kPadN; static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv;
static constexpr bool kWelford = Problem::Traits::kWelford;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd; static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant; static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
...@@ -62,7 +64,7 @@ struct Rmsnorm2dFwdPipelineOnePass ...@@ -62,7 +64,7 @@ struct Rmsnorm2dFwdPipelineOnePass
const YResidualWindow& y_residual_window_, const YResidualWindow& y_residual_window_,
InvRmsWindow& inv_rms_window, InvRmsWindow& inv_rms_window,
const SmoothScaleWindow& sm_scale_window_, const SmoothScaleWindow& sm_scale_window_,
YScaleWindow& y_scale_window_, YScaleWindow& y_scale_window,
ComputeDataType epsilon, ComputeDataType epsilon,
ck_tile::index_t row_size, ck_tile::index_t row_size,
void* smem, void* smem,
...@@ -77,12 +79,13 @@ struct Rmsnorm2dFwdPipelineOnePass ...@@ -77,12 +79,13 @@ struct Rmsnorm2dFwdPipelineOnePass
auto y_residual_window = make_tile_window( auto y_residual_window = make_tile_window(
y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>()); y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto reduce_square_sum_func = ReduceOp::SquareAdd{}; int cur_count = 0;
auto reduce_sum_func = ReduceOp::Add{}; int max_count =
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>(); block_tile_welford_calculate_max_count<typename Problem::BlockShape>(row_size);
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>(); auto block_norm_reduce = Policy::template GetBlockNormReduce<Problem>();
auto block_reduce2d_cross_warp_sync = auto block_norm_reduce_sync = Policy::template GetBlockNormReduceSync<Problem>();
Policy::template GetBlockReduce2dCrossWarpSync<Problem>(); auto block_norm_reduce_cross_warp_sync =
Policy::template GetBlockNormReduceCrossWarpSync<Problem>();
auto x = load_tile(x_window); auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window); auto x_resi = load_tile(x_residual_window);
...@@ -105,19 +108,32 @@ struct Rmsnorm2dFwdPipelineOnePass ...@@ -105,19 +108,32 @@ struct Rmsnorm2dFwdPipelineOnePass
} }
} }
// Calculate square here because block norm reduce only supports naive mean.
auto square = make_static_distributed_tensor<ComputeDataType>(acc.get_tile_distribution());
sweep_tile(acc, [&](auto idx) { square(idx) = acc(idx) * acc(idx); });
// compute mean square each-thread->cross-lane->cross-warp // compute mean square each-thread->cross-lane->cross-warp
auto square_sum = block_reduce2d(acc, using XTensorType = decltype(cast_tile<ComputeDataType>(x));
reduce_square_sum_func.GetIdentityValue<ComputeDataType>(), auto square_mean = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
reduce_square_sum_func); clear_tile(square_mean);
block_reduce2d_sync(square_sum, reduce_sum_func); block_norm_reduce(square, square_mean, cur_count, max_count);
block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func); block_norm_reduce_sync(square_mean, cur_count);
block_norm_reduce_cross_warp_sync(square_mean, cur_count, smem);
// compute inv-rms // compute inv-rms
auto inv_rms = tile_elementwise_in( auto inv_rms = tile_elementwise_in(
[&](const auto& v_) { [&](const auto& v_) {
return type_convert<ComputeDataType>(1.0f) / (sqrt(v_ / row_size + epsilon)); if constexpr(kFastFDiv && std::is_same_v<ComputeDataType, float>)
{
return type_convert<ComputeDataType>(1.0f) *
__builtin_amdgcn_rcpf(sqrt(v_ + epsilon));
}
else
{
return type_convert<ComputeDataType>(1.0f) / (sqrt(v_ + epsilon));
}
}, },
square_sum); square_mean);
if constexpr(kSaveInvRms) if constexpr(kSaveInvRms)
store_tile(inv_rms_window, cast_tile<InvRmsDataType>(inv_rms)); store_tile(inv_rms_window, cast_tile<InvRmsDataType>(inv_rms));
...@@ -137,11 +153,11 @@ struct Rmsnorm2dFwdPipelineOnePass ...@@ -137,11 +153,11 @@ struct Rmsnorm2dFwdPipelineOnePass
if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
{ {
Epilogue{}(y_window_, sm_scale_window_, y_scale_window_, rmsn, smem); Epilogue{}(y_window_, sm_scale_window_, y_scale_window, rmsn, smem);
} }
else if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT) else if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT)
{ {
Epilogue{}(y_window_, y_scale_window_, rmsn, smem); Epilogue{}(y_window_, y_scale_window, rmsn, smem);
} }
else else
{ {
......
...@@ -31,6 +31,8 @@ struct Rmsnorm2dFwdPipelineTwoPass ...@@ -31,6 +31,8 @@ struct Rmsnorm2dFwdPipelineTwoPass
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::Traits::kPadN; static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv;
static constexpr bool kWelford = Problem::Traits::kWelford;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd; static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant; static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
...@@ -82,16 +84,23 @@ struct Rmsnorm2dFwdPipelineTwoPass ...@@ -82,16 +84,23 @@ struct Rmsnorm2dFwdPipelineTwoPass
index_t num_n_tile_iteration = index_t num_n_tile_iteration =
__builtin_amdgcn_readfirstlane(integer_divide_ceil(row_size, Block_N)); __builtin_amdgcn_readfirstlane(integer_divide_ceil(row_size, Block_N));
auto reduce_square_sum_func = ReduceOp::SquareAdd{}; // total number of count assume current iter have no pad(only last iter has pad)
auto reduce_sum_func = ReduceOp::Add{}; constexpr index_t count_per_iter =
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>(); Problem::BlockShape::Repeat_N * Problem::BlockShape::Vector_N;
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>(); const index_t last_iter_n = row_size - (num_n_tile_iteration - 1) * Block_N;
auto block_reduce2d_cross_warp_sync =
Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
using ComputeTensorType = decltype(cast_tile<ComputeDataType>(load_tile(x_window))); int cur_count = 0;
auto square_sum = block_reduce2d.template MakeYBlockTile<ComputeTensorType>(); int max_count =
set_tile(square_sum, reduce_square_sum_func.GetIdentityValue<ComputeDataType>()); (num_n_tile_iteration - 1) * count_per_iter +
block_tile_welford_calculate_max_count<typename Problem::BlockShape>(last_iter_n);
auto block_norm_reduce = Policy::template GetBlockNormReduce<Problem>();
auto block_norm_reduce_sync = Policy::template GetBlockNormReduceSync<Problem>();
auto block_norm_reduce_cross_warp_sync =
Policy::template GetBlockNormReduceCrossWarpSync<Problem>();
using XTensorType = decltype(cast_tile<ComputeDataType>(load_tile(x_window)));
auto square_mean = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
clear_tile(square_mean);
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{ {
...@@ -102,6 +111,7 @@ struct Rmsnorm2dFwdPipelineTwoPass ...@@ -102,6 +111,7 @@ struct Rmsnorm2dFwdPipelineTwoPass
move_tile_window(x_residual_window, {0, Block_N}); move_tile_window(x_residual_window, {0, Block_N});
auto acc = cast_tile<ComputeDataType>(x); auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD || if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD ||
kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE) kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{ {
...@@ -116,18 +126,29 @@ struct Rmsnorm2dFwdPipelineTwoPass ...@@ -116,18 +126,29 @@ struct Rmsnorm2dFwdPipelineTwoPass
} }
} }
block_reduce2d(acc, square_sum, reduce_square_sum_func); // Calculate square here because block norm reduce only supports naive mean.
sweep_tile(acc, [&](auto idx) { acc(idx) *= acc(idx); });
block_norm_reduce(acc, square_mean, cur_count, max_count);
} }
block_reduce2d_sync(square_sum, reduce_sum_func); block_norm_reduce_sync(square_mean, cur_count);
block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func); block_norm_reduce_cross_warp_sync(square_mean, cur_count, smem);
// compute inv-rms // compute inv-rms
auto inv_rms = tile_elementwise_in( auto inv_rms = tile_elementwise_in(
[&](const auto& v_) { [&](const auto& v_) {
return type_convert<ComputeDataType>(1.0f) / (sqrt(v_ / row_size + epsilon)); if(kFastFDiv && std::is_same_v<ComputeDataType, float>)
{
return type_convert<ComputeDataType>(1.0f) *
__builtin_amdgcn_rcpf(sqrt(v_ + epsilon));
}
else
{
return type_convert<ComputeDataType>(1.0f) / (sqrt(v_ + epsilon));
}
}, },
square_sum); square_mean);
if constexpr(kSaveInvRms) if constexpr(kSaveInvRms)
store_tile(inv_rms_window, cast_tile<InvRmsDataType>(inv_rms)); store_tile(inv_rms_window, cast_tile<InvRmsDataType>(inv_rms));
......
...@@ -39,6 +39,8 @@ template<> struct Rmsnorm2dFusedQuantEnumName<Rmsnorm2dFusedQuantEnum::SMOOTH_DY ...@@ -39,6 +39,8 @@ template<> struct Rmsnorm2dFusedQuantEnumName<Rmsnorm2dFusedQuantEnum::SMOOTH_DY
template <bool kPadN_, template <bool kPadN_,
bool kSaveInvRms_, bool kSaveInvRms_,
bool kFastFDiv_,
bool kWelford_,
bool kTwoPass_, bool kTwoPass_,
Rmsnorm2dFusedAddEnum kFusedAdd_, Rmsnorm2dFusedAddEnum kFusedAdd_,
Rmsnorm2dFusedQuantEnum kFusedQuant_> Rmsnorm2dFusedQuantEnum kFusedQuant_>
...@@ -46,6 +48,8 @@ struct Rmsnorm2dFwdTraits ...@@ -46,6 +48,8 @@ struct Rmsnorm2dFwdTraits
{ {
static constexpr bool kPadN = kPadN_; static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveInvRms = kSaveInvRms_; static constexpr bool kSaveInvRms = kSaveInvRms_;
static constexpr bool kFastFDiv = kFastFDiv_;
static constexpr bool kWelford = kWelford_;
static constexpr bool kTwoPass = kTwoPass_; static constexpr bool kTwoPass = kTwoPass_;
static constexpr Rmsnorm2dFusedAddEnum kFusedAdd = kFusedAdd_; static constexpr Rmsnorm2dFusedAddEnum kFusedAdd = kFusedAdd_;
static constexpr Rmsnorm2dFusedQuantEnum kFusedQuant = kFusedQuant_; static constexpr Rmsnorm2dFusedQuantEnum kFusedQuant = kFusedQuant_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment