"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "c778f4ad182ae1a9b86d8bbaa09d9c6267d66374"
Unverified Commit d5c8a334 authored by AMD-dteng's avatar AMD-dteng Committed by GitHub
Browse files

enable bias feature that add bias before adding residual (for rtpllm project) (#1741)



* 1. enable bias feature that add bias before adding residual; 2. change block size from 128->64 when m<64 in fp16

* delete comment

* 1.remove fmha change 2.change buffer name from bias to xbias

* Now bias can be used independently from fadd

* change kbias to kxbias

---------
Co-authored-by: default avatarfeli <felix.li@amd.com>
parent a6b761c3
This diff is collapsed.
...@@ -41,6 +41,7 @@ auto create_args(int argc, char* argv[]) ...@@ -41,6 +41,7 @@ auto create_args(int argc, char* argv[])
.insert("prec_sy", .insert("prec_sy",
"auto", "auto",
"output quant scale type, set auto will use fp32. used when fquant=1 or 2") "output quant scale type, set auto will use fp32. used when fquant=1 or 2")
.insert("xbias", "0", "add bias, 0:no add, 1:add bias before fadd")
.insert("fadd", "0", "fused-add, 0:no fused add, 1:preadd+store, 2:preadd only") .insert("fadd", "0", "fused-add, 0:no fused add, 1:preadd+store, 2:preadd only")
.insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant") .insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant")
.insert("warmup", "5", "cold iter") .insert("warmup", "5", "cold iter")
...@@ -93,6 +94,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -93,6 +94,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
int do_validation = arg_parser.get_int("v"); int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup"); int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat"); int repeat = arg_parser.get_int("repeat");
int xbias = arg_parser.get_int("xbias");
int fused_add = arg_parser.get_int("fadd"); int fused_add = arg_parser.get_int("fadd");
int fused_quant = arg_parser.get_int("fquant"); int fused_quant = arg_parser.get_int("fquant");
if(fused_quant == 1 && prec_o != "int8") if(fused_quant == 1 && prec_o != "int8")
...@@ -107,6 +109,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -107,6 +109,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
using XDataType = typename TypeConfig::XDataType; using XDataType = typename TypeConfig::XDataType;
using YDataType = typename TypeConfig::YDataType; using YDataType = typename TypeConfig::YDataType;
using XBiasDataType = typename TypeConfig::XBiasDataType;
using GammaDataType = typename TypeConfig::GammaDataType; using GammaDataType = typename TypeConfig::GammaDataType;
using BetaDataType = typename TypeConfig::BetaDataType; using BetaDataType = typename TypeConfig::BetaDataType;
using XResidualDataType = XDataType; using XResidualDataType = XDataType;
...@@ -121,6 +124,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -121,6 +124,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// host verify // host verify
ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1}); ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
ck_tile::HostTensor<XBiasDataType> x_bias_host({n});
ck_tile::HostTensor<GammaDataType> gamma_host({n}); ck_tile::HostTensor<GammaDataType> gamma_host({n});
ck_tile::HostTensor<BetaDataType> beta_host({n}); ck_tile::HostTensor<BetaDataType> beta_host({n});
...@@ -141,10 +145,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -141,10 +145,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host); ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<XResidualDataType>{-.5f, .5f}(x_residual_host); ck_tile::FillUniformDistribution<XResidualDataType>{-.5f, .5f}(x_residual_host);
ck_tile::FillUniformDistribution<XScaleDataType>{-1.f, 1.f}(x_scale_host); ck_tile::FillUniformDistribution<XScaleDataType>{-1.f, 1.f}(x_scale_host);
ck_tile::FillUniformDistribution<XBiasDataType>{-.5f, .5f}(x_bias_host);
ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host); ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host);
ck_tile::FillUniformDistribution<BetaDataType>{-.5f, .5f}(beta_host); ck_tile::FillUniformDistribution<BetaDataType>{-.5f, .5f}(beta_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem x_bias_buf(x_bias_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes());
...@@ -155,6 +161,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -155,6 +161,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem y_residual_buf(y_residual_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_residual_buf(y_residual_host.get_element_space_size_in_bytes());
x_buf.ToDevice(x_host.data()); x_buf.ToDevice(x_host.data());
x_bias_buf.ToDevice(x_bias_host.data());
gamma_buf.ToDevice(gamma_host.data()); gamma_buf.ToDevice(gamma_host.data());
beta_buf.ToDevice(beta_host.data()); beta_buf.ToDevice(beta_host.data());
x_residual_buf.ToDevice(x_residual_host.data()); x_residual_buf.ToDevice(x_residual_host.data());
...@@ -179,11 +186,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -179,11 +186,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
<< ", yr_stride:" << yr_stride << std::flush; << ", yr_stride:" << yr_stride << std::flush;
layernorm2d_fwd_traits traits{ layernorm2d_fwd_traits traits{
prec_i, prec_o, prec_sx, prec_sy, SaveMeanVar, fused_add, fused_quant}; prec_i, prec_o, prec_sx, prec_sy, SaveMeanVar, xbias, fused_add, fused_quant};
layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(), layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(),
fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr, fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr,
fused_quant == 1 ? x_scale_buf.GetDeviceBuffer() : nullptr, fused_quant == 1 ? x_scale_buf.GetDeviceBuffer() : nullptr,
x_bias_buf.GetDeviceBuffer(),
gamma_buf.GetDeviceBuffer(), gamma_buf.GetDeviceBuffer(),
beta_buf.GetDeviceBuffer(), beta_buf.GetDeviceBuffer(),
...@@ -210,8 +218,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -210,8 +218,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
return false; return false;
} }
std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(GammaDataType) * n + std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(XBiasDataType) * n +
sizeof(BetaDataType) * n + sizeof(YDataType) * m * n; sizeof(GammaDataType) * n + sizeof(BetaDataType) * n +
sizeof(YDataType) * m * n;
float gb_per_sec = num_byte / 1.E6 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush; std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush;
...@@ -221,6 +230,22 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -221,6 +230,22 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(do_validation) if(do_validation)
{ {
// reference // reference
if(xbias != 0)
{
// add bias before fadd
int M = x_host.mDesc.get_lengths()[0];
int N = x_host.mDesc.get_lengths()[1];
for(int idx_m = 0; idx_m < M; ++idx_m)
{
for(int idx_n = 0; idx_n < N; ++idx_n)
{
x_host(idx_m, idx_n) = ck_tile::type_convert<XDataType>(
ck_tile::type_convert<ComputeDataType>(x_host(idx_m, idx_n)) +
ck_tile::type_convert<ComputeDataType>(x_bias_host(idx_n)));
}
}
}
if(fused_add != 0) if(fused_add != 0)
{ {
// fused pre_add/pre_add_store // fused pre_add/pre_add_store
......
...@@ -16,6 +16,7 @@ struct LayerNormTypeConfig<ck_tile::half_t, OutType, XScaleDataType_, YScaleData ...@@ -16,6 +16,7 @@ struct LayerNormTypeConfig<ck_tile::half_t, OutType, XScaleDataType_, YScaleData
{ {
using XDataType = ck_tile::half_t; using XDataType = ck_tile::half_t;
using YDataType = OutType; using YDataType = OutType;
using XBiasDataType = ck_tile::half_t;
using GammaDataType = ck_tile::half_t; using GammaDataType = ck_tile::half_t;
using BetaDataType = ck_tile::half_t; using BetaDataType = ck_tile::half_t;
using MeanDataType = ck_tile::half_t; using MeanDataType = ck_tile::half_t;
...@@ -30,6 +31,7 @@ struct LayerNormTypeConfig<ck_tile::bf16_t, OutType, XScaleDataType_, YScaleData ...@@ -30,6 +31,7 @@ struct LayerNormTypeConfig<ck_tile::bf16_t, OutType, XScaleDataType_, YScaleData
{ {
using XDataType = ck_tile::bf16_t; using XDataType = ck_tile::bf16_t;
using YDataType = OutType; using YDataType = OutType;
using XBiasDataType = ck_tile::bf16_t;
using GammaDataType = ck_tile::bf16_t; using GammaDataType = ck_tile::bf16_t;
using BetaDataType = ck_tile::bf16_t; using BetaDataType = ck_tile::bf16_t;
using MeanDataType = ck_tile::bf16_t; using MeanDataType = ck_tile::bf16_t;
...@@ -57,6 +59,7 @@ struct layernorm2d_fwd_traits ...@@ -57,6 +59,7 @@ struct layernorm2d_fwd_traits
std::string prec_sy; // y-scale, used for [M*1] output for next layer std::string prec_sy; // y-scale, used for [M*1] output for next layer
bool save_mean_var; // bool save_mean_var; //
int xbias; // 0:no-bias, 1:add bias
int fused_add; // 0:no-add, 1:pre-add-store, 2:pre-add int fused_add; // 0:no-add, 1:pre-add-store, 2:pre-add
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
}; };
......
...@@ -15,6 +15,7 @@ struct Layernorm2dFwdHostArgs ...@@ -15,6 +15,7 @@ struct Layernorm2dFwdHostArgs
const void* p_x; // [m ,n], input, fp16/bf16 const void* p_x; // [m ,n], input, fp16/bf16
const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
const void* p_x_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used const void* p_x_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
const void* p_x_bias; // [1, n], bias, prec same as input
const void* p_gamma; // [1, n], gamma, prec same as input const void* p_gamma; // [1, n], gamma, prec same as input
const void* p_beta; // [1, n], beta, prec same as input const void* p_beta; // [1, n], beta, prec same as input
...@@ -43,6 +44,7 @@ struct Layernorm2dFwd ...@@ -43,6 +44,7 @@ struct Layernorm2dFwd
using Problem = typename Pipeline::Problem; using Problem = typename Pipeline::Problem;
using XDataType = remove_cvref_t<typename Problem::XDataType>; using XDataType = remove_cvref_t<typename Problem::XDataType>;
using XBiasDataType = remove_cvref_t<typename Problem::XBiasDataType>;
using GammaDataType = remove_cvref_t<typename Problem::GammaDataType>; using GammaDataType = remove_cvref_t<typename Problem::GammaDataType>;
using BetaDataType = remove_cvref_t<typename Problem::BetaDataType>; using BetaDataType = remove_cvref_t<typename Problem::BetaDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>; using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
...@@ -67,6 +69,7 @@ struct Layernorm2dFwd ...@@ -67,6 +69,7 @@ struct Layernorm2dFwd
static constexpr bool kPadM = false; // always no need to pad along M static constexpr bool kPadM = false; // always no need to pad along M
static constexpr bool kPadN = Problem::Traits::kPadN; static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool kTwoPass = Problem::Traits::kTwoPass; static constexpr bool kTwoPass = Problem::Traits::kTwoPass;
static constexpr auto kXbias = Problem::Traits::kXbias;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd; static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant; static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
...@@ -82,6 +85,7 @@ struct Layernorm2dFwd ...@@ -82,6 +85,7 @@ struct Layernorm2dFwd
const void* p_x; // [m ,n], input, fp16/bf16 const void* p_x; // [m ,n], input, fp16/bf16
const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
const void* p_x_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used const void* p_x_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
const void* p_x_bias; // [1, n], bias, prec same as input
const void* p_gamma; // [1, n], gamma, prec same as input const void* p_gamma; // [1, n], gamma, prec same as input
const void* p_beta; // [1, n], beta, prec same as input const void* p_beta; // [1, n], beta, prec same as input
...@@ -108,6 +112,7 @@ struct Layernorm2dFwd ...@@ -108,6 +112,7 @@ struct Layernorm2dFwd
return Kargs{hargs.p_x, return Kargs{hargs.p_x,
hargs.p_x_residual, hargs.p_x_residual,
hargs.p_x_scale, hargs.p_x_scale,
hargs.p_x_bias,
hargs.p_gamma, hargs.p_gamma,
hargs.p_beta, hargs.p_beta,
hargs.p_y, hargs.p_y,
...@@ -152,6 +157,7 @@ struct Layernorm2dFwd ...@@ -152,6 +157,7 @@ struct Layernorm2dFwd
using S_ = typename Problem::BlockShape; using S_ = typename Problem::BlockShape;
auto surfix = [&] () { auto surfix = [&] () {
std::string n; std::string n;
if (kXbias != Layernorm2dXBiasEnum::NO_BIAS) n += _SS_("_") + Layernorm2dXBiasEnumName<kXbias>::name;
if (kFusedAdd != Layernorm2dFusedAddEnum::NO_ADD) n += _SS_("_") + Layernorm2dFusedAddEnumName<kFusedAdd>::name; if (kFusedAdd != Layernorm2dFusedAddEnum::NO_ADD) n += _SS_("_") + Layernorm2dFusedAddEnumName<kFusedAdd>::name;
if (kFusedQuant != Layernorm2dFusedQuantEnum::NO_SWEEP) n += _SS_("_") + Layernorm2dFusedQuantEnumName<kFusedQuant>::name; if (kFusedQuant != Layernorm2dFusedQuantEnum::NO_SWEEP) n += _SS_("_") + Layernorm2dFusedQuantEnumName<kFusedQuant>::name;
if (kPadN) n += "_pn"; if (kPadN) n += "_pn";
...@@ -228,6 +234,27 @@ struct Layernorm2dFwd ...@@ -228,6 +234,27 @@ struct Layernorm2dFwd
} }
}(); }();
const auto x_bias_window = [&]() {
if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS)
{
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XBiasDataType*>(kargs.p_x_bias),
make_tuple(kargs.n),
make_tuple(1),
number<Vector_N>{},
number<1>{});
const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<false>{});
return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
}
else
{
return make_null_tile_window(make_tuple(number<Block_N>{}));
}
}();
const auto gamma_window = [&]() { const auto gamma_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>( const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const GammaDataType*>(kargs.p_gamma), static_cast<const GammaDataType*>(kargs.p_gamma),
...@@ -371,6 +398,7 @@ struct Layernorm2dFwd ...@@ -371,6 +398,7 @@ struct Layernorm2dFwd
Pipeline{}(x_window, Pipeline{}(x_window,
x_residual_window, x_residual_window,
x_bias_window,
gamma_window, gamma_window,
beta_window, beta_window,
y_window, y_window,
......
...@@ -18,6 +18,7 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -18,6 +18,7 @@ struct Layernorm2dFwdPipelineOnePass
using Policy = ck_tile::remove_cvref_t<Policy_>; using Policy = ck_tile::remove_cvref_t<Policy_>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>; using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using XBiasDataType = ck_tile::remove_cvref_t<typename Problem::XBiasDataType>;
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>; using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
using BetaDataType = ck_tile::remove_cvref_t<typename Problem::BetaDataType>; using BetaDataType = ck_tile::remove_cvref_t<typename Problem::BetaDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>; using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
...@@ -38,6 +39,7 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -38,6 +39,7 @@ struct Layernorm2dFwdPipelineOnePass
static constexpr bool kPadN = Problem::Traits::kPadN; static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv; static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv;
static constexpr bool kWelford = Problem::Traits::kWelford; static constexpr bool kWelford = Problem::Traits::kWelford;
static constexpr auto kXbias = Problem::Traits::kXbias;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd; static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant; static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
...@@ -55,6 +57,7 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -55,6 +57,7 @@ struct Layernorm2dFwdPipelineOnePass
template <typename XWindow, template <typename XWindow,
typename XResidualWindow, typename XResidualWindow,
typename XBiasWindow,
typename GammaWindow, typename GammaWindow,
typename BetaWindow, typename BetaWindow,
typename YWindow, typename YWindow,
...@@ -66,6 +69,7 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -66,6 +69,7 @@ struct Layernorm2dFwdPipelineOnePass
typename Epilogue> typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_, CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XResidualWindow& x_residual_window_, const XResidualWindow& x_residual_window_,
const XBiasWindow& x_bias_window_,
const GammaWindow& gamma_window_, const GammaWindow& gamma_window_,
const BetaWindow& beta_window_, const BetaWindow& beta_window_,
YWindow& y_window_, YWindow& y_window_,
...@@ -81,6 +85,8 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -81,6 +85,8 @@ struct Layernorm2dFwdPipelineOnePass
{ {
const auto x_window = const auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>()); make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
const auto x_bias_window = make_tile_window(
x_bias_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
const auto gamma_window = make_tile_window( const auto gamma_window = make_tile_window(
gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>()); gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
const auto beta_window = make_tile_window( const auto beta_window = make_tile_window(
...@@ -90,8 +96,9 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -90,8 +96,9 @@ struct Layernorm2dFwdPipelineOnePass
auto y_residual_window = make_tile_window( auto y_residual_window = make_tile_window(
y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>()); y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto x = load_tile(x_window); auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window); auto x_resi = load_tile(x_residual_window);
const auto x_bias = load_tile(x_bias_window);
int cur_count = 0; int cur_count = 0;
int max_count = int max_count =
...@@ -112,6 +119,15 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -112,6 +119,15 @@ struct Layernorm2dFwdPipelineOnePass
auto acc = cast_tile<ComputeDataType>(x); auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS)
{
sweep_tile(x, [&](auto idx) {
// compute x = bias + x
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
acc(idx) = type_convert<ComputeDataType>(x_bias[j_idx]) + acc(idx);
});
}
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE || if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD) kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
{ {
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
namespace ck_tile { namespace ck_tile {
template <typename XDataType_, template <typename XDataType_,
typename XBiasDataType_,
typename GammaDataType_, typename GammaDataType_,
typename BetaDataType_, typename BetaDataType_,
typename ComputeDataType_, typename ComputeDataType_,
...@@ -21,6 +22,7 @@ template <typename XDataType_, ...@@ -21,6 +22,7 @@ template <typename XDataType_,
struct Layernorm2dFwdPipelineProblem struct Layernorm2dFwdPipelineProblem
{ {
using XDataType = remove_cvref_t<XDataType_>; using XDataType = remove_cvref_t<XDataType_>;
using XBiasDataType = remove_cvref_t<XBiasDataType_>;
using GammaDataType = remove_cvref_t<GammaDataType_>; using GammaDataType = remove_cvref_t<GammaDataType_>;
using BetaDataType = remove_cvref_t<BetaDataType_>; using BetaDataType = remove_cvref_t<BetaDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>; using ComputeDataType = remove_cvref_t<ComputeDataType_>;
......
...@@ -17,6 +17,7 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -17,6 +17,7 @@ struct Layernorm2dFwdPipelineTwoPass
using Policy = ck_tile::remove_cvref_t<Policy_>; using Policy = ck_tile::remove_cvref_t<Policy_>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>; using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using XBiasDataType = ck_tile::remove_cvref_t<typename Problem::XBiasDataType>;
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>; using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
using BetaDataType = ck_tile::remove_cvref_t<typename Problem::BetaDataType>; using BetaDataType = ck_tile::remove_cvref_t<typename Problem::BetaDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>; using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
...@@ -37,6 +38,7 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -37,6 +38,7 @@ struct Layernorm2dFwdPipelineTwoPass
static constexpr bool kPadN = Problem::Traits::kPadN; static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv; static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv;
static constexpr bool kWelford = Problem::Traits::kWelford; static constexpr bool kWelford = Problem::Traits::kWelford;
static constexpr auto kXbias = Problem::Traits::kXbias;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd; static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant; static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
...@@ -54,6 +56,7 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -54,6 +56,7 @@ struct Layernorm2dFwdPipelineTwoPass
template <typename XWindow, template <typename XWindow,
typename XResidualWindow, typename XResidualWindow,
typename XBiasWindow,
typename GammaWindow, typename GammaWindow,
typename BetaWindow, typename BetaWindow,
typename YWindow, typename YWindow,
...@@ -65,6 +68,7 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -65,6 +68,7 @@ struct Layernorm2dFwdPipelineTwoPass
typename Epilogue> typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_, CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XResidualWindow& x_residual_window_, const XResidualWindow& x_residual_window_,
const XBiasWindow& x_bias_window_,
const GammaWindow& gamma_window_, const GammaWindow& gamma_window_,
const BetaWindow& beta_window_, const BetaWindow& beta_window_,
YWindow& y_window, YWindow& y_window,
...@@ -81,6 +85,8 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -81,6 +85,8 @@ struct Layernorm2dFwdPipelineTwoPass
static_assert(kWelford == true, "2 pass only supports welford merge"); static_assert(kWelford == true, "2 pass only supports welford merge");
auto x_window = auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>()); make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto x_bias_window = make_tile_window(
x_bias_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
auto gamma_window = make_tile_window( auto gamma_window = make_tile_window(
gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>()); gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
auto beta_window = make_tile_window( auto beta_window = make_tile_window(
...@@ -115,13 +121,24 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -115,13 +121,24 @@ struct Layernorm2dFwdPipelineTwoPass
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{ {
auto x = load_tile(x_window); auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window); auto x_resi = load_tile(x_residual_window);
const auto x_bias = load_tile(x_bias_window);
move_tile_window(x_window, {0, Block_N}); move_tile_window(x_window, {0, Block_N});
move_tile_window(x_residual_window, {0, Block_N}); move_tile_window(x_residual_window, {0, Block_N});
move_tile_window(x_bias_window, {Block_N});
auto acc = cast_tile<ComputeDataType>(x); auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS)
{
sweep_tile(x, [&](auto idx) {
// compute x = bias + x
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
acc(idx) = type_convert<ComputeDataType>(x_bias[j_idx]) + acc(idx);
});
}
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE || if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD) kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
{ {
...@@ -167,6 +184,7 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -167,6 +184,7 @@ struct Layernorm2dFwdPipelineTwoPass
move_tile_window(x_window, {0, -Block_N}); move_tile_window(x_window, {0, -Block_N});
move_tile_window(x_residual_window, {0, -Block_N}); move_tile_window(x_residual_window, {0, -Block_N});
move_tile_window(x_bias_window, {-Block_N});
move_tile_window(gamma_window, {stride_to_right_most_window}); move_tile_window(gamma_window, {stride_to_right_most_window});
move_tile_window(beta_window, {stride_to_right_most_window}); move_tile_window(beta_window, {stride_to_right_most_window});
move_tile_window(y_window, {0, stride_to_right_most_window}); move_tile_window(y_window, {0, stride_to_right_most_window});
...@@ -174,9 +192,19 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -174,9 +192,19 @@ struct Layernorm2dFwdPipelineTwoPass
// layernorm computation // layernorm computation
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{ {
auto x = load_tile(x_window); auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window); auto x_resi = load_tile(x_residual_window);
auto acc = cast_tile<ComputeDataType>(x); const auto x_bias = load_tile(x_bias_window);
auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS)
{
sweep_tile(x, [&](auto idx) {
// compute x = bias + x
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
acc(idx) = type_convert<ComputeDataType>(x_bias[j_idx]) + acc(idx);
});
}
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE || if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD) kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
...@@ -209,6 +237,7 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -209,6 +237,7 @@ struct Layernorm2dFwdPipelineTwoPass
move_tile_window(x_window, {0, -Block_N}); move_tile_window(x_window, {0, -Block_N});
move_tile_window(x_residual_window, {0, -Block_N}); move_tile_window(x_residual_window, {0, -Block_N});
move_tile_window(x_bias_window, {-Block_N});
move_tile_window(gamma_window, {-Block_N}); move_tile_window(gamma_window, {-Block_N});
move_tile_window(beta_window, {-Block_N}); move_tile_window(beta_window, {-Block_N});
move_tile_window(y_window, {0, -Block_N}); move_tile_window(y_window, {0, -Block_N});
......
...@@ -7,6 +7,19 @@ ...@@ -7,6 +7,19 @@
namespace ck_tile { namespace ck_tile {
enum class Layernorm2dXBiasEnum
{
NO_BIAS = 0,
// add bias before fused add
ADD_BIAS = 1,
};
// clang-format off
template<Layernorm2dXBiasEnum> struct Layernorm2dXBiasEnumName;
template<> struct Layernorm2dXBiasEnumName<Layernorm2dXBiasEnum::NO_BIAS> { static constexpr const char * name = "no"; };
template<> struct Layernorm2dXBiasEnumName<Layernorm2dXBiasEnum::ADD_BIAS> { static constexpr const char * name = "xbias"; };
// clang-format on
enum class Layernorm2dFusedAddEnum enum class Layernorm2dFusedAddEnum
{ {
NO_ADD = 0, NO_ADD = 0,
...@@ -42,6 +55,7 @@ template <bool kPadN_, ...@@ -42,6 +55,7 @@ template <bool kPadN_,
bool kFastFDiv_, bool kFastFDiv_,
bool kWelford_, bool kWelford_,
bool kTwoPass_, bool kTwoPass_,
Layernorm2dXBiasEnum kXbias_,
Layernorm2dFusedAddEnum kFusedAdd_, Layernorm2dFusedAddEnum kFusedAdd_,
Layernorm2dFusedQuantEnum kFusedQuant_> Layernorm2dFusedQuantEnum kFusedQuant_>
struct Layernorm2dFwdTraits struct Layernorm2dFwdTraits
...@@ -51,6 +65,7 @@ struct Layernorm2dFwdTraits ...@@ -51,6 +65,7 @@ struct Layernorm2dFwdTraits
static constexpr bool kFastFDiv = kFastFDiv_; static constexpr bool kFastFDiv = kFastFDiv_;
static constexpr bool kWelford = kWelford_; static constexpr bool kWelford = kWelford_;
static constexpr bool kTwoPass = kTwoPass_; static constexpr bool kTwoPass = kTwoPass_;
static constexpr Layernorm2dXBiasEnum kXbias = kXbias_;
static constexpr Layernorm2dFusedAddEnum kFusedAdd = kFusedAdd_; static constexpr Layernorm2dFusedAddEnum kFusedAdd = kFusedAdd_;
static constexpr Layernorm2dFusedQuantEnum kFusedQuant = kFusedQuant_; static constexpr Layernorm2dFusedQuantEnum kFusedQuant = kFusedQuant_;
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment