Unverified Commit cb6c5d39 authored by carlushuang's avatar carlushuang Committed by GitHub
Browse files

[CK_TILE] layernorm have more accurate residual (#1623)



* more accurate residual

* modify comment

* Fix literal case in README.md

---------
Co-authored-by: default avatarPo Yen Chen <PoYen.Chen@amd.com>
parent 03c6448b
...@@ -69,7 +69,7 @@ args: ...@@ -69,7 +69,7 @@ args:
``` ```
## limitations ## limitations
Note that `fquant=2`, `fadd=2`, `prec_sx/prec_sy` other than `fp32` are not by default generated. though our kernel template suppor this. (TBD: add some flag in generate.py) to generate those instance on demand. Beside, N>8192 case will by default using two-pass pipeline, and `-fquant=1/2` are not supported yet. Note that `fquant=2`, `fadd=2`, `prec_sx/prec_sy` other than `fp32` are not by default generated. Though our kernel template suppor this. (TBD: add some flag in generate.py) to generate those instance on demand. Beside, `N>8192` case will by default using two-pass pipeline, and `-fquant=1/2` are not supported yet. If need suport `N>8192` and `fused+residual+store`, you can use this example together with `12_smoothquant`, to construct layernorm+residual, and smoothquant, 2 kernels for this purpose.
``` ```
# some case # some case
...@@ -82,4 +82,4 @@ Note that `fquant=2`, `fadd=2`, `prec_sx/prec_sy` other than `fp32` are not by d ...@@ -82,4 +82,4 @@ Note that `fquant=2`, `fadd=2`, `prec_sx/prec_sy` other than `fp32` are not by d
# standard fp16 layernorm 2d, m=10. n=1024, fused-smooth-quant+fused-add-store, output in int8 # standard fp16 layernorm 2d, m=10. n=1024, fused-smooth-quant+fused-add-store, output in int8
./build/bin/tile_example_layernorm2d_fwd -m=10 -n=1024 -prec_o=int8 -fquant=1 -fadd=1 ./build/bin/tile_example_layernorm2d_fwd -m=10 -n=1024 -prec_o=int8 -fquant=1 -fadd=1
``` ```
\ No newline at end of file
...@@ -202,8 +202,9 @@ float layernorm2d_fwd_(const S& s, A a) ...@@ -202,8 +202,9 @@ float layernorm2d_fwd_(const S& s, A a)
using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem<ComputeDataType, YDataType, false, Traits_::kPadN, false>; using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem<ComputeDataType, YDataType, false, Traits_::kPadN, false>;
using Default2DEpilogue = ck_tile::Default2DEpilogue<Default2DEpilogueProblem>; using Default2DEpilogue = ck_tile::Default2DEpilogue<Default2DEpilogueProblem>;
using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem<ComputeDataType, YScaleDataType, YDataType, typename Traits_::Shape, static constexpr bool UseSmoothInputScale = Traits_::kFusedQuant == 1;
ck_tile::DynamicQuantEpilogueTraits<false, Traits_::kPadN, false, true/*max3*/>>; using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem<ComputeDataType, XScaleDataType, YScaleDataType, YDataType, typename Traits_::Shape,
ck_tile::DynamicQuantEpilogueTraits<false, Traits_::kPadN, UseSmoothInputScale, false, true/*max3*/>>;
using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue<DynamicQuantEpilogueProblem>; using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue<DynamicQuantEpilogueProblem>;
......
...@@ -8,17 +8,23 @@ ...@@ -8,17 +8,23 @@
namespace ck_tile { namespace ck_tile {
template <bool kPadM_, bool kPadN_, bool UseRawStore_ = true, bool UseMax3_ = false> template <bool kPadM_,
bool kPadN_,
bool UseSmoothInputScale_,
bool UseRawStore_ = true,
bool UseMax3_ = false>
struct DynamicQuantEpilogueTraits struct DynamicQuantEpilogueTraits
{ {
static constexpr bool kPadM = kPadM_; static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_; static constexpr bool kPadN = kPadN_;
static constexpr bool UseRawStore = UseRawStore_; static constexpr bool UseSmoothInputScale = UseSmoothInputScale_;
static constexpr bool UseMax3 = UseMax3_; static constexpr bool UseRawStore = UseRawStore_;
static constexpr bool UseMax3 = UseMax3_;
}; };
// this epilogue just store out a M*N matrix, row major // this epilogue just store out a M*N matrix, row major
template <typename AccDataType_, template <typename AccDataType_,
typename XScaleDataType_,
typename YScaleDataType_, typename YScaleDataType_,
typename ODataType_, typename ODataType_,
typename BlockShape_, typename BlockShape_,
...@@ -26,17 +32,20 @@ template <typename AccDataType_, ...@@ -26,17 +32,20 @@ template <typename AccDataType_,
struct DynamicQuantEpilogueProblem struct DynamicQuantEpilogueProblem
{ {
using AccDataType = remove_cvref_t<AccDataType_>; using AccDataType = remove_cvref_t<AccDataType_>;
using XScaleDataType = remove_cvref_t<XScaleDataType_>;
using YScaleDataType = remove_cvref_t<YScaleDataType_>; using YScaleDataType = remove_cvref_t<YScaleDataType_>;
using ODataType = remove_cvref_t<ODataType_>; using ODataType = remove_cvref_t<ODataType_>;
using BlockShape = remove_cvref_t<BlockShape_>; // can consum generic 2d shape using BlockShape = remove_cvref_t<BlockShape_>; // can consum generic 2d shape
using Traits = remove_cvref_t<Traits_>; using Traits = remove_cvref_t<Traits_>;
}; };
// TODO: we should put descriptor creation function into policy
template <typename Problem_, typename Policy_ = void> template <typename Problem_, typename Policy_ = void>
struct DynamicQuantEpilogue struct DynamicQuantEpilogue
{ {
using Problem = remove_cvref_t<Problem_>; using Problem = remove_cvref_t<Problem_>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>; using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using XScaleDataType = remove_cvref_t<typename Problem::XScaleDataType>;
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>; using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>; using ODataType = remove_cvref_t<typename Problem::ODataType>;
using BlockShape = remove_cvref_t<typename Problem::BlockShape>; using BlockShape = remove_cvref_t<typename Problem::BlockShape>;
...@@ -63,6 +72,33 @@ struct DynamicQuantEpilogue ...@@ -63,6 +72,33 @@ struct DynamicQuantEpilogue
return BlockReduce2dCrossWarpSync<P_>{}; return BlockReduce2dCrossWarpSync<P_>{};
} }
CK_TILE_DEVICE static constexpr auto MakeSmoothInputScaleTileDistribution()
{
using S = BlockShape;
#if 0
// don't remove this
// Note that if we set encoding purposely like this, you will result in compile fail
// TODO: x_scale create local-scratch to accept arbitrary acc input (with same length)
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M>,
tuple<sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
tuple<sequence<0, 1>, sequence<0, 1>>,
tuple<sequence<1, 1>, sequence<2, 2>>,
sequence<0, 1, 1>,
sequence<0, 0, 3>>{});
#else
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<S::WarpPerBlock_M, S::ThreadPerWarp_M>,
tuple<sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
tuple<sequence<0, 1>, sequence<0, 1>>,
tuple<sequence<0, 1>, sequence<1, 2>>,
sequence<1, 1>,
sequence<0, 3>>{});
#endif
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{ {
auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync(); auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync();
...@@ -71,8 +107,12 @@ struct DynamicQuantEpilogue ...@@ -71,8 +107,12 @@ struct DynamicQuantEpilogue
// TODO: this function assume store out vector size is the same as OAccTile last dimension size // TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ? // how do we fix this ?
template <typename ODramWindowTmp, typename YScaleWindow, typename OAccTile> template <typename ODramWindowTmp,
typename XScaleWindow,
typename YScaleWindow,
typename OAccTile>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
const XScaleWindow& x_scale_window_,
YScaleWindow& y_scale_window, YScaleWindow& y_scale_window,
const OAccTile& o_acc_tile, const OAccTile& o_acc_tile,
void* smem) void* smem)
...@@ -80,6 +120,18 @@ struct DynamicQuantEpilogue ...@@ -80,6 +120,18 @@ struct DynamicQuantEpilogue
auto reduce = GetBlockReduce2d(); auto reduce = GetBlockReduce2d();
auto reduce_sync = GetBlockReduce2dSync(); auto reduce_sync = GetBlockReduce2dSync();
auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync(); auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync();
const auto x_scale_window =
make_tile_window(x_scale_window_, MakeSmoothInputScaleTileDistribution());
auto x_scale = load_tile(x_scale_window);
auto o_acc_tmp = o_acc_tile;
sweep_tile(o_acc_tmp, [&](auto idx) {
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
const auto xs_ = type_convert<AccDataType>(x_scale[j_idx]);
o_acc_tmp(idx) = o_acc_tmp(idx) * xs_;
});
const auto f_absmax = [](auto acc_, auto v_0_) { return max(acc_, abs(v_0_)); }; const auto f_absmax = [](auto acc_, auto v_0_) { return max(acc_, abs(v_0_)); };
...@@ -87,10 +139,9 @@ struct DynamicQuantEpilogue ...@@ -87,10 +139,9 @@ struct DynamicQuantEpilogue
constexpr auto y_size_per_row = constexpr auto y_size_per_row =
OAccTile{}.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at( OAccTile{}.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at(
number<1>{}); number<1>{});
// constexpr auto y_size_per_row = OAccTile::get_lengths()[number<1>{}];
if constexpr(UseMax3 && std::is_same_v<AccDataType, float> && y_size_per_row % 2 == 0) if constexpr(UseMax3 && std::is_same_v<AccDataType, float> && y_size_per_row % 2 == 0)
{ {
// fast max3 implementation // fast max3+abs implementation
const auto f_max3 = [](auto acc_, auto v_0_, auto v_1_) { const auto f_max3 = [](auto acc_, auto v_0_, auto v_1_) {
float rtn; float rtn;
asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)" asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)"
...@@ -98,11 +149,11 @@ struct DynamicQuantEpilogue ...@@ -98,11 +149,11 @@ struct DynamicQuantEpilogue
: "v"(acc_), "v"(v_0_), "v"(v_1_)); : "v"(acc_), "v"(v_0_), "v"(v_1_));
return rtn; return rtn;
}; };
return reduce(o_acc_tile, type_convert<AccDataType>(0), f_max3, sequence<1, 2>{}); return reduce(o_acc_tmp, type_convert<AccDataType>(0), f_max3, sequence<1, 2>{});
} }
else else
{ {
return reduce(o_acc_tile, type_convert<AccDataType>(0), f_absmax); return reduce(o_acc_tmp, type_convert<AccDataType>(0), f_absmax);
} }
}(); }();
reduce_sync(row_absmax, f_absmax); reduce_sync(row_absmax, f_absmax);
...@@ -117,23 +168,20 @@ struct DynamicQuantEpilogue ...@@ -117,23 +168,20 @@ struct DynamicQuantEpilogue
store_tile(y_scale_window, cast_tile<YScaleDataType>(y_scale)); store_tile(y_scale_window, cast_tile<YScaleDataType>(y_scale));
auto o_acc_scaled_tile = sweep_tile(o_acc_tmp, [&](auto idx) {
make_static_distributed_tensor<AccDataType>(o_acc_tile.get_tile_distribution()); constexpr auto row_id = make_tuple(idx[number<0>{}]);
o_acc_tmp(idx) = o_acc_tmp[idx] / y_scale(row_id);
sweep_tile(o_acc_tile, [&](auto idx) {
constexpr auto row_id = make_tuple(idx[number<0>{}]);
o_acc_scaled_tile(idx) = o_acc_tile[idx] / y_scale(row_id);
}); });
// TODO: this is ugly // TODO: this is ugly
if constexpr(UseRawStore && (kPadM || kPadN)) if constexpr(UseRawStore && (kPadM || kPadN))
{ {
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_scaled_tile)); store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tmp));
buffer_store_fence(); buffer_store_fence();
} }
else else
{ {
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_scaled_tile)); store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tmp));
} }
} }
}; };
......
...@@ -45,7 +45,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy ...@@ -45,7 +45,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelford() CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelford()
{ {
using P_ = BlockWelfordProblem<typename Problem::XDataType, using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType, typename Problem::ComputeDataType,
typename Problem::BlockShape>; typename Problem::BlockShape>;
...@@ -55,7 +55,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy ...@@ -55,7 +55,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelfordSync() CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelfordSync()
{ {
using P_ = BlockWelfordProblem<typename Problem::XDataType, using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType, typename Problem::ComputeDataType,
typename Problem::BlockShape>; typename Problem::BlockShape>;
...@@ -65,7 +65,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy ...@@ -65,7 +65,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelfordCrossWarpSync() CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelfordCrossWarpSync()
{ {
using P_ = BlockWelfordProblem<typename Problem::XDataType, using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType, typename Problem::ComputeDataType,
typename Problem::BlockShape>; typename Problem::BlockShape>;
...@@ -77,13 +77,13 @@ struct Layernorm2dFwdPipelineDefaultPolicy ...@@ -77,13 +77,13 @@ struct Layernorm2dFwdPipelineDefaultPolicy
{ {
if constexpr(Problem::kNeedCrossWarpSync) if constexpr(Problem::kNeedCrossWarpSync)
{ {
using P_ = BlockWelfordProblem<typename Problem::XDataType, using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType, typename Problem::ComputeDataType,
typename Problem::BlockShape>; typename Problem::BlockShape>;
using block_welford = BlockWelford<P_>; using block_welford = BlockWelford<P_>;
using x_block_tile = using x_block_tile =
decltype(make_static_distributed_tensor<typename Problem::XDataType>( decltype(make_static_distributed_tensor<typename Problem::ComputeDataType>(
MakeXBlockTileDistribution<Problem>())); MakeXBlockTileDistribution<Problem>()));
using mean_var_block_tile = using mean_var_block_tile =
decltype(block_welford::template MakeMeanVarBlockTile<x_block_tile>()); decltype(block_welford::template MakeMeanVarBlockTile<x_block_tile>());
......
...@@ -87,12 +87,9 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -87,12 +87,9 @@ struct Layernorm2dFwdPipelineOnePass
x_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>()); x_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto y_residual_window = make_tile_window( auto y_residual_window = make_tile_window(
y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>()); y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
const auto x_scale_window = make_tile_window(
x_scale_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
auto x = load_tile(x_window); auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window); auto x_resi = load_tile(x_residual_window);
auto x_scale = load_tile(x_scale_window);
int cur_count = 0; int cur_count = 0;
int max_count = int max_count =
...@@ -106,21 +103,21 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -106,21 +103,21 @@ struct Layernorm2dFwdPipelineOnePass
const auto gamma = load_tile(gamma_window); const auto gamma = load_tile(gamma_window);
const auto beta = load_tile(beta_window); const auto beta = load_tile(beta_window);
auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE || if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD) kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
{ {
sweep_tile(x_resi, [&](auto idx) { sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x // compute x = x_resi + x
auto re_ = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
type_convert<ComputeDataType>(x(idx));
x(idx) = type_convert<XDataType>(re_);
}); });
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE) if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
store_tile(y_residual_window, x); store_tile(y_residual_window, cast_tile<YResidualDataType>(acc));
} }
// compute welford each-thread->cross-lane->cross-warp // compute welford each-thread->cross-lane->cross-warp
auto [mean, var] = block_welford(x, cur_count, max_count); auto [mean, var] = block_welford(acc, cur_count, max_count);
block_welford_sync(mean, var, cur_count); block_welford_sync(mean, var, cur_count);
block_welford_cross_warp_sync(mean, var, cur_count, smem); block_welford_cross_warp_sync(mean, var, cur_count, smem);
block_tile_welford_post_scale_var(var, cur_count); block_tile_welford_post_scale_var(var, cur_count);
...@@ -138,7 +135,7 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -138,7 +135,7 @@ struct Layernorm2dFwdPipelineOnePass
store_tile(inv_std_window, cast_tile<InvStdDataType>(inv_std)); store_tile(inv_std_window, cast_tile<InvStdDataType>(inv_std));
// layernorm computation // layernorm computation
auto ln = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution()); auto ln = make_static_distributed_tensor<ComputeDataType>(acc.get_tile_distribution());
sweep_tile(ln, [&, mean_ = mean](auto idx) { sweep_tile(ln, [&, mean_ = mean](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]); constexpr auto i_idx = make_tuple(idx[number<0>{}]);
constexpr auto j_idx = make_tuple(idx[number<1>{}]); constexpr auto j_idx = make_tuple(idx[number<1>{}]);
...@@ -146,26 +143,15 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -146,26 +143,15 @@ struct Layernorm2dFwdPipelineOnePass
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]); const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]); const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);
const auto x_ = type_convert<ComputeDataType>(x[idx]); auto ln_ = (acc[idx] - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
auto ln_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
ln(idx) = ln_; ln(idx) = ln_;
}); });
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
{
// smooth-quant pre-scale, then run rowwise-quant
sweep_tile(ln, [&](auto idx) {
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
const auto xs_ = type_convert<ComputeDataType>(x_scale[j_idx]);
ln(idx) = ln(idx) * xs_;
});
}
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT || if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT ||
kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
{ {
Epilogue{}(y_window_, y_scale_window, ln, smem); Epilogue{}(y_window_, x_scale_window_, y_scale_window, ln, smem);
} }
else else
Epilogue{}(y_window_, ln); Epilogue{}(y_window_, ln);
......
...@@ -106,7 +106,7 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -106,7 +106,7 @@ struct Layernorm2dFwdPipelineTwoPass
auto block_welford_cross_warp_sync = auto block_welford_cross_warp_sync =
Policy::template GetBlockWelfordCrossWarpSync<Problem>(); Policy::template GetBlockWelfordCrossWarpSync<Problem>();
using XTensorType = decltype(load_tile(x_window)); using XTensorType = decltype(cast_tile<ComputeDataType>(load_tile(x_window)));
auto mean = block_welford.template MakeMeanVarBlockTile<XTensorType>(); auto mean = block_welford.template MakeMeanVarBlockTile<XTensorType>();
auto var = block_welford.template MakeMeanVarBlockTile<XTensorType>(); auto var = block_welford.template MakeMeanVarBlockTile<XTensorType>();
...@@ -117,22 +117,22 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -117,22 +117,22 @@ struct Layernorm2dFwdPipelineTwoPass
move_tile_window(x_window, {0, Block_N}); move_tile_window(x_window, {0, Block_N});
move_tile_window(x_residual_window, {0, Block_N}); move_tile_window(x_residual_window, {0, Block_N});
auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE || if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD) kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
{ {
sweep_tile(x_resi, [&](auto idx) { sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x // compute x = x_resi + x
auto re_ = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
type_convert<ComputeDataType>(x(idx));
x(idx) = type_convert<XDataType>(re_);
}); });
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE) if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
{ {
store_tile(y_residual_window, x); store_tile(y_residual_window, cast_tile<YResidualDataType>(acc));
move_tile_window(y_residual_window, {0, Block_N}); move_tile_window(y_residual_window, {0, Block_N});
} }
} }
block_welford(x, mean, var, cur_count, max_count); block_welford(acc, mean, var, cur_count, max_count);
} }
block_welford_sync(mean, var, cur_count); block_welford_sync(mean, var, cur_count);
...@@ -166,21 +166,21 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -166,21 +166,21 @@ struct Layernorm2dFwdPipelineTwoPass
{ {
auto x = load_tile(x_window); auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window); auto x_resi = load_tile(x_residual_window);
auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE || if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD) kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
{ {
sweep_tile(x_resi, [&](auto idx) { sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x // compute x = x_resi + x
auto re_ = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
type_convert<ComputeDataType>(x(idx));
x(idx) = type_convert<XDataType>(re_);
}); });
} }
// load gamma/beta (TODO: support no gamma/beta?) // load gamma/beta (TODO: support no gamma/beta?)
const auto gamma = load_tile(gamma_window); const auto gamma = load_tile(gamma_window);
const auto beta = load_tile(beta_window); const auto beta = load_tile(beta_window);
auto ln = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution()); auto ln = make_static_distributed_tensor<ComputeDataType>(acc.get_tile_distribution());
sweep_tile(ln, [&, mean_ = mean](auto idx) { sweep_tile(ln, [&, mean_ = mean](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]); constexpr auto i_idx = make_tuple(idx[number<0>{}]);
...@@ -189,8 +189,7 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -189,8 +189,7 @@ struct Layernorm2dFwdPipelineTwoPass
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]); const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]); const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);
const auto x_ = type_convert<ComputeDataType>(x[idx]); auto ln_ = (acc(idx) - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
auto ln_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
ln(idx) = ln_; ln(idx) = ln_;
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment