"vscode:/vscode.git/clone" did not exist on "41afddffaee0d36f4540dbab546c95bfb5603cd3"
Commit 4b59b5c9 authored by carlushuang's avatar carlushuang
Browse files

add prenorm/postnorm support, refactor using generate.py

parent 4d5248e2
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 8, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 4, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 2, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 8, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 1, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 2, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 4, 64, 4, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 6, 4, 64, 2, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 12, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm2d_fwd.hpp"
#include <iostream>
#pragma once
using S = ck_tile::stream_config;
using A = layernorm2d_fwd_args;
template <typename DataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kSaveMeanInvStd_,
bool kTwoPass_>
using trait_ = layernorm2d_fwd_traits_<DataType_,
Repeat_M_,
Repeat_N_,
ThreadPerBlock_M_,
ThreadPerBlock_N_,
Vector_N_,
kPadN_,
kSaveMeanInvStd_,
kTwoPass_>;
template <typename Traits_>
float layernorm2d_fwd_(const S& s, A a)
{
using DataType = typename Traits_::DataType;
using PipelineProblem = ck_tile::Layernorm2dFwdPipelineProblem<
typename LayerNormTypeConfig<DataType>::XDataType,
typename LayerNormTypeConfig<DataType>::GammaDataType,
typename LayerNormTypeConfig<DataType>::BetaDataType,
typename LayerNormTypeConfig<DataType>::ComputeDataType,
typename LayerNormTypeConfig<DataType>::YDataType,
typename LayerNormTypeConfig<DataType>::MeanDataType,
typename LayerNormTypeConfig<DataType>::InvStdDataType,
typename Traits_::Shape,
Traits_::kPadN,
Traits_::kSaveMeanInvStd,
Traits_::kTwoPass>;
using OnePassPipeline = ck_tile::Layernorm2dFwdPipelineOnePass<PipelineProblem>;
using TwoPassPipeline = ck_tile::Layernorm2dFwdPipelineTwoPass<PipelineProblem>;
using Pipeline = std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>;
using Kernel = ck_tile::Layernorm2dFwd<Pipeline>;
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
auto kargs = Kernel::MakeKargs(a);
if(s.log_level_ > 0)
std::cout << ", " << Kernel::GetName() << std::flush;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
#include "layernorm2d_fwd.hpp" #include "layernorm2d_fwd.hpp"
#include <cstring> #include <cstring>
#include <algorithm>
// different threshold for different dtype // different threshold for different dtype
template <typename DataType> template <typename DataType>
...@@ -29,7 +30,13 @@ auto create_args(int argc, char* argv[]) ...@@ -29,7 +30,13 @@ auto create_args(int argc, char* argv[])
.insert("save_mv", "0", "save mean/variance(invstd) or not. set to 1 in training case") .insert("save_mv", "0", "save mean/variance(invstd) or not. set to 1 in training case")
.insert("v", "1", "cpu validation or not") .insert("v", "1", "cpu validation or not")
.insert("kname", "1", "print kernel name or not") .insert("kname", "1", "print kernel name or not")
.insert("prec", "fp16", "precision") .insert("prec_i", "fp16", "input precision")
.insert("prec_o", "auto", "output precision, set auto will be the same as input")
.insert(
"fadd",
"0",
"fused-add, 0:no fused add, 1:fused-prenorm(preadd+store), 2:fused-postnorm(preadd)")
.insert("fsweep", "0", "fused-sweep")
.insert("warmup", "5", "cold iter") .insert("warmup", "5", "cold iter")
.insert("repeat", "20", "hot iter"); .insert("repeat", "20", "hot iter");
...@@ -37,7 +44,7 @@ auto create_args(int argc, char* argv[]) ...@@ -37,7 +44,7 @@ auto create_args(int argc, char* argv[])
return std::make_tuple(result, arg_parser); return std::make_tuple(result, arg_parser);
} }
template <typename DataType, bool SaveMeanVar> template <typename InDataType, typename OutDataType, bool SaveMeanVar>
bool run(const ck_tile::ArgParser& arg_parser) bool run(const ck_tile::ArgParser& arg_parser)
{ {
ck_tile::index_t m = arg_parser.get_int("m"); ck_tile::index_t m = arg_parser.get_int("m");
...@@ -46,20 +53,29 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -46,20 +53,29 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(stride < 0) if(stride < 0)
stride = n; stride = n;
float epsilon = arg_parser.get_float("e"); float epsilon = arg_parser.get_float("e");
std::string data_type = arg_parser.get_str("prec"); std::string prec_i = arg_parser.get_str("prec_i");
std::string prec_o = arg_parser.get_str("prec_o");
if(prec_o == "auto")
{
prec_o = prec_i;
}
int kname = arg_parser.get_int("kname"); int kname = arg_parser.get_int("kname");
int do_validation = arg_parser.get_int("v"); int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup"); int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat"); int repeat = arg_parser.get_int("repeat");
int fused_add = arg_parser.get_int("fadd");
int fused_sweep = arg_parser.get_int("fsweep");
assert(stride >= n); assert(stride >= n);
using TypeConfig = LayerNormTypeConfig<DataType>; using TypeConfig = LayerNormTypeConfig<InDataType, OutDataType>;
using XDataType = typename TypeConfig::XDataType; using XDataType = typename TypeConfig::XDataType;
using YDataType = typename TypeConfig::YDataType; using YDataType = typename TypeConfig::YDataType;
using GammaDataType = typename TypeConfig::GammaDataType; using GammaDataType = typename TypeConfig::GammaDataType;
using BetaDataType = typename TypeConfig::BetaDataType; using BetaDataType = typename TypeConfig::BetaDataType;
using SXDataType = XDataType;
using SYDataType = YDataType;
using MeanDataType = using MeanDataType =
std::conditional_t<SaveMeanVar, typename TypeConfig::MeanDataType, ck_tile::null_type>; std::conditional_t<SaveMeanVar, typename TypeConfig::MeanDataType, ck_tile::null_type>;
...@@ -73,6 +89,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -73,6 +89,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<GammaDataType> gamma_host({n}); ck_tile::HostTensor<GammaDataType> gamma_host({n});
ck_tile::HostTensor<BetaDataType> beta_host({n}); ck_tile::HostTensor<BetaDataType> beta_host({n});
ck_tile::HostTensor<SXDataType> sx_host({m, n}, {stride, 1});
ck_tile::HostTensor<SYDataType> sy_host({m, n}, {stride, 1});
ck_tile::HostTensor<YDataType> y_host_ref({m, n}, {stride, 1}); ck_tile::HostTensor<YDataType> y_host_ref({m, n}, {stride, 1});
ck_tile::HostTensor<YDataType> y_host_dev({m, n}, {stride, 1}); ck_tile::HostTensor<YDataType> y_host_dev({m, n}, {stride, 1});
...@@ -88,19 +107,25 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -88,19 +107,25 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem sx_buf(sx_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sy_buf(sy_host.get_element_space_size_in_bytes());
x_buf.ToDevice(x_host.data()); x_buf.ToDevice(x_host.data());
gamma_buf.ToDevice(gamma_host.data()); gamma_buf.ToDevice(gamma_host.data());
beta_buf.ToDevice(beta_host.data()); beta_buf.ToDevice(beta_host.data());
sx_buf.ToDevice(sx_host.data());
std::cout << "[" << data_type << "]" std::cout << "[" << prec_i << "]"
<< " m:" << m << ", n:" << n << ", stride:" << stride << std::flush; << " m:" << m << ", n:" << n << ", stride:" << stride << std::flush;
layernorm2d_fwd_traits traits{data_type, SaveMeanVar}; layernorm2d_fwd_traits traits{prec_i, prec_o, SaveMeanVar, fused_add, fused_sweep};
layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(), layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(),
fused_add != 0 ? sx_buf.GetDeviceBuffer() : nullptr,
gamma_buf.GetDeviceBuffer(), gamma_buf.GetDeviceBuffer(),
beta_buf.GetDeviceBuffer(), beta_buf.GetDeviceBuffer(),
y_buf.GetDeviceBuffer(), y_buf.GetDeviceBuffer(),
fused_add == 1 ? sy_buf.GetDeviceBuffer() : nullptr,
nullptr, nullptr,
nullptr, nullptr,
epsilon, epsilon,
...@@ -111,6 +136,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -111,6 +136,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
float ave_time = layernorm2d_fwd( float ave_time = layernorm2d_fwd(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
if(ave_time < 0)
{
std::cout << " not supported!" << std::endl << std::flush;
return false;
}
std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(GammaDataType) * n + std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(GammaDataType) * n +
sizeof(BetaDataType) * n + sizeof(YDataType) * m * n; sizeof(BetaDataType) * n + sizeof(YDataType) * m * n;
...@@ -122,6 +153,17 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -122,6 +153,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(do_validation) if(do_validation)
{ {
// reference // reference
if(fused_add != 0)
{
// fused pre_add/pre_add_store
// TODO we accumulate directly to x_host for simplcity here...
std::transform(x_host.mData.cbegin(),
x_host.mData.cend(),
sx_host.mData.cbegin(),
x_host.mData.begin(),
std::plus<XDataType>{});
}
ck_tile::reference_layernorm2d_fwd<XDataType, ck_tile::reference_layernorm2d_fwd<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
...@@ -133,11 +175,22 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -133,11 +175,22 @@ bool run(const ck_tile::ArgParser& arg_parser)
y_buf.FromDevice(y_host_dev.data()); y_buf.FromDevice(y_host_dev.data());
auto [rtol, atol] = get_elimit<DataType>(); ck_tile::HostTensor<SYDataType> sy_host_dev({m, n}, {stride, 1});
if(fused_add == 1)
{
sy_buf.FromDevice(sy_host_dev.data());
}
auto [rtol, atol] = get_elimit<InDataType>();
if(stride == n) if(stride == n)
{ {
pass = ck_tile::check_err( pass = ck_tile::check_err(
y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol); y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
if(fused_add == 1)
{
pass &= ck_tile::check_err(
sy_host_dev, x_host, std::string("ADD Error: Incorrect results!"), rtol, atol);
}
} }
else else
{ {
...@@ -153,6 +206,19 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -153,6 +206,19 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::string("] Error: Incorrect results!"), std::string("] Error: Incorrect results!"),
rtol, rtol,
atol); atol);
if(fused_add == 1)
{
std::vector<SYDataType> sy_host_dev_row(sy_host_dev.begin() + i_r * stride,
sy_host_dev.begin() + i_r * stride + n);
std::vector<SYDataType> sy_host_ref_row(x_host.begin() + i_r * stride,
x_host.begin() + i_r * stride + n);
pass &= ck_tile::check_err(sy_host_dev_row,
sy_host_ref_row,
std::string("ADD[") + std::to_string(i_r) +
std::string("] Error: Incorrect results!"),
rtol,
atol);
}
} }
} }
...@@ -168,23 +234,28 @@ int main(int argc, char* argv[]) ...@@ -168,23 +234,28 @@ int main(int argc, char* argv[])
if(!result) if(!result)
return -1; return -1;
const std::string data_type = arg_parser.get_str("prec"); std::string prec_i = arg_parser.get_str("prec_i");
std::string prec_o = arg_parser.get_str("prec_o");
if(prec_o == "auto")
{
prec_o = prec_i;
}
int save_mv = arg_parser.get_int("save_mv"); int save_mv = arg_parser.get_int("save_mv");
if(data_type == "fp16" && save_mv) if(prec_i == "fp16" && prec_o == "fp16" && save_mv)
{ {
return run<ck_tile::half_t, true>(arg_parser) ? 0 : -2; return run<ck_tile::half_t, ck_tile::half_t, true>(arg_parser) ? 0 : -2;
} }
else if(data_type == "fp16" && !save_mv) else if(prec_i == "fp16" && prec_o == "fp16" && !save_mv)
{ {
return run<ck_tile::half_t, false>(arg_parser) ? 0 : -2; return run<ck_tile::half_t, ck_tile::half_t, false>(arg_parser) ? 0 : -2;
} }
else if(data_type == "bf16" && save_mv) else if(prec_i == "bf16" && prec_o == "bf16" && save_mv)
{ {
return run<ck_tile::bf16_t, true>(arg_parser) ? 0 : -2; return run<ck_tile::bf16_t, ck_tile::bf16_t, true>(arg_parser) ? 0 : -2;
} }
else if(data_type == "bf16" && !save_mv) else if(prec_i == "bf16" && prec_o == "bf16" && !save_mv)
{ {
return run<ck_tile::bf16_t, true>(arg_parser) ? 0 : -2; return run<ck_tile::bf16_t, ck_tile::bf16_t, true>(arg_parser) ? 0 : -2;
} }
return -3; return -3;
......
...@@ -8,14 +8,14 @@ ...@@ -8,14 +8,14 @@
#include "ck_tile/ops/layernorm2d.hpp" #include "ck_tile/ops/layernorm2d.hpp"
#include <string> #include <string>
template <typename DataType> template <typename InType, typename OutType>
struct LayerNormTypeConfig; struct LayerNormTypeConfig;
template <> template <typename OutType>
struct LayerNormTypeConfig<ck_tile::half_t> struct LayerNormTypeConfig<ck_tile::half_t, OutType>
{ {
using XDataType = ck_tile::half_t; using XDataType = ck_tile::half_t;
using YDataType = ck_tile::half_t; using YDataType = OutType;
using GammaDataType = ck_tile::half_t; using GammaDataType = ck_tile::half_t;
using BetaDataType = ck_tile::half_t; using BetaDataType = ck_tile::half_t;
using MeanDataType = ck_tile::half_t; using MeanDataType = ck_tile::half_t;
...@@ -23,11 +23,11 @@ struct LayerNormTypeConfig<ck_tile::half_t> ...@@ -23,11 +23,11 @@ struct LayerNormTypeConfig<ck_tile::half_t>
using ComputeDataType = float; using ComputeDataType = float;
}; };
template <> template <typename OutType>
struct LayerNormTypeConfig<ck_tile::bf16_t> struct LayerNormTypeConfig<ck_tile::bf16_t, OutType>
{ {
using XDataType = ck_tile::bf16_t; using XDataType = ck_tile::bf16_t;
using YDataType = ck_tile::bf16_t; using YDataType = OutType;
using GammaDataType = ck_tile::bf16_t; using GammaDataType = ck_tile::bf16_t;
using BetaDataType = ck_tile::bf16_t; using BetaDataType = ck_tile::bf16_t;
using MeanDataType = ck_tile::bf16_t; using MeanDataType = ck_tile::bf16_t;
...@@ -40,82 +40,17 @@ struct layernorm2d_fwd_args : public ck_tile::Layernorm2dFwdHostArgs ...@@ -40,82 +40,17 @@ struct layernorm2d_fwd_args : public ck_tile::Layernorm2dFwdHostArgs
{ {
}; };
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <typename DataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kSaveMeanInvStd_,
bool kTwoPass_>
struct layernorm2d_fwd_traits_
{
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
static constexpr ck_tile::index_t total_warps =
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;
// num of warps along m
static constexpr ck_tile::index_t BlockWarps_M = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
return total_warps * (warpSize / ThreadPerBlock_N_);
}
else
{
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / warpSize);
}
}();
// num of warps along n
static constexpr ck_tile::index_t BlockWarps_N = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
return 1;
}
else
{
static_assert(ThreadPerBlock_N_ % warpSize == 0);
return ThreadPerBlock_N_ / warpSize;
}
}();
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
using Vector = ck_tile::sequence<1, Vector_N_>;
using Shape = ck_tile::Layernorm2dShape<BlockTile, BlockWarps, WarpTile, Vector>;
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
static constexpr bool kTwoPass = kTwoPass_;
};
template <typename Traits_> template <typename Traits_>
float layernorm2d_fwd_(const ck_tile::stream_config& s, layernorm2d_fwd_args a); float layernorm2d_fwd_(const ck_tile::stream_config& s, layernorm2d_fwd_args a);
// This is the public API, will be generated by script // This is the public API, will be generated by script
struct layernorm2d_fwd_traits struct layernorm2d_fwd_traits
{ {
std::string data_type; std::string prec_i;
std::string prec_o;
bool save_mean_var; bool save_mean_var;
int fused_add; // 0:no-add, 1:pre-add, 2:post-add
int fused_sweep; // 0:no-sweep,
}; };
float layernorm2d_fwd(layernorm2d_fwd_traits, layernorm2d_fwd_args, const ck_tile::stream_config&); float layernorm2d_fwd(layernorm2d_fwd_traits, layernorm2d_fwd_args, const ck_tile::stream_config&);
...@@ -2,37 +2,37 @@ ...@@ -2,37 +2,37 @@
# run from top of ck folder # run from top of ck folder
EXE=build/bin/tile_example_layernorm2d_fwd EXE=build/bin/tile_example_layernorm2d_fwd
$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 $EXE -m=1 -n=1 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 $EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 $EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 $EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 $EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 $EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 $EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 $EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 $EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 $EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 $EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 $EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 $EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 $EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 $EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 $EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 $EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 $EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 $EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 $EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 $EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 $EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 $EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 $EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 $EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 $EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 $EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 $EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 $EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 $EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 $EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 $EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 $EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
\ No newline at end of file \ No newline at end of file
...@@ -3,29 +3,31 @@ ...@@ -3,29 +3,31 @@
EXE=./build/bin/tile_example_layernorm2d_fwd EXE=./build/bin/tile_example_layernorm2d_fwd
for pr_i in "fp16" "bf16" ; do for pr_i in "fp16" "bf16" ; do
$EXE -prec=$pr_i -m=99 -n=13 for fadd in "0" "1" "2"; do
$EXE -prec=$pr_i -m=17 -n=16 $EXE -prec_i=$pr_i -fadd=$fadd -m=99 -n=13
$EXE -prec=$pr_i -m=1 -n=100 $EXE -prec_i=$pr_i -fadd=$fadd -m=17 -n=16
$EXE -prec=$pr_i -m=4 -n=128 $EXE -prec_i=$pr_i -fadd=$fadd -m=1 -n=100
$EXE -prec=$pr_i -m=80 -n=127 $EXE -prec_i=$pr_i -fadd=$fadd -m=4 -n=128
$EXE -prec=$pr_i -m=22 -n=255 -stride=256 $EXE -prec_i=$pr_i -fadd=$fadd -m=80 -n=127
$EXE -prec=$pr_i -m=7 -n=599 $EXE -prec_i=$pr_i -fadd=$fadd -m=22 -n=255 -stride=256
$EXE -prec=$pr_i -m=19 -n=512 $EXE -prec_i=$pr_i -fadd=$fadd -m=7 -n=599
$EXE -prec=$pr_i -m=33 -n=313 -stride=1000 $EXE -prec_i=$pr_i -fadd=$fadd -m=19 -n=512
$EXE -prec=$pr_i -m=11 -n=510 $EXE -prec_i=$pr_i -fadd=$fadd -m=33 -n=313 -stride=1000
$EXE -prec=$pr_i -m=171 -n=676 -stride=818 $EXE -prec_i=$pr_i -fadd=$fadd -m=11 -n=510
$EXE -prec=$pr_i -m=91 -n=636 $EXE -prec_i=$pr_i -fadd=$fadd -m=171 -n=676 -stride=818
$EXE -prec=$pr_i -m=12 -n=768 -stride=800 $EXE -prec_i=$pr_i -fadd=$fadd -m=91 -n=636
$EXE -prec=$pr_i -m=100 -n=766 -stride=812 $EXE -prec_i=$pr_i -fadd=$fadd -m=12 -n=768 -stride=800
$EXE -prec=$pr_i -m=31 -n=1024 $EXE -prec_i=$pr_i -fadd=$fadd -m=100 -n=766 -stride=812
$EXE -prec=$pr_i -m=64 -n=1000 -stride=1004 $EXE -prec_i=$pr_i -fadd=$fadd -m=31 -n=1024
$EXE -prec=$pr_i -m=8 -n=1501 $EXE -prec_i=$pr_i -fadd=$fadd -m=64 -n=1000 -stride=1004
$EXE -prec=$pr_i -m=3 -n=1826 $EXE -prec_i=$pr_i -fadd=$fadd -m=8 -n=1501
$EXE -prec=$pr_i -m=5 -n=2040 $EXE -prec_i=$pr_i -fadd=$fadd -m=3 -n=1826
$EXE -prec=$pr_i -m=7 -n=2734 $EXE -prec_i=$pr_i -fadd=$fadd -m=5 -n=2040
$EXE -prec=$pr_i -m=1 -n=3182 $EXE -prec_i=$pr_i -fadd=$fadd -m=7 -n=2734
$EXE -prec=$pr_i -m=9 -n=4096 $EXE -prec_i=$pr_i -fadd=$fadd -m=1 -n=3182
$EXE -prec=$pr_i -m=3 -n=8192 $EXE -prec_i=$pr_i -fadd=$fadd -m=9 -n=4096
$EXE -prec=$pr_i -m=1 -n=10547 $EXE -prec_i=$pr_i -fadd=$fadd -m=3 -n=8192
$EXE -prec=$pr_i -m=3 -n=17134 $EXE -prec_i=$pr_i -fadd=$fadd -m=1 -n=10547
$EXE -prec_i=$pr_i -fadd=$fadd -m=3 -n=17134
done
done done
...@@ -80,6 +80,13 @@ CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, ...@@ -80,6 +80,13 @@ CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view,
return null_tile_window<remove_cvref_t<WindowLengths>>{window_lengths}; return null_tile_window<remove_cvref_t<WindowLengths>>{window_lengths};
} }
template <typename WindowLengths, typename StaticTileDistribution>
CK_TILE_DEVICE constexpr auto make_tile_window(const null_tile_window<WindowLengths>& t,
const StaticTileDistribution&)
{
return t;
}
template <typename WindowLengths> template <typename WindowLengths>
CK_TILE_DEVICE void CK_TILE_DEVICE void
move_tile_window(null_tile_window<WindowLengths>&, move_tile_window(null_tile_window<WindowLengths>&,
......
...@@ -9,4 +9,5 @@ ...@@ -9,4 +9,5 @@
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
...@@ -5,17 +5,20 @@ ...@@ -5,17 +5,20 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp" #include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp"
namespace ck_tile { namespace ck_tile {
// host side args // host side args
struct Layernorm2dFwdHostArgs struct Layernorm2dFwdHostArgs
{ {
const void* p_x; const void* p_x; // input
const void* p_sx; // shortcut input, set to nullptr if no
const void* p_gamma; const void* p_gamma;
const void* p_beta; const void* p_beta;
void* p_y; void* p_y; // output
void* p_sy; // shortcut output, set to nullptr if no
void* p_mean; void* p_mean;
void* p_invStd; void* p_invStd;
...@@ -27,10 +30,11 @@ struct Layernorm2dFwdHostArgs ...@@ -27,10 +30,11 @@ struct Layernorm2dFwdHostArgs
}; };
// TODO: Extract some type to wrapper class // TODO: Extract some type to wrapper class
template <typename Pipeline_> template <typename Pipeline_, typename Epilogue_>
struct Layernorm2dFwd struct Layernorm2dFwd
{ {
using Pipeline = remove_cvref_t<Pipeline_>; using Pipeline = remove_cvref_t<Pipeline_>;
using Epilogue = remove_cvref_t<Epilogue_>;
using Problem = typename Pipeline::Problem; using Problem = typename Pipeline::Problem;
using XDataType = remove_cvref_t<typename Problem::XDataType>; using XDataType = remove_cvref_t<typename Problem::XDataType>;
...@@ -41,17 +45,23 @@ struct Layernorm2dFwd ...@@ -41,17 +45,23 @@ struct Layernorm2dFwd
using MeanDataType = remove_cvref_t<typename Problem::MeanDataType>; using MeanDataType = remove_cvref_t<typename Problem::MeanDataType>;
using InvStdDataType = remove_cvref_t<typename Problem::InvStdDataType>; using InvStdDataType = remove_cvref_t<typename Problem::InvStdDataType>;
// for simplicity, shortcut input/output type is same as X
using SXDataType = XDataType;
using SYDataType = XDataType;
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, null_type>; static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, null_type>;
static constexpr bool kHasBeta = !std::is_same_v<BetaDataType, null_type>; static constexpr bool kHasBeta = !std::is_same_v<BetaDataType, null_type>;
static constexpr bool kSaveMeanInvStd = Problem::kSaveMeanInvStd; static constexpr bool kSaveMeanInvStd = Problem::Traits::kSaveMeanInvStd;
static constexpr bool kSaveMean = Problem::kSaveMeanInvStd; static constexpr bool kSaveMean = Problem::Traits::kSaveMeanInvStd;
static constexpr bool kSaveInvStd = Problem::kSaveMeanInvStd; static constexpr bool kSaveInvStd = Problem::Traits::kSaveMeanInvStd;
static constexpr index_t Block_M = Problem::BlockShape::Block_M; static constexpr index_t Block_M = Problem::BlockShape::Block_M;
static constexpr index_t Block_N = Problem::BlockShape::Block_N; static constexpr index_t Block_N = Problem::BlockShape::Block_N;
static constexpr bool kPadM = false; // always no need to pad along M static constexpr bool kPadM = false; // always no need to pad along M
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool kTwoPass = Problem::kTwoPass; static constexpr bool kTwoPass = Problem::Traits::kTwoPass;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedSweep = Problem::Traits::kFusedSweep;
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N; static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N; static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
...@@ -62,11 +72,13 @@ struct Layernorm2dFwd ...@@ -62,11 +72,13 @@ struct Layernorm2dFwd
struct Kargs struct Kargs
{ {
const void* p_x; const void* p_x; // input
const void* p_sx; // shortcut input, set to nullptr if no
const void* p_gamma; const void* p_gamma;
const void* p_beta; const void* p_beta;
void* p_y; void* p_y; // output
void* p_sy; // shortcut output, set to nullptr if no
void* p_mean; void* p_mean;
void* p_invStd; void* p_invStd;
...@@ -81,9 +93,11 @@ struct Layernorm2dFwd ...@@ -81,9 +93,11 @@ struct Layernorm2dFwd
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs) CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
{ {
return Kargs{hargs.p_x, return Kargs{hargs.p_x,
hargs.p_sx,
hargs.p_gamma, hargs.p_gamma,
hargs.p_beta, hargs.p_beta,
hargs.p_y, hargs.p_y,
hargs.p_sy,
hargs.p_mean, hargs.p_mean,
hargs.p_invStd, hargs.p_invStd,
hargs.epsilon, hargs.epsilon,
...@@ -113,17 +127,19 @@ struct Layernorm2dFwd ...@@ -113,17 +127,19 @@ struct Layernorm2dFwd
CK_TILE_HOST static std::string GetName() CK_TILE_HOST static std::string GetName()
{ {
// clang-format off // clang-format off
#define _SS_ std::string
#define _TS_ std::to_string
using S_ = typename Problem::BlockShape; using S_ = typename Problem::BlockShape;
auto surfix = [&] () { auto surfix = [&] () {
std::string n; std::string n;
if (kFusedAdd != Layernorm2dFusedAddEnum::NO_ADD) n += _SS_("_") + Layernorm2dFusedAddEnumName<kFusedAdd>::name;
if (kFusedSweep != Layernorm2dFusedSweepEnum::NO_SWEEP) n += _SS_("_") + Layernorm2dFusedSweepEnumName<kFusedSweep>::name;
if (kPadN) n += "_pn"; if (kPadN) n += "_pn";
if (kSaveMeanInvStd) n += "_mv"; if (kSaveMeanInvStd) n += "_mv";
if (kTwoPass) n += "_2p"; if (kTwoPass) n += "_2p";
return n; }(); return n; }();
#define _SS_ std::string
#define _TS_ std::to_string
return _SS_("layernorm2d_fwd_") + _SS_(t2s<XDataType>::name) + "_" + return _SS_("layernorm2d_fwd_") + _SS_(t2s<XDataType>::name) + "_" +
_TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" + _TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
_TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" + _TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
...@@ -153,6 +169,31 @@ struct Layernorm2dFwd ...@@ -153,6 +169,31 @@ struct Layernorm2dFwd
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0}); tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}(); }();
const auto sx_window = [&]() {
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
{
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const SXDataType*>(kargs.p_sx),
make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1),
number<Vector_N>{},
number<1>{});
// NOTE: we don't do any pad in this kernel for loading, assume that inside kernel
// will check the max count dynamically
const auto tmp2_ = pad_tensor_view(tmp_,
make_tuple(number<Block_M>{}, number<Block_N>{}),
sequence<false, false>{});
return make_tile_window(
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}
else
{
return make_null_tile_window(make_tuple(number<Block_M>{}, number<Block_N>{}));
}
}();
const auto gamma_window = [&]() { const auto gamma_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>( const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const GammaDataType*>(kargs.p_gamma), static_cast<const GammaDataType*>(kargs.p_gamma),
...@@ -194,6 +235,28 @@ struct Layernorm2dFwd ...@@ -194,6 +235,28 @@ struct Layernorm2dFwd
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0}); tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}(); }();
auto sy_window = [&]() {
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
{
auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<SYDataType*>(kargs.p_sy),
make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1),
number<Vector_N>{},
number<1>{});
auto tmp2_ = pad_tensor_view(tmp_,
make_tuple(number<Block_M>{}, number<Block_N>{}),
sequence<kPadM, kPadN>{});
return make_tile_window(
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}
else
{
return make_null_tile_window(make_tuple(number<Block_M>{}, number<Block_N>{}));
}
}();
auto mean_window = [&]() { auto mean_window = [&]() {
if constexpr(kSaveMean) if constexpr(kSaveMean)
{ {
...@@ -235,14 +298,17 @@ struct Layernorm2dFwd ...@@ -235,14 +298,17 @@ struct Layernorm2dFwd
__shared__ char smem[GetSmemSize()]; __shared__ char smem[GetSmemSize()];
Pipeline{}(x_window, Pipeline{}(x_window,
sx_window,
gamma_window, gamma_window,
beta_window, beta_window,
y_window, y_window,
sy_window,
mean_window, mean_window,
inv_std_window, inv_std_window,
static_cast<const ComputeDataType>(kargs.epsilon), static_cast<const ComputeDataType>(kargs.epsilon),
kargs.n, kargs.n,
smem); smem,
Epilogue{});
} }
}; };
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp"
#include <string> #include <string>
#include <type_traits> #include <type_traits>
...@@ -24,14 +25,19 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -24,14 +25,19 @@ struct Layernorm2dFwdPipelineOnePass
using MeanDataType = ck_tile::remove_cvref_t<typename Problem::MeanDataType>; using MeanDataType = ck_tile::remove_cvref_t<typename Problem::MeanDataType>;
using InvStdDataType = ck_tile::remove_cvref_t<typename Problem::InvStdDataType>; using InvStdDataType = ck_tile::remove_cvref_t<typename Problem::InvStdDataType>;
using SXDataType = XDataType;
using SYDataType = XDataType;
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>; static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
static constexpr bool kHasBeta = !std::is_same_v<BetaDataType, ck_tile::null_type>; static constexpr bool kHasBeta = !std::is_same_v<BetaDataType, ck_tile::null_type>;
static constexpr bool kSaveMean = Problem::kSaveMeanInvStd; static constexpr bool kSaveMean = Problem::Traits::kSaveMeanInvStd;
static constexpr bool kSaveInvStd = Problem::kSaveMeanInvStd; static constexpr bool kSaveInvStd = Problem::Traits::kSaveMeanInvStd;
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedSweep = Problem::Traits::kFusedSweep;
static constexpr const char* name = []() { static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync) if constexpr(kNeedCrossWarpSync)
...@@ -46,20 +52,26 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -46,20 +52,26 @@ struct Layernorm2dFwdPipelineOnePass
} }
template <typename XWindow, template <typename XWindow,
typename SXWindow,
typename GammaWindow, typename GammaWindow,
typename BetaWindow, typename BetaWindow,
typename YWindow, typename YWindow,
typename SYWindow,
typename MeanWindow, typename MeanWindow,
typename InvStdWindow> typename InvStdWindow,
typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_, CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const SXWindow& sx_window_,
const GammaWindow& gamma_window_, const GammaWindow& gamma_window_,
const BetaWindow& beta_window_, const BetaWindow& beta_window_,
YWindow& y_window, YWindow& y_window_,
const SYWindow& sy_window_,
MeanWindow& mean_window, MeanWindow& mean_window,
InvStdWindow& inv_std_window, InvStdWindow& inv_std_window,
ComputeDataType epsilon, ComputeDataType epsilon,
ck_tile::index_t row_size, ck_tile::index_t row_size,
void* smem) const void* smem,
Epilogue) const
{ {
const auto x_window = const auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>()); make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
...@@ -67,8 +79,14 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -67,8 +79,14 @@ struct Layernorm2dFwdPipelineOnePass
gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>()); gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
const auto beta_window = make_tile_window( const auto beta_window = make_tile_window(
beta_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>()); beta_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
const auto sx_window =
make_tile_window(sx_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto sy_window =
make_tile_window(sy_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto x = load_tile(x_window);
auto sx = load_tile(sx_window);
const auto x = load_tile(x_window);
int cur_count = 0; int cur_count = 0;
int max_count = int max_count =
block_tile_welford_calculate_max_count<typename Problem::BlockShape>(row_size); block_tile_welford_calculate_max_count<typename Problem::BlockShape>(row_size);
...@@ -81,6 +99,17 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -81,6 +99,17 @@ struct Layernorm2dFwdPipelineOnePass
const auto gamma = load_tile(gamma_window); const auto gamma = load_tile(gamma_window);
const auto beta = load_tile(beta_window); const auto beta = load_tile(beta_window);
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
{
sweep_tile(sx, [&](auto idx) {
// compute x = sx + x
x(idx) = type_convert<SYDataType>(sx(idx)) + type_convert<SYDataType>(x(idx));
});
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
store_tile(sy_window, x);
}
// compute welford each-thread->cross-lane->cross-warp // compute welford each-thread->cross-lane->cross-warp
auto [mean, var] = block_welford(x, cur_count, max_count); auto [mean, var] = block_welford(x, cur_count, max_count);
block_welford_sync(mean, var, cur_count); block_welford_sync(mean, var, cur_count);
...@@ -100,8 +129,8 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -100,8 +129,8 @@ struct Layernorm2dFwdPipelineOnePass
store_tile(inv_std_window, cast_tile<InvStdDataType>(inv_std)); store_tile(inv_std_window, cast_tile<InvStdDataType>(inv_std));
// layernorm computation // layernorm computation
auto y = make_static_distributed_tensor<YDataType>(x.get_tile_distribution()); auto ln = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
sweep_tile(y, [&, mean_ = mean](auto idx) { sweep_tile(ln, [&, mean_ = mean](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]); constexpr auto i_idx = make_tuple(idx[number<0>{}]);
constexpr auto j_idx = make_tuple(idx[number<1>{}]); constexpr auto j_idx = make_tuple(idx[number<1>{}]);
...@@ -109,11 +138,12 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -109,11 +138,12 @@ struct Layernorm2dFwdPipelineOnePass
const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]); const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);
const auto x_ = type_convert<ComputeDataType>(x[idx]); const auto x_ = type_convert<ComputeDataType>(x[idx]);
auto y_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_; auto ln_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
y(idx) = type_convert<YDataType>(y_); ln(idx) = ln_;
}); });
store_tile(y_window, y);
Epilogue{}(y_window_, ln);
} }
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -15,9 +15,7 @@ template <typename XDataType_, ...@@ -15,9 +15,7 @@ template <typename XDataType_,
typename MeanDataType_, typename MeanDataType_,
typename InvStdDataType_, typename InvStdDataType_,
typename BlockShape_, typename BlockShape_,
bool kPadN_, typename Traits_>
bool kSaveMeanInvStd_,
bool kTwoPass_>
struct Layernorm2dFwdPipelineProblem struct Layernorm2dFwdPipelineProblem
{ {
using XDataType = remove_cvref_t<XDataType_>; using XDataType = remove_cvref_t<XDataType_>;
...@@ -32,9 +30,7 @@ struct Layernorm2dFwdPipelineProblem ...@@ -32,9 +30,7 @@ struct Layernorm2dFwdPipelineProblem
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1; static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1; static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
static constexpr bool kPadN = kPadN_; using Traits = remove_cvref_t<Traits_>;
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
static constexpr bool kTwoPass = kTwoPass_;
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -24,14 +24,19 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -24,14 +24,19 @@ struct Layernorm2dFwdPipelineTwoPass
using MeanDataType = ck_tile::remove_cvref_t<typename Problem::MeanDataType>; using MeanDataType = ck_tile::remove_cvref_t<typename Problem::MeanDataType>;
using InvStdDataType = ck_tile::remove_cvref_t<typename Problem::InvStdDataType>; using InvStdDataType = ck_tile::remove_cvref_t<typename Problem::InvStdDataType>;
using SXDataType = XDataType;
using SYDataType = XDataType;
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>; static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
static constexpr bool kHasBeta = !std::is_same_v<BetaDataType, ck_tile::null_type>; static constexpr bool kHasBeta = !std::is_same_v<BetaDataType, ck_tile::null_type>;
static constexpr bool kSaveMean = Problem::kSaveMeanInvStd; static constexpr bool kSaveMean = Problem::Traits::kSaveMeanInvStd;
static constexpr bool kSaveInvStd = Problem::kSaveMeanInvStd; static constexpr bool kSaveInvStd = Problem::Traits::kSaveMeanInvStd;
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedSweep = Problem::Traits::kFusedSweep;
static constexpr const char* name = []() { static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync) if constexpr(kNeedCrossWarpSync)
...@@ -46,20 +51,26 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -46,20 +51,26 @@ struct Layernorm2dFwdPipelineTwoPass
} }
template <typename XWindow, template <typename XWindow,
typename SXWindow,
typename GammaWindow, typename GammaWindow,
typename BetaWindow, typename BetaWindow,
typename YWindow, typename YWindow,
typename SYWindow,
typename MeanWindow, typename MeanWindow,
typename InvStdWindow> typename InvStdWindow,
typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_, CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const SXWindow& sx_window_,
const GammaWindow& gamma_window_, const GammaWindow& gamma_window_,
const BetaWindow& beta_window_, const BetaWindow& beta_window_,
YWindow& y_window, YWindow& y_window,
const SYWindow& sy_window_,
MeanWindow& mean_window, MeanWindow& mean_window,
InvStdWindow& inv_std_window, InvStdWindow& inv_std_window,
ComputeDataType epsilon, ComputeDataType epsilon,
ck_tile::index_t row_size, ck_tile::index_t row_size,
void* smem) const void* smem,
Epilogue) const
{ {
auto x_window = auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>()); make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
...@@ -67,6 +78,10 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -67,6 +78,10 @@ struct Layernorm2dFwdPipelineTwoPass
gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>()); gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
auto beta_window = make_tile_window( auto beta_window = make_tile_window(
beta_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>()); beta_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
auto sx_window =
make_tile_window(sx_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto sy_window =
make_tile_window(sy_window_, Policy::template MakeXBlockTileDistribution<Problem>());
// Problem::BlockShape // Problem::BlockShape
static constexpr index_t Block_N = Problem::BlockShape::Block_N; static constexpr index_t Block_N = Problem::BlockShape::Block_N;
...@@ -93,9 +108,25 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -93,9 +108,25 @@ struct Layernorm2dFwdPipelineTwoPass
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{ {
const auto x = load_tile(x_window); auto x = load_tile(x_window);
block_welford(x, mean, var, cur_count, max_count); auto sx = load_tile(sx_window);
move_tile_window(x_window, {0, Block_N}); move_tile_window(x_window, {0, Block_N});
move_tile_window(sx_window, {0, Block_N});
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
{
sweep_tile(sx, [&](auto idx) {
// compute x = sx + x
x(idx) = type_convert<SYDataType>(sx(idx)) + type_convert<SYDataType>(x(idx));
});
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
{
store_tile(sy_window, x);
move_tile_window(sy_window, {0, Block_N});
}
}
block_welford(x, mean, var, cur_count, max_count);
} }
block_welford_sync(mean, var, cur_count); block_welford_sync(mean, var, cur_count);
...@@ -121,6 +152,7 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -121,6 +152,7 @@ struct Layernorm2dFwdPipelineTwoPass
// x_window.foo(); // x_window.foo();
// gamma_window.foo(); // gamma_window.foo();
move_tile_window(x_window, {0, -Block_N}); move_tile_window(x_window, {0, -Block_N});
move_tile_window(sx_window, {0, -Block_N});
move_tile_window(gamma_window, {stride_to_right_most_window}); move_tile_window(gamma_window, {stride_to_right_most_window});
move_tile_window(beta_window, {stride_to_right_most_window}); move_tile_window(beta_window, {stride_to_right_most_window});
move_tile_window(y_window, {0, stride_to_right_most_window}); move_tile_window(y_window, {0, stride_to_right_most_window});
...@@ -128,14 +160,23 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -128,14 +160,23 @@ struct Layernorm2dFwdPipelineTwoPass
// layernorm computation // layernorm computation
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{ {
const auto x = load_tile(x_window); auto x = load_tile(x_window);
auto sx = load_tile(sx_window);
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
{
sweep_tile(sx, [&](auto idx) {
// compute x = sx + x
x(idx) = type_convert<SYDataType>(sx(idx)) + type_convert<SYDataType>(x(idx));
});
}
// load gamma/beta (TODO: support no gamma/beta?) // load gamma/beta (TODO: support no gamma/beta?)
const auto gamma = load_tile(gamma_window); const auto gamma = load_tile(gamma_window);
const auto beta = load_tile(beta_window); const auto beta = load_tile(beta_window);
auto y = make_static_distributed_tensor<YDataType>(x.get_tile_distribution()); auto ln = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
sweep_tile(y, [&, mean_ = mean](auto idx) { sweep_tile(ln, [&, mean_ = mean](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]); constexpr auto i_idx = make_tuple(idx[number<0>{}]);
constexpr auto j_idx = make_tuple(idx[number<1>{}]); constexpr auto j_idx = make_tuple(idx[number<1>{}]);
...@@ -143,14 +184,15 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -143,14 +184,15 @@ struct Layernorm2dFwdPipelineTwoPass
const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]); const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);
const auto x_ = type_convert<ComputeDataType>(x[idx]); const auto x_ = type_convert<ComputeDataType>(x[idx]);
auto y_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_; auto ln_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
y(idx) = type_convert<YDataType>(y_); ln(idx) = ln_;
}); });
store_tile(y_window, y); Epilogue{}(y_window, ln);
move_tile_window(x_window, {0, -Block_N}); move_tile_window(x_window, {0, -Block_N});
move_tile_window(sx_window, {0, -Block_N});
move_tile_window(gamma_window, {-Block_N}); move_tile_window(gamma_window, {-Block_N});
move_tile_window(beta_window, {-Block_N}); move_tile_window(beta_window, {-Block_N});
move_tile_window(y_window, {0, -Block_N}); move_tile_window(y_window, {0, -Block_N});
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
enum class Layernorm2dFusedAddEnum
{
NO_ADD = 0,
// fused add before layernorm (prenorm), and store result to global
PRE_ADD_STORE = 1,
PRE_NORM_ADD = PRE_ADD_STORE,
// fused add before layernorm (postnorm), but not store result
PRE_ADD = 2,
POST_NORM_ADD = PRE_ADD,
};
// clang-format off
template<Layernorm2dFusedAddEnum E> struct Layernorm2dFusedAddEnumName;
template<> struct Layernorm2dFusedAddEnumName<Layernorm2dFusedAddEnum::NO_ADD> { static constexpr const char * name = "no"; };
template<> struct Layernorm2dFusedAddEnumName<Layernorm2dFusedAddEnum::PRE_ADD_STORE> { static constexpr const char * name = "pras"; };
template<> struct Layernorm2dFusedAddEnumName<Layernorm2dFusedAddEnum::PRE_ADD> { static constexpr const char * name = "pra"; };
// clang-format on
enum class Layernorm2dFusedSweepEnum
{
NO_SWEEP = 0,
RENORM = 1,
DYNAMIC_QUANT = 2,
};
// clang-format off
template<Layernorm2dFusedSweepEnum E> struct Layernorm2dFusedSweepEnumName;
template<> struct Layernorm2dFusedSweepEnumName<Layernorm2dFusedSweepEnum::NO_SWEEP> { static constexpr const char * name = "no"; };
template<> struct Layernorm2dFusedSweepEnumName<Layernorm2dFusedSweepEnum::RENORM> { static constexpr const char * name = "renorm"; };
template<> struct Layernorm2dFusedSweepEnumName<Layernorm2dFusedSweepEnum::DYNAMIC_QUANT> { static constexpr const char * name = "dequant"; };
// clang-format on
template <bool kPadN_,
bool kSaveMeanInvStd_,
bool kTwoPass_,
Layernorm2dFusedAddEnum kFusedAdd_,
Layernorm2dFusedSweepEnum kFusedSweep_>
struct Layernorm2dFwdTraits
{
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
static constexpr bool kTwoPass = kTwoPass_;
static constexpr Layernorm2dFusedAddEnum kFusedAdd = kFusedAdd_;
static constexpr Layernorm2dFusedSweepEnum kFusedSweep = kFusedSweep_;
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment