Commit 4b59b5c9 authored by carlushuang's avatar carlushuang
Browse files

add prenorm/postnorm support, refactor using generate.py

parent 4d5248e2
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 8, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 4, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 2, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 8, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 1, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 2, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 4, 64, 4, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 6, 4, 64, 2, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 12, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm2d_fwd.hpp"
#include <iostream>
#pragma once
using S = ck_tile::stream_config;
using A = layernorm2d_fwd_args;
template <typename DataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kSaveMeanInvStd_,
bool kTwoPass_>
using trait_ = layernorm2d_fwd_traits_<DataType_,
Repeat_M_,
Repeat_N_,
ThreadPerBlock_M_,
ThreadPerBlock_N_,
Vector_N_,
kPadN_,
kSaveMeanInvStd_,
kTwoPass_>;
template <typename Traits_>
float layernorm2d_fwd_(const S& s, A a)
{
using DataType = typename Traits_::DataType;
using PipelineProblem = ck_tile::Layernorm2dFwdPipelineProblem<
typename LayerNormTypeConfig<DataType>::XDataType,
typename LayerNormTypeConfig<DataType>::GammaDataType,
typename LayerNormTypeConfig<DataType>::BetaDataType,
typename LayerNormTypeConfig<DataType>::ComputeDataType,
typename LayerNormTypeConfig<DataType>::YDataType,
typename LayerNormTypeConfig<DataType>::MeanDataType,
typename LayerNormTypeConfig<DataType>::InvStdDataType,
typename Traits_::Shape,
Traits_::kPadN,
Traits_::kSaveMeanInvStd,
Traits_::kTwoPass>;
using OnePassPipeline = ck_tile::Layernorm2dFwdPipelineOnePass<PipelineProblem>;
using TwoPassPipeline = ck_tile::Layernorm2dFwdPipelineTwoPass<PipelineProblem>;
using Pipeline = std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>;
using Kernel = ck_tile::Layernorm2dFwd<Pipeline>;
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
auto kargs = Kernel::MakeKargs(a);
if(s.log_level_ > 0)
std::cout << ", " << Kernel::GetName() << std::flush;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
#include "ck_tile/host.hpp"
#include "layernorm2d_fwd.hpp"
#include <cstring>
#include <algorithm>
// different threshold for different dtype
template <typename DataType>
......@@ -29,7 +30,13 @@ auto create_args(int argc, char* argv[])
.insert("save_mv", "0", "save mean/variance(invstd) or not. set to 1 in training case")
.insert("v", "1", "cpu validation or not")
.insert("kname", "1", "print kernel name or not")
.insert("prec", "fp16", "precision")
.insert("prec_i", "fp16", "input precision")
.insert("prec_o", "auto", "output precision, set auto will be the same as input")
.insert(
"fadd",
"0",
"fused-add, 0:no fused add, 1:fused-prenorm(preadd+store), 2:fused-postnorm(preadd)")
.insert("fsweep", "0", "fused-sweep")
.insert("warmup", "5", "cold iter")
.insert("repeat", "20", "hot iter");
......@@ -37,7 +44,7 @@ auto create_args(int argc, char* argv[])
return std::make_tuple(result, arg_parser);
}
template <typename DataType, bool SaveMeanVar>
template <typename InDataType, typename OutDataType, bool SaveMeanVar>
bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile::index_t m = arg_parser.get_int("m");
......@@ -45,21 +52,30 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t stride = arg_parser.get_int("stride");
if(stride < 0)
stride = n;
float epsilon = arg_parser.get_float("e");
std::string data_type = arg_parser.get_str("prec");
int kname = arg_parser.get_int("kname");
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
float epsilon = arg_parser.get_float("e");
std::string prec_i = arg_parser.get_str("prec_i");
std::string prec_o = arg_parser.get_str("prec_o");
if(prec_o == "auto")
{
prec_o = prec_i;
}
int kname = arg_parser.get_int("kname");
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
int fused_add = arg_parser.get_int("fadd");
int fused_sweep = arg_parser.get_int("fsweep");
assert(stride >= n);
using TypeConfig = LayerNormTypeConfig<DataType>;
using TypeConfig = LayerNormTypeConfig<InDataType, OutDataType>;
using XDataType = typename TypeConfig::XDataType;
using YDataType = typename TypeConfig::YDataType;
using GammaDataType = typename TypeConfig::GammaDataType;
using BetaDataType = typename TypeConfig::BetaDataType;
using SXDataType = XDataType;
using SYDataType = YDataType;
using MeanDataType =
std::conditional_t<SaveMeanVar, typename TypeConfig::MeanDataType, ck_tile::null_type>;
......@@ -73,6 +89,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<GammaDataType> gamma_host({n});
ck_tile::HostTensor<BetaDataType> beta_host({n});
ck_tile::HostTensor<SXDataType> sx_host({m, n}, {stride, 1});
ck_tile::HostTensor<SYDataType> sy_host({m, n}, {stride, 1});
ck_tile::HostTensor<YDataType> y_host_ref({m, n}, {stride, 1});
ck_tile::HostTensor<YDataType> y_host_dev({m, n}, {stride, 1});
......@@ -88,19 +107,25 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem sx_buf(sx_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sy_buf(sy_host.get_element_space_size_in_bytes());
x_buf.ToDevice(x_host.data());
gamma_buf.ToDevice(gamma_host.data());
beta_buf.ToDevice(beta_host.data());
sx_buf.ToDevice(sx_host.data());
std::cout << "[" << data_type << "]"
std::cout << "[" << prec_i << "]"
<< " m:" << m << ", n:" << n << ", stride:" << stride << std::flush;
layernorm2d_fwd_traits traits{data_type, SaveMeanVar};
layernorm2d_fwd_traits traits{prec_i, prec_o, SaveMeanVar, fused_add, fused_sweep};
layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(),
fused_add != 0 ? sx_buf.GetDeviceBuffer() : nullptr,
gamma_buf.GetDeviceBuffer(),
beta_buf.GetDeviceBuffer(),
y_buf.GetDeviceBuffer(),
fused_add == 1 ? sy_buf.GetDeviceBuffer() : nullptr,
nullptr,
nullptr,
epsilon,
......@@ -111,6 +136,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
float ave_time = layernorm2d_fwd(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
if(ave_time < 0)
{
std::cout << " not supported!" << std::endl << std::flush;
return false;
}
std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(GammaDataType) * n +
sizeof(BetaDataType) * n + sizeof(YDataType) * m * n;
......@@ -122,6 +153,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(do_validation)
{
// reference
if(fused_add != 0)
{
// fused pre_add/pre_add_store
// TODO we accumulate directly to x_host for simplcity here...
std::transform(x_host.mData.cbegin(),
x_host.mData.cend(),
sx_host.mData.cbegin(),
x_host.mData.begin(),
std::plus<XDataType>{});
}
ck_tile::reference_layernorm2d_fwd<XDataType,
GammaDataType,
BetaDataType,
......@@ -133,11 +175,22 @@ bool run(const ck_tile::ArgParser& arg_parser)
y_buf.FromDevice(y_host_dev.data());
auto [rtol, atol] = get_elimit<DataType>();
ck_tile::HostTensor<SYDataType> sy_host_dev({m, n}, {stride, 1});
if(fused_add == 1)
{
sy_buf.FromDevice(sy_host_dev.data());
}
auto [rtol, atol] = get_elimit<InDataType>();
if(stride == n)
{
pass = ck_tile::check_err(
y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
if(fused_add == 1)
{
pass &= ck_tile::check_err(
sy_host_dev, x_host, std::string("ADD Error: Incorrect results!"), rtol, atol);
}
}
else
{
......@@ -153,6 +206,19 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::string("] Error: Incorrect results!"),
rtol,
atol);
if(fused_add == 1)
{
std::vector<SYDataType> sy_host_dev_row(sy_host_dev.begin() + i_r * stride,
sy_host_dev.begin() + i_r * stride + n);
std::vector<SYDataType> sy_host_ref_row(x_host.begin() + i_r * stride,
x_host.begin() + i_r * stride + n);
pass &= ck_tile::check_err(sy_host_dev_row,
sy_host_ref_row,
std::string("ADD[") + std::to_string(i_r) +
std::string("] Error: Incorrect results!"),
rtol,
atol);
}
}
}
......@@ -168,23 +234,28 @@ int main(int argc, char* argv[])
if(!result)
return -1;
const std::string data_type = arg_parser.get_str("prec");
int save_mv = arg_parser.get_int("save_mv");
if(data_type == "fp16" && save_mv)
std::string prec_i = arg_parser.get_str("prec_i");
std::string prec_o = arg_parser.get_str("prec_o");
if(prec_o == "auto")
{
prec_o = prec_i;
}
int save_mv = arg_parser.get_int("save_mv");
if(prec_i == "fp16" && prec_o == "fp16" && save_mv)
{
return run<ck_tile::half_t, true>(arg_parser) ? 0 : -2;
return run<ck_tile::half_t, ck_tile::half_t, true>(arg_parser) ? 0 : -2;
}
else if(data_type == "fp16" && !save_mv)
else if(prec_i == "fp16" && prec_o == "fp16" && !save_mv)
{
return run<ck_tile::half_t, false>(arg_parser) ? 0 : -2;
return run<ck_tile::half_t, ck_tile::half_t, false>(arg_parser) ? 0 : -2;
}
else if(data_type == "bf16" && save_mv)
else if(prec_i == "bf16" && prec_o == "bf16" && save_mv)
{
return run<ck_tile::bf16_t, true>(arg_parser) ? 0 : -2;
return run<ck_tile::bf16_t, ck_tile::bf16_t, true>(arg_parser) ? 0 : -2;
}
else if(data_type == "bf16" && !save_mv)
else if(prec_i == "bf16" && prec_o == "bf16" && !save_mv)
{
return run<ck_tile::bf16_t, true>(arg_parser) ? 0 : -2;
return run<ck_tile::bf16_t, ck_tile::bf16_t, true>(arg_parser) ? 0 : -2;
}
return -3;
......
......@@ -8,14 +8,14 @@
#include "ck_tile/ops/layernorm2d.hpp"
#include <string>
template <typename DataType>
template <typename InType, typename OutType>
struct LayerNormTypeConfig;
template <>
struct LayerNormTypeConfig<ck_tile::half_t>
template <typename OutType>
struct LayerNormTypeConfig<ck_tile::half_t, OutType>
{
using XDataType = ck_tile::half_t;
using YDataType = ck_tile::half_t;
using YDataType = OutType;
using GammaDataType = ck_tile::half_t;
using BetaDataType = ck_tile::half_t;
using MeanDataType = ck_tile::half_t;
......@@ -23,11 +23,11 @@ struct LayerNormTypeConfig<ck_tile::half_t>
using ComputeDataType = float;
};
template <>
struct LayerNormTypeConfig<ck_tile::bf16_t>
template <typename OutType>
struct LayerNormTypeConfig<ck_tile::bf16_t, OutType>
{
using XDataType = ck_tile::bf16_t;
using YDataType = ck_tile::bf16_t;
using YDataType = OutType;
using GammaDataType = ck_tile::bf16_t;
using BetaDataType = ck_tile::bf16_t;
using MeanDataType = ck_tile::bf16_t;
......@@ -40,82 +40,17 @@ struct layernorm2d_fwd_args : public ck_tile::Layernorm2dFwdHostArgs
{
};
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <typename DataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kSaveMeanInvStd_,
bool kTwoPass_>
struct layernorm2d_fwd_traits_
{
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
static constexpr ck_tile::index_t total_warps =
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;
// num of warps along m
static constexpr ck_tile::index_t BlockWarps_M = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
return total_warps * (warpSize / ThreadPerBlock_N_);
}
else
{
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / warpSize);
}
}();
// num of warps along n
static constexpr ck_tile::index_t BlockWarps_N = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
return 1;
}
else
{
static_assert(ThreadPerBlock_N_ % warpSize == 0);
return ThreadPerBlock_N_ / warpSize;
}
}();
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
using Vector = ck_tile::sequence<1, Vector_N_>;
using Shape = ck_tile::Layernorm2dShape<BlockTile, BlockWarps, WarpTile, Vector>;
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
static constexpr bool kTwoPass = kTwoPass_;
};
template <typename Traits_>
float layernorm2d_fwd_(const ck_tile::stream_config& s, layernorm2d_fwd_args a);
// This is the public API, will be generated by script
struct layernorm2d_fwd_traits
{
std::string data_type;
std::string prec_i;
std::string prec_o;
bool save_mean_var;
int fused_add; // 0:no-add, 1:pre-add, 2:post-add
int fused_sweep; // 0:no-sweep,
};
float layernorm2d_fwd(layernorm2d_fwd_traits, layernorm2d_fwd_args, const ck_tile::stream_config&);
......@@ -2,37 +2,37 @@
# run from top of ck folder
EXE=build/bin/tile_example_layernorm2d_fwd
$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
\ No newline at end of file
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
\ No newline at end of file
......@@ -3,29 +3,31 @@
EXE=./build/bin/tile_example_layernorm2d_fwd
for pr_i in "fp16" "bf16" ; do
$EXE -prec=$pr_i -m=99 -n=13
$EXE -prec=$pr_i -m=17 -n=16
$EXE -prec=$pr_i -m=1 -n=100
$EXE -prec=$pr_i -m=4 -n=128
$EXE -prec=$pr_i -m=80 -n=127
$EXE -prec=$pr_i -m=22 -n=255 -stride=256
$EXE -prec=$pr_i -m=7 -n=599
$EXE -prec=$pr_i -m=19 -n=512
$EXE -prec=$pr_i -m=33 -n=313 -stride=1000
$EXE -prec=$pr_i -m=11 -n=510
$EXE -prec=$pr_i -m=171 -n=676 -stride=818
$EXE -prec=$pr_i -m=91 -n=636
$EXE -prec=$pr_i -m=12 -n=768 -stride=800
$EXE -prec=$pr_i -m=100 -n=766 -stride=812
$EXE -prec=$pr_i -m=31 -n=1024
$EXE -prec=$pr_i -m=64 -n=1000 -stride=1004
$EXE -prec=$pr_i -m=8 -n=1501
$EXE -prec=$pr_i -m=3 -n=1826
$EXE -prec=$pr_i -m=5 -n=2040
$EXE -prec=$pr_i -m=7 -n=2734
$EXE -prec=$pr_i -m=1 -n=3182
$EXE -prec=$pr_i -m=9 -n=4096
$EXE -prec=$pr_i -m=3 -n=8192
$EXE -prec=$pr_i -m=1 -n=10547
$EXE -prec=$pr_i -m=3 -n=17134
for fadd in "0" "1" "2"; do
$EXE -prec_i=$pr_i -fadd=$fadd -m=99 -n=13
$EXE -prec_i=$pr_i -fadd=$fadd -m=17 -n=16
$EXE -prec_i=$pr_i -fadd=$fadd -m=1 -n=100
$EXE -prec_i=$pr_i -fadd=$fadd -m=4 -n=128
$EXE -prec_i=$pr_i -fadd=$fadd -m=80 -n=127
$EXE -prec_i=$pr_i -fadd=$fadd -m=22 -n=255 -stride=256
$EXE -prec_i=$pr_i -fadd=$fadd -m=7 -n=599
$EXE -prec_i=$pr_i -fadd=$fadd -m=19 -n=512
$EXE -prec_i=$pr_i -fadd=$fadd -m=33 -n=313 -stride=1000
$EXE -prec_i=$pr_i -fadd=$fadd -m=11 -n=510
$EXE -prec_i=$pr_i -fadd=$fadd -m=171 -n=676 -stride=818
$EXE -prec_i=$pr_i -fadd=$fadd -m=91 -n=636
$EXE -prec_i=$pr_i -fadd=$fadd -m=12 -n=768 -stride=800
$EXE -prec_i=$pr_i -fadd=$fadd -m=100 -n=766 -stride=812
$EXE -prec_i=$pr_i -fadd=$fadd -m=31 -n=1024
$EXE -prec_i=$pr_i -fadd=$fadd -m=64 -n=1000 -stride=1004
$EXE -prec_i=$pr_i -fadd=$fadd -m=8 -n=1501
$EXE -prec_i=$pr_i -fadd=$fadd -m=3 -n=1826
$EXE -prec_i=$pr_i -fadd=$fadd -m=5 -n=2040
$EXE -prec_i=$pr_i -fadd=$fadd -m=7 -n=2734
$EXE -prec_i=$pr_i -fadd=$fadd -m=1 -n=3182
$EXE -prec_i=$pr_i -fadd=$fadd -m=9 -n=4096
$EXE -prec_i=$pr_i -fadd=$fadd -m=3 -n=8192
$EXE -prec_i=$pr_i -fadd=$fadd -m=1 -n=10547
$EXE -prec_i=$pr_i -fadd=$fadd -m=3 -n=17134
done
done
......@@ -80,6 +80,13 @@ CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view,
return null_tile_window<remove_cvref_t<WindowLengths>>{window_lengths};
}
template <typename WindowLengths, typename StaticTileDistribution>
CK_TILE_DEVICE constexpr auto make_tile_window(const null_tile_window<WindowLengths>& t,
const StaticTileDistribution&)
{
return t;
}
template <typename WindowLengths>
CK_TILE_DEVICE void
move_tile_window(null_tile_window<WindowLengths>&,
......
......@@ -9,4 +9,5 @@
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
......@@ -5,17 +5,20 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp"
namespace ck_tile {
// host side args
struct Layernorm2dFwdHostArgs
{
const void* p_x;
const void* p_x; // input
const void* p_sx; // shortcut input, set to nullptr if no
const void* p_gamma;
const void* p_beta;
void* p_y;
void* p_y; // output
void* p_sy; // shortcut output, set to nullptr if no
void* p_mean;
void* p_invStd;
......@@ -27,10 +30,11 @@ struct Layernorm2dFwdHostArgs
};
// TODO: Extract some type to wrapper class
template <typename Pipeline_>
template <typename Pipeline_, typename Epilogue_>
struct Layernorm2dFwd
{
using Pipeline = remove_cvref_t<Pipeline_>;
using Epilogue = remove_cvref_t<Epilogue_>;
using Problem = typename Pipeline::Problem;
using XDataType = remove_cvref_t<typename Problem::XDataType>;
......@@ -41,17 +45,23 @@ struct Layernorm2dFwd
using MeanDataType = remove_cvref_t<typename Problem::MeanDataType>;
using InvStdDataType = remove_cvref_t<typename Problem::InvStdDataType>;
// for simplicity, shortcut input/output type is same as X
using SXDataType = XDataType;
using SYDataType = XDataType;
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, null_type>;
static constexpr bool kHasBeta = !std::is_same_v<BetaDataType, null_type>;
static constexpr bool kSaveMeanInvStd = Problem::kSaveMeanInvStd;
static constexpr bool kSaveMean = Problem::kSaveMeanInvStd;
static constexpr bool kSaveInvStd = Problem::kSaveMeanInvStd;
static constexpr index_t Block_M = Problem::BlockShape::Block_M;
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
static constexpr bool kPadM = false; // always no need to pad along M
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kTwoPass = Problem::kTwoPass;
static constexpr bool kSaveMeanInvStd = Problem::Traits::kSaveMeanInvStd;
static constexpr bool kSaveMean = Problem::Traits::kSaveMeanInvStd;
static constexpr bool kSaveInvStd = Problem::Traits::kSaveMeanInvStd;
static constexpr index_t Block_M = Problem::BlockShape::Block_M;
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
static constexpr bool kPadM = false; // always no need to pad along M
static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool kTwoPass = Problem::Traits::kTwoPass;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedSweep = Problem::Traits::kFusedSweep;
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
......@@ -62,11 +72,13 @@ struct Layernorm2dFwd
struct Kargs
{
const void* p_x;
const void* p_x; // input
const void* p_sx; // shortcut input, set to nullptr if no
const void* p_gamma;
const void* p_beta;
void* p_y;
void* p_y; // output
void* p_sy; // shortcut output, set to nullptr if no
void* p_mean;
void* p_invStd;
......@@ -81,9 +93,11 @@ struct Layernorm2dFwd
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
{
return Kargs{hargs.p_x,
hargs.p_sx,
hargs.p_gamma,
hargs.p_beta,
hargs.p_y,
hargs.p_sy,
hargs.p_mean,
hargs.p_invStd,
hargs.epsilon,
......@@ -113,17 +127,19 @@ struct Layernorm2dFwd
CK_TILE_HOST static std::string GetName()
{
// clang-format off
// clang-format off
#define _SS_ std::string
#define _TS_ std::to_string
using S_ = typename Problem::BlockShape;
auto surfix = [&] () {
std::string n;
if (kFusedAdd != Layernorm2dFusedAddEnum::NO_ADD) n += _SS_("_") + Layernorm2dFusedAddEnumName<kFusedAdd>::name;
if (kFusedSweep != Layernorm2dFusedSweepEnum::NO_SWEEP) n += _SS_("_") + Layernorm2dFusedSweepEnumName<kFusedSweep>::name;
if (kPadN) n += "_pn";
if (kSaveMeanInvStd) n += "_mv";
if (kTwoPass) n += "_2p";
return n; }();
#define _SS_ std::string
#define _TS_ std::to_string
return _SS_("layernorm2d_fwd_") + _SS_(t2s<XDataType>::name) + "_" +
_TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
_TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
......@@ -153,6 +169,31 @@ struct Layernorm2dFwd
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}();
const auto sx_window = [&]() {
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
{
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const SXDataType*>(kargs.p_sx),
make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1),
number<Vector_N>{},
number<1>{});
// NOTE: we don't do any pad in this kernel for loading, assume that inside kernel
// will check the max count dynamically
const auto tmp2_ = pad_tensor_view(tmp_,
make_tuple(number<Block_M>{}, number<Block_N>{}),
sequence<false, false>{});
return make_tile_window(
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}
else
{
return make_null_tile_window(make_tuple(number<Block_M>{}, number<Block_N>{}));
}
}();
const auto gamma_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const GammaDataType*>(kargs.p_gamma),
......@@ -194,6 +235,28 @@ struct Layernorm2dFwd
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}();
auto sy_window = [&]() {
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
{
auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<SYDataType*>(kargs.p_sy),
make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1),
number<Vector_N>{},
number<1>{});
auto tmp2_ = pad_tensor_view(tmp_,
make_tuple(number<Block_M>{}, number<Block_N>{}),
sequence<kPadM, kPadN>{});
return make_tile_window(
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}
else
{
return make_null_tile_window(make_tuple(number<Block_M>{}, number<Block_N>{}));
}
}();
auto mean_window = [&]() {
if constexpr(kSaveMean)
{
......@@ -235,14 +298,17 @@ struct Layernorm2dFwd
__shared__ char smem[GetSmemSize()];
Pipeline{}(x_window,
sx_window,
gamma_window,
beta_window,
y_window,
sy_window,
mean_window,
inv_std_window,
static_cast<const ComputeDataType>(kargs.epsilon),
kargs.n,
smem);
smem,
Epilogue{});
}
};
......
......@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp"
#include <string>
#include <type_traits>
......@@ -24,14 +25,19 @@ struct Layernorm2dFwdPipelineOnePass
using MeanDataType = ck_tile::remove_cvref_t<typename Problem::MeanDataType>;
using InvStdDataType = ck_tile::remove_cvref_t<typename Problem::InvStdDataType>;
using SXDataType = XDataType;
using SYDataType = XDataType;
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
static constexpr bool kHasBeta = !std::is_same_v<BetaDataType, ck_tile::null_type>;
static constexpr bool kSaveMean = Problem::kSaveMeanInvStd;
static constexpr bool kSaveInvStd = Problem::kSaveMeanInvStd;
static constexpr bool kSaveMean = Problem::Traits::kSaveMeanInvStd;
static constexpr bool kSaveInvStd = Problem::Traits::kSaveMeanInvStd;
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedSweep = Problem::Traits::kFusedSweep;
static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync)
......@@ -46,20 +52,26 @@ struct Layernorm2dFwdPipelineOnePass
}
template <typename XWindow,
typename SXWindow,
typename GammaWindow,
typename BetaWindow,
typename YWindow,
typename SYWindow,
typename MeanWindow,
typename InvStdWindow>
typename InvStdWindow,
typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const SXWindow& sx_window_,
const GammaWindow& gamma_window_,
const BetaWindow& beta_window_,
YWindow& y_window,
YWindow& y_window_,
const SYWindow& sy_window_,
MeanWindow& mean_window,
InvStdWindow& inv_std_window,
ComputeDataType epsilon,
ck_tile::index_t row_size,
void* smem) const
void* smem,
Epilogue) const
{
const auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
......@@ -67,8 +79,14 @@ struct Layernorm2dFwdPipelineOnePass
gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
const auto beta_window = make_tile_window(
beta_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
const auto sx_window =
make_tile_window(sx_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto sy_window =
make_tile_window(sy_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto x = load_tile(x_window);
auto sx = load_tile(sx_window);
const auto x = load_tile(x_window);
int cur_count = 0;
int max_count =
block_tile_welford_calculate_max_count<typename Problem::BlockShape>(row_size);
......@@ -81,6 +99,17 @@ struct Layernorm2dFwdPipelineOnePass
const auto gamma = load_tile(gamma_window);
const auto beta = load_tile(beta_window);
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
{
sweep_tile(sx, [&](auto idx) {
// compute x = sx + x
x(idx) = type_convert<SYDataType>(sx(idx)) + type_convert<SYDataType>(x(idx));
});
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
store_tile(sy_window, x);
}
// compute welford each-thread->cross-lane->cross-warp
auto [mean, var] = block_welford(x, cur_count, max_count);
block_welford_sync(mean, var, cur_count);
......@@ -100,8 +129,8 @@ struct Layernorm2dFwdPipelineOnePass
store_tile(inv_std_window, cast_tile<InvStdDataType>(inv_std));
// layernorm computation
auto y = make_static_distributed_tensor<YDataType>(x.get_tile_distribution());
sweep_tile(y, [&, mean_ = mean](auto idx) {
auto ln = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
sweep_tile(ln, [&, mean_ = mean](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
......@@ -109,11 +138,12 @@ struct Layernorm2dFwdPipelineOnePass
const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);
const auto x_ = type_convert<ComputeDataType>(x[idx]);
auto y_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
auto ln_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
y(idx) = type_convert<YDataType>(y_);
ln(idx) = ln_;
});
store_tile(y_window, y);
Epilogue{}(y_window_, ln);
}
};
} // namespace ck_tile
......@@ -15,9 +15,7 @@ template <typename XDataType_,
typename MeanDataType_,
typename InvStdDataType_,
typename BlockShape_,
bool kPadN_,
bool kSaveMeanInvStd_,
bool kTwoPass_>
typename Traits_>
struct Layernorm2dFwdPipelineProblem
{
using XDataType = remove_cvref_t<XDataType_>;
......@@ -32,9 +30,7 @@ struct Layernorm2dFwdPipelineProblem
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
static constexpr bool kTwoPass = kTwoPass_;
using Traits = remove_cvref_t<Traits_>;
};
} // namespace ck_tile
......@@ -24,14 +24,19 @@ struct Layernorm2dFwdPipelineTwoPass
using MeanDataType = ck_tile::remove_cvref_t<typename Problem::MeanDataType>;
using InvStdDataType = ck_tile::remove_cvref_t<typename Problem::InvStdDataType>;
using SXDataType = XDataType;
using SYDataType = XDataType;
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
static constexpr bool kHasBeta = !std::is_same_v<BetaDataType, ck_tile::null_type>;
static constexpr bool kSaveMean = Problem::kSaveMeanInvStd;
static constexpr bool kSaveInvStd = Problem::kSaveMeanInvStd;
static constexpr bool kSaveMean = Problem::Traits::kSaveMeanInvStd;
static constexpr bool kSaveInvStd = Problem::Traits::kSaveMeanInvStd;
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedSweep = Problem::Traits::kFusedSweep;
static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync)
......@@ -46,20 +51,26 @@ struct Layernorm2dFwdPipelineTwoPass
}
template <typename XWindow,
typename SXWindow,
typename GammaWindow,
typename BetaWindow,
typename YWindow,
typename SYWindow,
typename MeanWindow,
typename InvStdWindow>
typename InvStdWindow,
typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const SXWindow& sx_window_,
const GammaWindow& gamma_window_,
const BetaWindow& beta_window_,
YWindow& y_window,
const SYWindow& sy_window_,
MeanWindow& mean_window,
InvStdWindow& inv_std_window,
ComputeDataType epsilon,
ck_tile::index_t row_size,
void* smem) const
void* smem,
Epilogue) const
{
auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
......@@ -67,6 +78,10 @@ struct Layernorm2dFwdPipelineTwoPass
gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
auto beta_window = make_tile_window(
beta_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
auto sx_window =
make_tile_window(sx_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto sy_window =
make_tile_window(sy_window_, Policy::template MakeXBlockTileDistribution<Problem>());
// Problem::BlockShape
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
......@@ -93,9 +108,25 @@ struct Layernorm2dFwdPipelineTwoPass
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
const auto x = load_tile(x_window);
block_welford(x, mean, var, cur_count, max_count);
auto x = load_tile(x_window);
auto sx = load_tile(sx_window);
move_tile_window(x_window, {0, Block_N});
move_tile_window(sx_window, {0, Block_N});
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
{
sweep_tile(sx, [&](auto idx) {
// compute x = sx + x
x(idx) = type_convert<SYDataType>(sx(idx)) + type_convert<SYDataType>(x(idx));
});
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
{
store_tile(sy_window, x);
move_tile_window(sy_window, {0, Block_N});
}
}
block_welford(x, mean, var, cur_count, max_count);
}
block_welford_sync(mean, var, cur_count);
......@@ -121,6 +152,7 @@ struct Layernorm2dFwdPipelineTwoPass
// x_window.foo();
// gamma_window.foo();
move_tile_window(x_window, {0, -Block_N});
move_tile_window(sx_window, {0, -Block_N});
move_tile_window(gamma_window, {stride_to_right_most_window});
move_tile_window(beta_window, {stride_to_right_most_window});
move_tile_window(y_window, {0, stride_to_right_most_window});
......@@ -128,14 +160,23 @@ struct Layernorm2dFwdPipelineTwoPass
// layernorm computation
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
const auto x = load_tile(x_window);
auto x = load_tile(x_window);
auto sx = load_tile(sx_window);
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
{
sweep_tile(sx, [&](auto idx) {
// compute x = sx + x
x(idx) = type_convert<SYDataType>(sx(idx)) + type_convert<SYDataType>(x(idx));
});
}
// load gamma/beta (TODO: support no gamma/beta?)
const auto gamma = load_tile(gamma_window);
const auto beta = load_tile(beta_window);
auto y = make_static_distributed_tensor<YDataType>(x.get_tile_distribution());
auto ln = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
sweep_tile(y, [&, mean_ = mean](auto idx) {
sweep_tile(ln, [&, mean_ = mean](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
......@@ -143,14 +184,15 @@ struct Layernorm2dFwdPipelineTwoPass
const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);
const auto x_ = type_convert<ComputeDataType>(x[idx]);
auto y_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
auto ln_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
y(idx) = type_convert<YDataType>(y_);
ln(idx) = ln_;
});
store_tile(y_window, y);
Epilogue{}(y_window, ln);
move_tile_window(x_window, {0, -Block_N});
move_tile_window(sx_window, {0, -Block_N});
move_tile_window(gamma_window, {-Block_N});
move_tile_window(beta_window, {-Block_N});
move_tile_window(y_window, {0, -Block_N});
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
enum class Layernorm2dFusedAddEnum
{
NO_ADD = 0,
// fused add before layernorm (prenorm), and store result to global
PRE_ADD_STORE = 1,
PRE_NORM_ADD = PRE_ADD_STORE,
// fused add before layernorm (postnorm), but not store result
PRE_ADD = 2,
POST_NORM_ADD = PRE_ADD,
};
// clang-format off
template<Layernorm2dFusedAddEnum E> struct Layernorm2dFusedAddEnumName;
template<> struct Layernorm2dFusedAddEnumName<Layernorm2dFusedAddEnum::NO_ADD> { static constexpr const char * name = "no"; };
template<> struct Layernorm2dFusedAddEnumName<Layernorm2dFusedAddEnum::PRE_ADD_STORE> { static constexpr const char * name = "pras"; };
template<> struct Layernorm2dFusedAddEnumName<Layernorm2dFusedAddEnum::PRE_ADD> { static constexpr const char * name = "pra"; };
// clang-format on
enum class Layernorm2dFusedSweepEnum
{
NO_SWEEP = 0,
RENORM = 1,
DYNAMIC_QUANT = 2,
};
// clang-format off
template<Layernorm2dFusedSweepEnum E> struct Layernorm2dFusedSweepEnumName;
template<> struct Layernorm2dFusedSweepEnumName<Layernorm2dFusedSweepEnum::NO_SWEEP> { static constexpr const char * name = "no"; };
template<> struct Layernorm2dFusedSweepEnumName<Layernorm2dFusedSweepEnum::RENORM> { static constexpr const char * name = "renorm"; };
template<> struct Layernorm2dFusedSweepEnumName<Layernorm2dFusedSweepEnum::DYNAMIC_QUANT> { static constexpr const char * name = "dequant"; };
// clang-format on
template <bool kPadN_,
bool kSaveMeanInvStd_,
bool kTwoPass_,
Layernorm2dFusedAddEnum kFusedAdd_,
Layernorm2dFusedSweepEnum kFusedSweep_>
struct Layernorm2dFwdTraits
{
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
static constexpr bool kTwoPass = kTwoPass_;
static constexpr Layernorm2dFusedAddEnum kFusedAdd = kFusedAdd_;
static constexpr Layernorm2dFusedSweepEnum kFusedSweep = kFusedSweep_;
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment