Commit f199c936 authored by AMD-dteng's avatar AMD-dteng
Browse files

1.remove fmha change 2.change buffer name from bias to xbias

parent ec07718a
...@@ -410,8 +410,8 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: ...@@ -410,8 +410,8 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16': if dtype == 'fp16' or dtype == 'bf16':
return { return {
'32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, -1), '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, -1),
'64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1), #'64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1),
#'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1), '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1),
## '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1), ## '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1), '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1),
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1), '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1),
......
...@@ -198,7 +198,7 @@ float layernorm2d_fwd_(const S& s, A a) ...@@ -198,7 +198,7 @@ float layernorm2d_fwd_(const S& s, A a)
static_cast<ck_tile::Layernorm2dFusedQuantEnum>(Traits_::kFusedQuant)>; static_cast<ck_tile::Layernorm2dFusedQuantEnum>(Traits_::kFusedQuant)>;
using PipelineProblem = ck_tile::Layernorm2dFwdPipelineProblem< using PipelineProblem = ck_tile::Layernorm2dFwdPipelineProblem<
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::XDataType, typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::XDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::BiasDataType, typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::XBiasDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::GammaDataType, typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::GammaDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::BetaDataType, typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::BetaDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::ComputeDataType, typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::ComputeDataType,
...@@ -330,7 +330,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -330,7 +330,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
@dataclass @dataclass
class k_problem: class k_problem:
F_XDataType : str F_XDataType : str
F_BiasDataType : str F_XBiasDataType : str
F_GammaDataType : str F_GammaDataType : str
F_BetaDataType : str F_BetaDataType : str
F_ComputeDataType : str F_ComputeDataType : str
......
...@@ -109,7 +109,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -109,7 +109,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
using XDataType = typename TypeConfig::XDataType; using XDataType = typename TypeConfig::XDataType;
using YDataType = typename TypeConfig::YDataType; using YDataType = typename TypeConfig::YDataType;
using BiasDataType = typename TypeConfig::BiasDataType; using XBiasDataType = typename TypeConfig::XBiasDataType;
using GammaDataType = typename TypeConfig::GammaDataType; using GammaDataType = typename TypeConfig::GammaDataType;
using BetaDataType = typename TypeConfig::BetaDataType; using BetaDataType = typename TypeConfig::BetaDataType;
using XResidualDataType = XDataType; using XResidualDataType = XDataType;
...@@ -124,7 +124,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -124,7 +124,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// host verify // host verify
ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1}); ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
ck_tile::HostTensor<BiasDataType> bias_host({n}); ck_tile::HostTensor<XBiasDataType> x_bias_host({n});
ck_tile::HostTensor<GammaDataType> gamma_host({n}); ck_tile::HostTensor<GammaDataType> gamma_host({n});
ck_tile::HostTensor<BetaDataType> beta_host({n}); ck_tile::HostTensor<BetaDataType> beta_host({n});
...@@ -145,12 +145,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -145,12 +145,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host); ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<XResidualDataType>{-.5f, .5f}(x_residual_host); ck_tile::FillUniformDistribution<XResidualDataType>{-.5f, .5f}(x_residual_host);
ck_tile::FillUniformDistribution<XScaleDataType>{-1.f, 1.f}(x_scale_host); ck_tile::FillUniformDistribution<XScaleDataType>{-1.f, 1.f}(x_scale_host);
ck_tile::FillUniformDistribution<BiasDataType>{-.5f, .5f}(bias_host); ck_tile::FillUniformDistribution<XBiasDataType>{-.5f, .5f}(x_bias_host);
ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host); ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host);
ck_tile::FillUniformDistribution<BetaDataType>{-.5f, .5f}(beta_host); ck_tile::FillUniformDistribution<BetaDataType>{-.5f, .5f}(beta_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem x_bias_buf(x_bias_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes());
...@@ -161,7 +161,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -161,7 +161,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem y_residual_buf(y_residual_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_residual_buf(y_residual_host.get_element_space_size_in_bytes());
x_buf.ToDevice(x_host.data()); x_buf.ToDevice(x_host.data());
bias_buf.ToDevice(bias_host.data()); x_bias_buf.ToDevice(x_bias_host.data());
gamma_buf.ToDevice(gamma_host.data()); gamma_buf.ToDevice(gamma_host.data());
beta_buf.ToDevice(beta_host.data()); beta_buf.ToDevice(beta_host.data());
x_residual_buf.ToDevice(x_residual_host.data()); x_residual_buf.ToDevice(x_residual_host.data());
...@@ -191,7 +191,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -191,7 +191,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(), layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(),
fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr, fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr,
fused_quant == 1 ? x_scale_buf.GetDeviceBuffer() : nullptr, fused_quant == 1 ? x_scale_buf.GetDeviceBuffer() : nullptr,
bias_buf.GetDeviceBuffer(), x_bias_buf.GetDeviceBuffer(),
gamma_buf.GetDeviceBuffer(), gamma_buf.GetDeviceBuffer(),
beta_buf.GetDeviceBuffer(), beta_buf.GetDeviceBuffer(),
...@@ -218,8 +218,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -218,8 +218,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
return false; return false;
} }
std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(BiasDataType) * n + sizeof(GammaDataType) * n + std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(XBiasDataType) * n +
sizeof(BetaDataType) * n + sizeof(YDataType) * m * n; sizeof(GammaDataType) * n + sizeof(BetaDataType) * n +
sizeof(YDataType) * m * n;
float gb_per_sec = num_byte / 1.E6 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush; std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush;
...@@ -240,7 +241,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -240,7 +241,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
{ {
for(int idx_n = 0; idx_n < N; ++idx_n) for(int idx_n = 0; idx_n < N; ++idx_n)
{ {
x_host(idx_m, idx_n) = ck_tile::type_convert<XDataType>(ck_tile::type_convert<ComputeDataType>(x_host(idx_m, idx_n)) + ck_tile::type_convert<ComputeDataType>(bias_host(idx_n))); x_host(idx_m, idx_n) = ck_tile::type_convert<XDataType>(
ck_tile::type_convert<ComputeDataType>(x_host(idx_m, idx_n)) +
ck_tile::type_convert<ComputeDataType>(x_bias_host(idx_n)));
} }
} }
} }
......
...@@ -16,7 +16,7 @@ struct LayerNormTypeConfig<ck_tile::half_t, OutType, XScaleDataType_, YScaleData ...@@ -16,7 +16,7 @@ struct LayerNormTypeConfig<ck_tile::half_t, OutType, XScaleDataType_, YScaleData
{ {
using XDataType = ck_tile::half_t; using XDataType = ck_tile::half_t;
using YDataType = OutType; using YDataType = OutType;
using BiasDataType = ck_tile::half_t; using XBiasDataType = ck_tile::half_t;
using GammaDataType = ck_tile::half_t; using GammaDataType = ck_tile::half_t;
using BetaDataType = ck_tile::half_t; using BetaDataType = ck_tile::half_t;
using MeanDataType = ck_tile::half_t; using MeanDataType = ck_tile::half_t;
...@@ -31,7 +31,7 @@ struct LayerNormTypeConfig<ck_tile::bf16_t, OutType, XScaleDataType_, YScaleData ...@@ -31,7 +31,7 @@ struct LayerNormTypeConfig<ck_tile::bf16_t, OutType, XScaleDataType_, YScaleData
{ {
using XDataType = ck_tile::bf16_t; using XDataType = ck_tile::bf16_t;
using YDataType = OutType; using YDataType = OutType;
using BiasDataType = ck_tile::bf16_t; using XBiasDataType = ck_tile::bf16_t;
using GammaDataType = ck_tile::bf16_t; using GammaDataType = ck_tile::bf16_t;
using BetaDataType = ck_tile::bf16_t; using BetaDataType = ck_tile::bf16_t;
using MeanDataType = ck_tile::bf16_t; using MeanDataType = ck_tile::bf16_t;
......
...@@ -15,7 +15,7 @@ struct Layernorm2dFwdHostArgs ...@@ -15,7 +15,7 @@ struct Layernorm2dFwdHostArgs
const void* p_x; // [m ,n], input, fp16/bf16 const void* p_x; // [m ,n], input, fp16/bf16
const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
const void* p_x_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used const void* p_x_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
const void* p_bias; // [1, n], bias, prec same as input const void* p_x_bias; // [1, n], bias, prec same as input
const void* p_gamma; // [1, n], gamma, prec same as input const void* p_gamma; // [1, n], gamma, prec same as input
const void* p_beta; // [1, n], beta, prec same as input const void* p_beta; // [1, n], beta, prec same as input
...@@ -44,7 +44,7 @@ struct Layernorm2dFwd ...@@ -44,7 +44,7 @@ struct Layernorm2dFwd
using Problem = typename Pipeline::Problem; using Problem = typename Pipeline::Problem;
using XDataType = remove_cvref_t<typename Problem::XDataType>; using XDataType = remove_cvref_t<typename Problem::XDataType>;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>; using XBiasDataType = remove_cvref_t<typename Problem::XBiasDataType>;
using GammaDataType = remove_cvref_t<typename Problem::GammaDataType>; using GammaDataType = remove_cvref_t<typename Problem::GammaDataType>;
using BetaDataType = remove_cvref_t<typename Problem::BetaDataType>; using BetaDataType = remove_cvref_t<typename Problem::BetaDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>; using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
...@@ -85,7 +85,7 @@ struct Layernorm2dFwd ...@@ -85,7 +85,7 @@ struct Layernorm2dFwd
const void* p_x; // [m ,n], input, fp16/bf16 const void* p_x; // [m ,n], input, fp16/bf16
const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
const void* p_x_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used const void* p_x_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
const void* p_bias; // [1, n], bias, prec same as input const void* p_x_bias; // [1, n], bias, prec same as input
const void* p_gamma; // [1, n], gamma, prec same as input const void* p_gamma; // [1, n], gamma, prec same as input
const void* p_beta; // [1, n], beta, prec same as input const void* p_beta; // [1, n], beta, prec same as input
...@@ -112,7 +112,7 @@ struct Layernorm2dFwd ...@@ -112,7 +112,7 @@ struct Layernorm2dFwd
return Kargs{hargs.p_x, return Kargs{hargs.p_x,
hargs.p_x_residual, hargs.p_x_residual,
hargs.p_x_scale, hargs.p_x_scale,
hargs.p_bias, hargs.p_x_bias,
hargs.p_gamma, hargs.p_gamma,
hargs.p_beta, hargs.p_beta,
hargs.p_y, hargs.p_y,
...@@ -234,11 +234,11 @@ struct Layernorm2dFwd ...@@ -234,11 +234,11 @@ struct Layernorm2dFwd
} }
}(); }();
const auto bias_window = [&]() { const auto x_bias_window = [&]() {
if constexpr(kBias == Layernorm2dBiasEnum::ADD_BIAS) if constexpr(kBias == Layernorm2dBiasEnum::ADD_BIAS)
{ {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>( const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const BiasDataType*>(kargs.p_bias), static_cast<const XBiasDataType*>(kargs.p_x_bias),
make_tuple(kargs.n), make_tuple(kargs.n),
make_tuple(1), make_tuple(1),
number<Vector_N>{}, number<Vector_N>{},
...@@ -398,7 +398,7 @@ struct Layernorm2dFwd ...@@ -398,7 +398,7 @@ struct Layernorm2dFwd
Pipeline{}(x_window, Pipeline{}(x_window,
x_residual_window, x_residual_window,
bias_window, x_bias_window,
gamma_window, gamma_window,
beta_window, beta_window,
y_window, y_window,
......
...@@ -18,7 +18,7 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -18,7 +18,7 @@ struct Layernorm2dFwdPipelineOnePass
using Policy = ck_tile::remove_cvref_t<Policy_>; using Policy = ck_tile::remove_cvref_t<Policy_>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>; using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using BiasDataType = ck_tile::remove_cvref_t<typename Problem::BiasDataType>; using XBiasDataType = ck_tile::remove_cvref_t<typename Problem::XBiasDataType>;
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>; using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
using BetaDataType = ck_tile::remove_cvref_t<typename Problem::BetaDataType>; using BetaDataType = ck_tile::remove_cvref_t<typename Problem::BetaDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>; using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
...@@ -56,7 +56,7 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -56,7 +56,7 @@ struct Layernorm2dFwdPipelineOnePass
template <typename XWindow, template <typename XWindow,
typename XResidualWindow, typename XResidualWindow,
typename BiasWindow, typename XBiasWindow,
typename GammaWindow, typename GammaWindow,
typename BetaWindow, typename BetaWindow,
typename YWindow, typename YWindow,
...@@ -68,7 +68,7 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -68,7 +68,7 @@ struct Layernorm2dFwdPipelineOnePass
typename Epilogue> typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_, CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XResidualWindow& x_residual_window_, const XResidualWindow& x_residual_window_,
const BiasWindow& bias_window_, const XBiasWindow& x_bias_window_,
const GammaWindow& gamma_window_, const GammaWindow& gamma_window_,
const BetaWindow& beta_window_, const BetaWindow& beta_window_,
YWindow& y_window_, YWindow& y_window_,
...@@ -84,8 +84,8 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -84,8 +84,8 @@ struct Layernorm2dFwdPipelineOnePass
{ {
const auto x_window = const auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>()); make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
const auto bias_window = make_tile_window( const auto x_bias_window = make_tile_window(
bias_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>()); x_bias_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
const auto gamma_window = make_tile_window( const auto gamma_window = make_tile_window(
gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>()); gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
const auto beta_window = make_tile_window( const auto beta_window = make_tile_window(
...@@ -97,7 +97,7 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -97,7 +97,7 @@ struct Layernorm2dFwdPipelineOnePass
auto x = load_tile(x_window); auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window); auto x_resi = load_tile(x_residual_window);
const auto bias = load_tile(bias_window); const auto x_bias = load_tile(x_bias_window);
int cur_count = 0; int cur_count = 0;
int max_count = int max_count =
...@@ -121,7 +121,7 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -121,7 +121,7 @@ struct Layernorm2dFwdPipelineOnePass
{ {
// compute x = bias + x // compute x = bias + x
constexpr auto j_idx = make_tuple(idx[number<1>{}]); constexpr auto j_idx = make_tuple(idx[number<1>{}]);
acc(idx) = type_convert<ComputeDataType>(bias[j_idx]) + acc(idx); acc(idx) = type_convert<ComputeDataType>(x_bias[j_idx]) + acc(idx);
} }
// compute x = x_resi + x // compute x = x_resi + x
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx); acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
namespace ck_tile { namespace ck_tile {
template <typename XDataType_, template <typename XDataType_,
typename BiasDataType_, typename XBiasDataType_,
typename GammaDataType_, typename GammaDataType_,
typename BetaDataType_, typename BetaDataType_,
typename ComputeDataType_, typename ComputeDataType_,
...@@ -22,7 +22,7 @@ template <typename XDataType_, ...@@ -22,7 +22,7 @@ template <typename XDataType_,
struct Layernorm2dFwdPipelineProblem struct Layernorm2dFwdPipelineProblem
{ {
using XDataType = remove_cvref_t<XDataType_>; using XDataType = remove_cvref_t<XDataType_>;
using BiasDataType = remove_cvref_t<BiasDataType_>; using XBiasDataType = remove_cvref_t<XBiasDataType_>;
using GammaDataType = remove_cvref_t<GammaDataType_>; using GammaDataType = remove_cvref_t<GammaDataType_>;
using BetaDataType = remove_cvref_t<BetaDataType_>; using BetaDataType = remove_cvref_t<BetaDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>; using ComputeDataType = remove_cvref_t<ComputeDataType_>;
......
...@@ -17,7 +17,7 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -17,7 +17,7 @@ struct Layernorm2dFwdPipelineTwoPass
using Policy = ck_tile::remove_cvref_t<Policy_>; using Policy = ck_tile::remove_cvref_t<Policy_>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>; using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using BiasDataType = ck_tile::remove_cvref_t<typename Problem::BiasDataType>; using XBiasDataType = ck_tile::remove_cvref_t<typename Problem::XBiasDataType>;
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>; using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
using BetaDataType = ck_tile::remove_cvref_t<typename Problem::BetaDataType>; using BetaDataType = ck_tile::remove_cvref_t<typename Problem::BetaDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>; using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
...@@ -55,7 +55,7 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -55,7 +55,7 @@ struct Layernorm2dFwdPipelineTwoPass
template <typename XWindow, template <typename XWindow,
typename XResidualWindow, typename XResidualWindow,
typename BiasWindow, typename XBiasWindow,
typename GammaWindow, typename GammaWindow,
typename BetaWindow, typename BetaWindow,
typename YWindow, typename YWindow,
...@@ -67,7 +67,7 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -67,7 +67,7 @@ struct Layernorm2dFwdPipelineTwoPass
typename Epilogue> typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_, CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XResidualWindow& x_residual_window_, const XResidualWindow& x_residual_window_,
const BiasWindow& bias_window_, const XBiasWindow& x_bias_window_,
const GammaWindow& gamma_window_, const GammaWindow& gamma_window_,
const BetaWindow& beta_window_, const BetaWindow& beta_window_,
YWindow& y_window, YWindow& y_window,
...@@ -83,8 +83,8 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -83,8 +83,8 @@ struct Layernorm2dFwdPipelineTwoPass
{ {
auto x_window = auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>()); make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto bias_window = make_tile_window( auto x_bias_window = make_tile_window(
bias_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>()); x_bias_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
auto gamma_window = make_tile_window( auto gamma_window = make_tile_window(
gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>()); gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
auto beta_window = make_tile_window( auto beta_window = make_tile_window(
...@@ -121,11 +121,11 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -121,11 +121,11 @@ struct Layernorm2dFwdPipelineTwoPass
{ {
auto x = load_tile(x_window); auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window); auto x_resi = load_tile(x_residual_window);
const auto bias = load_tile(bias_window); const auto x_bias = load_tile(x_bias_window);
move_tile_window(x_window, {0, Block_N}); move_tile_window(x_window, {0, Block_N});
move_tile_window(x_residual_window, {0, Block_N}); move_tile_window(x_residual_window, {0, Block_N});
move_tile_window(bias_window, {Block_N}); move_tile_window(x_bias_window, {Block_N});
auto acc = cast_tile<ComputeDataType>(x); auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE || if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
...@@ -136,7 +136,7 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -136,7 +136,7 @@ struct Layernorm2dFwdPipelineTwoPass
{ {
// compute x = bias + x // compute x = bias + x
constexpr auto j_idx = make_tuple(idx[number<1>{}]); constexpr auto j_idx = make_tuple(idx[number<1>{}]);
acc(idx) = type_convert<ComputeDataType>(bias[j_idx]) + acc(idx); acc(idx) = type_convert<ComputeDataType>(x_bias[j_idx]) + acc(idx);
} }
// compute x = x_resi + x // compute x = x_resi + x
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx); acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
...@@ -179,7 +179,7 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -179,7 +179,7 @@ struct Layernorm2dFwdPipelineTwoPass
move_tile_window(x_window, {0, -Block_N}); move_tile_window(x_window, {0, -Block_N});
move_tile_window(x_residual_window, {0, -Block_N}); move_tile_window(x_residual_window, {0, -Block_N});
move_tile_window(bias_window, {-Block_N}); move_tile_window(x_bias_window, {-Block_N});
move_tile_window(gamma_window, {stride_to_right_most_window}); move_tile_window(gamma_window, {stride_to_right_most_window});
move_tile_window(beta_window, {stride_to_right_most_window}); move_tile_window(beta_window, {stride_to_right_most_window});
move_tile_window(y_window, {0, stride_to_right_most_window}); move_tile_window(y_window, {0, stride_to_right_most_window});
...@@ -189,7 +189,7 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -189,7 +189,7 @@ struct Layernorm2dFwdPipelineTwoPass
{ {
auto x = load_tile(x_window); auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window); auto x_resi = load_tile(x_residual_window);
const auto bias = load_tile(bias_window); const auto x_bias = load_tile(x_bias_window);
auto acc = cast_tile<ComputeDataType>(x); auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE || if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
...@@ -200,7 +200,7 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -200,7 +200,7 @@ struct Layernorm2dFwdPipelineTwoPass
{ {
// compute x = bias + x // compute x = bias + x
constexpr auto j_idx = make_tuple(idx[number<1>{}]); constexpr auto j_idx = make_tuple(idx[number<1>{}]);
acc(idx) = type_convert<ComputeDataType>(bias[j_idx]) + acc(idx); acc(idx) = type_convert<ComputeDataType>(x_bias[j_idx]) + acc(idx);
} }
// compute x = x_resi + x // compute x = x_resi + x
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx); acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
...@@ -229,7 +229,7 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -229,7 +229,7 @@ struct Layernorm2dFwdPipelineTwoPass
move_tile_window(x_window, {0, -Block_N}); move_tile_window(x_window, {0, -Block_N});
move_tile_window(x_residual_window, {0, -Block_N}); move_tile_window(x_residual_window, {0, -Block_N});
move_tile_window(bias_window, {-Block_N}); move_tile_window(x_bias_window, {-Block_N});
move_tile_window(gamma_window, {-Block_N}); move_tile_window(gamma_window, {-Block_N});
move_tile_window(beta_window, {-Block_N}); move_tile_window(beta_window, {-Block_N});
move_tile_window(y_window, {0, -Block_N}); move_tile_window(y_window, {0, -Block_N});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment