Commit 54617a85 authored by Jiming Ruan's avatar Jiming Ruan
Browse files

Adds support to Welford algorithm and fast div for rmsnorm

parent c2ea75ed
......@@ -36,7 +36,7 @@ struct BlockNormReduce
constexpr auto I1 = number<1>{};
constexpr auto spans = XDistributedTensor_::get_distributed_spans();
constexpr bool computeVariance =
constexpr bool comp_var =
!std::is_same<VarDistributedTensor_, null_tensor>::value && kComputeVariance;
sweep_tile_span(spans[I1], [&](auto dstr_idx_i1) {
......@@ -50,7 +50,7 @@ struct BlockNormReduce
auto x = ck_tile::type_convert<ComputeDataType>(x_tensor[in_dstr_idx]);
if(kWelford)
{
if constexpr(computeVariance)
if constexpr(comp_var)
{
welford_update(mean_tensor(out_dstr_idx),
var_tensor(out_dstr_idx),
......@@ -67,7 +67,7 @@ struct BlockNormReduce
else
{
mean_tensor(out_dstr_idx) += x;
if constexpr(computeVariance)
if constexpr(comp_var)
{
var_tensor(out_dstr_idx) += x * x;
}
......@@ -98,7 +98,8 @@ struct BlockNormReduce
int& cur_count_,
const int& max_count_)
{
Impl(x_tensor, mean_tensor, null_tensor{}, cur_count_, max_count_);
auto nt = null_tensor{};
Impl(x_tensor, mean_tensor, nt, cur_count_, max_count_);
}
template <typename XDistributedTensor_>
......@@ -152,31 +153,39 @@ struct BlockNormReduceSync
using DstrEncode = typename Dstr::DstrEncode;
using DstrEncodeDetail = typename DstrEncode::detail;
static_assert(std::is_same_v<Dstr, typename VarDistributedTensor_::StaticTileDistribution>,
"wrong!");
constexpr index_t NDimP = Dstr::get_num_of_dimension_p();
constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
constexpr index_t idim_p_lane = NDimP - 1;
constexpr bool computeVariance =
constexpr bool comp_var =
!std::is_same<VarDistributedTensor_, null_tensor>::value && kComputeVariance;
constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size();
if constexpr(comp_var)
{
static_assert(
std::is_same_v<Dstr, typename VarDistributedTensor_::StaticTileDistribution>,
"wrong!");
static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size());
}
// const auto ps_idx = make_array<index_t>(get_warp_id(), get_lane_id());
// const auto rs_idx =
// mean_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size();
static_assert((computeVariance == false) ||
(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size()));
const int original_count = count;
// loop over thread data
static_for<0, thread_buf_size, 1>{}([&](auto i) {
auto v_local_mean = mean_tensor.get_thread_buffer()[i];
auto v_local_var = computeVariance ? var_tensor.get_thread_buffer()[i] : 0;
auto v_local_mean = mean_tensor.get_thread_buffer()[i];
auto v_local_var = [&]() {
if constexpr(comp_var)
return var_tensor.get_thread_buffer()[i];
else
return 0;
}();
auto v_local_count = original_count;
// cross-lane reduce for replication
......@@ -206,13 +215,13 @@ struct BlockNormReduceSync
// pull data from remote lane
const auto v_remote_mean = warp_shuffle(v_local_mean, src_lane);
const auto v_remote_var =
computeVariance ? warp_shuffle(v_local_var, src_lane) : 0;
comp_var ? warp_shuffle(v_local_var, src_lane) : 0;
if(kWelford)
{
const auto v_remote_count = warp_shuffle(v_local_count, src_lane);
// norm_reduce merge
if constexpr(computeVariance)
if constexpr(comp_var)
{
welford_merge(v_local_mean,
v_local_var,
......@@ -234,7 +243,7 @@ struct BlockNormReduceSync
else
{
v_local_mean += v_remote_mean;
if constexpr(computeVariance)
if constexpr(comp_var)
{
v_local_var += v_remote_var;
}
......@@ -244,7 +253,11 @@ struct BlockNormReduceSync
});
mean_tensor.get_thread_buffer()(i) = v_local_mean;
var_tensor.get_thread_buffer()(i) = v_local_var;
if constexpr(comp_var)
{
var_tensor.get_thread_buffer()(i) = v_local_var;
}
if(kWelford)
{
count = v_local_count;
......@@ -263,7 +276,8 @@ struct BlockNormReduceSync
template <typename MeanDistributedTensor_>
CK_TILE_DEVICE void operator()(MeanDistributedTensor_& mean_tensor, int& count)
{
Impl(mean_tensor, null_tensor{}, count);
auto nt = null_tensor{};
Impl(mean_tensor, nt, count);
}
};
......@@ -348,17 +362,18 @@ struct BlockNormReduceCrossWarpSync
// using DstrEncode = typename Dstr::DstrEncode;
// using DstrEncodeDetail = typename DstrEncode::detail;
constexpr bool computeVariance =
constexpr bool comp_var =
!std::is_same<VarDistributedTensor_, null_tensor>::value && kComputeVariance;
static_assert(
(computeVariance == false) ||
std::is_same_v<Dstr, typename VarDistributedTensor_::StaticTileDistribution>,
"wrong!");
constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size();
static_assert((computeVariance == false) ||
(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size()));
if constexpr(comp_var)
{
static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size());
static_assert(
std::is_same_v<Dstr, typename VarDistributedTensor_::StaticTileDistribution>,
"wrong!");
}
// Note: we always pack everything into fp32x4
smem_dtype* smem_ptr = reinterpret_cast<smem_dtype*>(smem);
......@@ -413,7 +428,12 @@ struct BlockNormReduceCrossWarpSync
// TODO: use descriptor for this
auto v_local = all_scratch[i_0 * num_reduce_warps];
auto v_local_mean = bit_cast<DataType>(v_local[0]);
auto v_local_var = kComputeVariance ? bit_cast<DataType>(v_local[1]) : 0;
auto v_local_var = [&]() {
if constexpr(comp_var)
return bit_cast<DataType>(v_local[1]);
else
return 0;
}();
int v_local_count = kWelford ? (kComputeVariance ? bit_cast<int>(v_local[2])
: bit_cast<int>(v_local[1]))
: 0;
......@@ -458,7 +478,11 @@ struct BlockNormReduceCrossWarpSync
});
mean_tensor.get_thread_buffer()(i_0) = v_local_mean;
var_tensor.get_thread_buffer()(i_0) = v_local_var;
if constexpr(comp_var)
{
var_tensor.get_thread_buffer()(i_0) = v_local_var;
}
if constexpr(kWelford)
{
count = v_local_count;
......@@ -479,7 +503,8 @@ struct BlockNormReduceCrossWarpSync
template <typename MeanDistributedTensor_>
CK_TILE_DEVICE void operator()(MeanDistributedTensor_& mean_tensor, int& count, void* smem)
{
Impl(mean_tensor, null_tensor{}, count, smem);
auto nt = null_tensor{};
Impl(mean_tensor, nt, count, smem);
}
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d.hpp"
#include "ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp"
#include "ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp"
namespace ck_tile {
......@@ -43,30 +43,39 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2d()
CK_TILE_HOST_DEVICE static constexpr auto GetBlockNormReduce()
{
using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
return BlockReduce2d<P_>{};
using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape,
false,
Problem::Traits::kFastFDiv,
Problem::Traits::kWelford>;
return BlockNormReduce<P_>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dSync()
CK_TILE_HOST_DEVICE static constexpr auto GetBlockNormReduceSync()
{
using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
return BlockReduce2dSync<P_>{};
using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape,
false,
Problem::Traits::kFastFDiv,
Problem::Traits::kWelford>;
return BlockNormReduceSync<P_>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dCrossWarpSync()
CK_TILE_HOST_DEVICE static constexpr auto GetBlockNormReduceCrossWarpSync()
{
using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
return BlockReduce2dCrossWarpSync<P_>{};
using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape,
false,
Problem::Traits::kFastFDiv,
Problem::Traits::kWelford>;
return BlockNormReduceCrossWarpSync<P_>{};
}
template <typename Problem>
......@@ -74,17 +83,22 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
{
if constexpr(Problem::kNeedCrossWarpSync)
{
using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape,
false,
Problem::Traits::kFastFDiv,
Problem::Traits::kWelford>;
using block_reduce2d = BlockReduce2d<P_>;
using block_reduce = BlockNormReduce<P_>;
using x_block_tile =
decltype(make_static_distributed_tensor<typename Problem::ComputeDataType>(
MakeXBlockTileDistribution<Problem>()));
using y_block_tile = decltype(block_reduce2d::template MakeYBlockTile<x_block_tile>());
using mean_var_block_tile =
decltype(block_reduce::template MakeMeanVarBlockTile<x_block_tile>());
return GetBlockReduce2dCrossWarpSync<Problem>().template GetSmemSize<y_block_tile>();
return GetBlockNormReduceCrossWarpSync<Problem>()
.template GetSmemSize<mean_var_block_tile>();
}
else
{
......
......@@ -31,6 +31,8 @@ struct Rmsnorm2dFwdPipelineOnePass
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv;
static constexpr bool kWelford = Problem::Traits::kWelford;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
......@@ -62,7 +64,7 @@ struct Rmsnorm2dFwdPipelineOnePass
const YResidualWindow& y_residual_window_,
InvRmsWindow& inv_rms_window,
const SmoothScaleWindow& sm_scale_window_,
YScaleWindow& y_scale_window_,
YScaleWindow& y_scale_window,
ComputeDataType epsilon,
ck_tile::index_t row_size,
void* smem,
......@@ -77,12 +79,13 @@ struct Rmsnorm2dFwdPipelineOnePass
auto y_residual_window = make_tile_window(
y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto reduce_square_sum_func = ReduceOp::SquareAdd{};
auto reduce_sum_func = ReduceOp::Add{};
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
auto block_reduce2d_cross_warp_sync =
Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
int cur_count = 0;
int max_count =
block_tile_welford_calculate_max_count<typename Problem::BlockShape>(row_size);
auto block_norm_reduce = Policy::template GetBlockNormReduce<Problem>();
auto block_norm_reduce_sync = Policy::template GetBlockNormReduceSync<Problem>();
auto block_norm_reduce_cross_warp_sync =
Policy::template GetBlockNormReduceCrossWarpSync<Problem>();
auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window);
......@@ -105,19 +108,32 @@ struct Rmsnorm2dFwdPipelineOnePass
}
}
// Calculate square here because block norm reduce only supports naive mean.
auto square = make_static_distributed_tensor<ComputeDataType>(acc.get_tile_distribution());
sweep_tile(acc, [&](auto idx) { square(idx) = acc(idx) * acc(idx); });
// compute mean square each-thread->cross-lane->cross-warp
auto square_sum = block_reduce2d(acc,
reduce_square_sum_func.GetIdentityValue<ComputeDataType>(),
reduce_square_sum_func);
block_reduce2d_sync(square_sum, reduce_sum_func);
block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func);
using XTensorType = decltype(cast_tile<ComputeDataType>(x));
auto square_mean = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
clear_tile(square_mean);
block_norm_reduce(square, square_mean, cur_count, max_count);
block_norm_reduce_sync(square_mean, cur_count);
block_norm_reduce_cross_warp_sync(square_mean, cur_count, smem);
// compute inv-rms
auto inv_rms = tile_elementwise_in(
[&](const auto& v_) {
return type_convert<ComputeDataType>(1.0f) / (sqrt(v_ / row_size + epsilon));
if constexpr(kFastFDiv && std::is_same_v<ComputeDataType, float>)
{
return type_convert<ComputeDataType>(1.0f) *
__builtin_amdgcn_rcpf(sqrt(v_ + epsilon));
}
else
{
return type_convert<ComputeDataType>(1.0f) / (sqrt(v_ + epsilon));
}
},
square_sum);
square_mean);
if constexpr(kSaveInvRms)
store_tile(inv_rms_window, cast_tile<InvRmsDataType>(inv_rms));
......@@ -137,11 +153,11 @@ struct Rmsnorm2dFwdPipelineOnePass
if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
{
Epilogue{}(y_window_, sm_scale_window_, y_scale_window_, rmsn, smem);
Epilogue{}(y_window_, sm_scale_window_, y_scale_window, rmsn, smem);
}
else if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT)
{
Epilogue{}(y_window_, y_scale_window_, rmsn, smem);
Epilogue{}(y_window_, y_scale_window, rmsn, smem);
}
else
{
......
......@@ -31,6 +31,8 @@ struct Rmsnorm2dFwdPipelineTwoPass
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv;
static constexpr bool kWelford = Problem::Traits::kWelford;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
......@@ -82,16 +84,23 @@ struct Rmsnorm2dFwdPipelineTwoPass
index_t num_n_tile_iteration =
__builtin_amdgcn_readfirstlane(integer_divide_ceil(row_size, Block_N));
auto reduce_square_sum_func = ReduceOp::SquareAdd{};
auto reduce_sum_func = ReduceOp::Add{};
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
auto block_reduce2d_cross_warp_sync =
Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
// total number of count assume current iter have no pad(only last iter has pad)
constexpr index_t count_per_iter =
Problem::BlockShape::Repeat_N * Problem::BlockShape::Vector_N;
const index_t last_iter_n = row_size - (num_n_tile_iteration - 1) * Block_N;
using ComputeTensorType = decltype(cast_tile<ComputeDataType>(load_tile(x_window)));
auto square_sum = block_reduce2d.template MakeYBlockTile<ComputeTensorType>();
set_tile(square_sum, reduce_square_sum_func.GetIdentityValue<ComputeDataType>());
int cur_count = 0;
int max_count =
(num_n_tile_iteration - 1) * count_per_iter +
block_tile_welford_calculate_max_count<typename Problem::BlockShape>(last_iter_n);
auto block_norm_reduce = Policy::template GetBlockNormReduce<Problem>();
auto block_norm_reduce_sync = Policy::template GetBlockNormReduceSync<Problem>();
auto block_norm_reduce_cross_warp_sync =
Policy::template GetBlockNormReduceCrossWarpSync<Problem>();
using XTensorType = decltype(cast_tile<ComputeDataType>(load_tile(x_window)));
auto square_mean = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
clear_tile(square_mean);
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
......@@ -102,6 +111,7 @@ struct Rmsnorm2dFwdPipelineTwoPass
move_tile_window(x_residual_window, {0, Block_N});
auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD ||
kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
......@@ -116,18 +126,29 @@ struct Rmsnorm2dFwdPipelineTwoPass
}
}
block_reduce2d(acc, square_sum, reduce_square_sum_func);
// Calculate square here because block norm reduce only supports naive mean.
sweep_tile(acc, [&](auto idx) { acc(idx) *= acc(idx); });
block_norm_reduce(acc, square_mean, cur_count, max_count);
}
block_reduce2d_sync(square_sum, reduce_sum_func);
block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func);
block_norm_reduce_sync(square_mean, cur_count);
block_norm_reduce_cross_warp_sync(square_mean, cur_count, smem);
// compute inv-rms
auto inv_rms = tile_elementwise_in(
[&](const auto& v_) {
return type_convert<ComputeDataType>(1.0f) / (sqrt(v_ / row_size + epsilon));
if(kFastFDiv && std::is_same_v<ComputeDataType, float>)
{
return type_convert<ComputeDataType>(1.0f) *
__builtin_amdgcn_rcpf(sqrt(v_ + epsilon));
}
else
{
return type_convert<ComputeDataType>(1.0f) / (sqrt(v_ + epsilon));
}
},
square_sum);
square_mean);
if constexpr(kSaveInvRms)
store_tile(inv_rms_window, cast_tile<InvRmsDataType>(inv_rms));
......
......@@ -39,6 +39,8 @@ template<> struct Rmsnorm2dFusedQuantEnumName<Rmsnorm2dFusedQuantEnum::SMOOTH_DY
template <bool kPadN_,
bool kSaveInvRms_,
bool kFastFDiv_,
bool kWelford_,
bool kTwoPass_,
Rmsnorm2dFusedAddEnum kFusedAdd_,
Rmsnorm2dFusedQuantEnum kFusedQuant_>
......@@ -46,6 +48,8 @@ struct Rmsnorm2dFwdTraits
{
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveInvRms = kSaveInvRms_;
static constexpr bool kFastFDiv = kFastFDiv_;
static constexpr bool kWelford = kWelford_;
static constexpr bool kTwoPass = kTwoPass_;
static constexpr Rmsnorm2dFusedAddEnum kFusedAdd = kFusedAdd_;
static constexpr Rmsnorm2dFusedQuantEnum kFusedQuant = kFusedQuant_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment