Unverified Commit c3a4800c authored by carlushuang's avatar carlushuang Committed by GitHub
Browse files

[CK_TILE] layernorm support fused-quant/fused-add (#1604)

* add prenorm/postnorm support, refactor using generate.py

* update README

* update README

* fix format

* update some description and fix format

* update format

* format

* use non-raw for loading

* format and update n4096

* dynamic-quant ready

* update readme

* support fused dynamic-quant

* update fused-quant, with smooth

* update README

* update args

* update some based on comment
parent 9a8a5213
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 256, 8, true, false, true>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 4, true, false, true>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 1024, 2, true, false, true>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 1024, 1, true, false, true>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 8, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 4, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 2, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 8, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 1, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 2, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 4, 64, 4, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 6, 4, 64, 2, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 12, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm2d_fwd.hpp"
#include <iostream>
#pragma once
using S = ck_tile::stream_config;
using A = layernorm2d_fwd_args;
template <typename DataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kSaveMeanInvStd_,
bool kTwoPass_>
using trait_ = layernorm2d_fwd_traits_<DataType_,
Repeat_M_,
Repeat_N_,
ThreadPerBlock_M_,
ThreadPerBlock_N_,
Vector_N_,
kPadN_,
kSaveMeanInvStd_,
kTwoPass_>;
template <typename Traits_>
float layernorm2d_fwd_(const S& s, A a)
{
using DataType = typename Traits_::DataType;
using PipelineProblem = ck_tile::Layernorm2dFwdPipelineProblem<
typename LayerNormTypeConfig<DataType>::XDataType,
typename LayerNormTypeConfig<DataType>::GammaDataType,
typename LayerNormTypeConfig<DataType>::BetaDataType,
typename LayerNormTypeConfig<DataType>::ComputeDataType,
typename LayerNormTypeConfig<DataType>::YDataType,
typename LayerNormTypeConfig<DataType>::MeanDataType,
typename LayerNormTypeConfig<DataType>::InvStdDataType,
typename Traits_::Shape,
Traits_::kPadN,
Traits_::kSaveMeanInvStd,
Traits_::kTwoPass>;
using OnePassPipeline = ck_tile::Layernorm2dFwdPipelineOnePass<PipelineProblem>;
using TwoPassPipeline = ck_tile::Layernorm2dFwdPipelineTwoPass<PipelineProblem>;
using Pipeline = std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>;
using Kernel = ck_tile::Layernorm2dFwd<Pipeline>;
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
auto kargs = Kernel::MakeKargs(a);
if(s.log_level_ > 0)
std::cout << ", " << Kernel::GetName() << std::flush;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
#include "layernorm2d_fwd.hpp" #include "layernorm2d_fwd.hpp"
#include <algorithm>
#include <cstring> #include <cstring>
// different threshold for different dtype // different threshold for different dtype
...@@ -29,7 +30,16 @@ auto create_args(int argc, char* argv[]) ...@@ -29,7 +30,16 @@ auto create_args(int argc, char* argv[])
.insert("save_mv", "0", "save mean/variance(invstd) or not. set to 1 in training case") .insert("save_mv", "0", "save mean/variance(invstd) or not. set to 1 in training case")
.insert("v", "1", "cpu validation or not") .insert("v", "1", "cpu validation or not")
.insert("kname", "1", "print kernel name or not") .insert("kname", "1", "print kernel name or not")
.insert("prec", "fp16", "precision") .insert("prec_i", "fp16", "input precision")
.insert("prec_o", "auto", "output precision, set auto will be the same as input")
.insert("prec_sx",
"auto",
"output quant scale type, set auto will use fp32. used when fquant=1")
.insert("prec_sy",
"auto",
"output quant scale type, set auto will use fp32. used when fquant=1 or 2")
.insert("fadd", "0", "fused-add, 0:no fused add, 1:preadd+store, 2:preadd only")
.insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant")
.insert("warmup", "5", "cold iter") .insert("warmup", "5", "cold iter")
.insert("repeat", "20", "hot iter"); .insert("repeat", "20", "hot iter");
...@@ -37,7 +47,11 @@ auto create_args(int argc, char* argv[]) ...@@ -37,7 +47,11 @@ auto create_args(int argc, char* argv[])
return std::make_tuple(result, arg_parser); return std::make_tuple(result, arg_parser);
} }
template <typename DataType, bool SaveMeanVar> template <typename InDataType,
typename OutDataType,
typename XScaleDataType,
typename YScaleDataType,
bool SaveMeanVar>
bool run(const ck_tile::ArgParser& arg_parser) bool run(const ck_tile::ArgParser& arg_parser)
{ {
ck_tile::index_t m = arg_parser.get_int("m"); ck_tile::index_t m = arg_parser.get_int("m");
...@@ -45,21 +59,46 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -45,21 +59,46 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t stride = arg_parser.get_int("stride"); ck_tile::index_t stride = arg_parser.get_int("stride");
if(stride < 0) if(stride < 0)
stride = n; stride = n;
float epsilon = arg_parser.get_float("e"); float epsilon = arg_parser.get_float("e");
std::string data_type = arg_parser.get_str("prec"); std::string prec_i = arg_parser.get_str("prec_i");
int kname = arg_parser.get_int("kname"); std::string prec_o = arg_parser.get_str("prec_o");
int do_validation = arg_parser.get_int("v"); std::string prec_sx = arg_parser.get_str("prec_sx");
int warmup = arg_parser.get_int("warmup"); std::string prec_sy = arg_parser.get_str("prec_sy");
int repeat = arg_parser.get_int("repeat"); if(prec_o == "auto")
{
prec_o = prec_i;
}
if(prec_sx == "auto")
{
prec_sx = "fp32";
}
if(prec_sy == "auto")
{
prec_sy = "fp32";
}
int kname = arg_parser.get_int("kname");
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
int fused_add = arg_parser.get_int("fadd");
int fused_quant = arg_parser.get_int("fquant");
if(fused_quant == 1 && prec_o != "int8")
{
std::cout << "if fused_quant is 1, only support \"-prec_o=int8\" case" << std::endl;
return false;
}
assert(stride >= n); assert(stride >= n);
using TypeConfig = LayerNormTypeConfig<DataType>; using TypeConfig = LayerNormTypeConfig<InDataType, OutDataType, XScaleDataType, YScaleDataType>;
using XDataType = typename TypeConfig::XDataType; using XDataType = typename TypeConfig::XDataType;
using YDataType = typename TypeConfig::YDataType; using YDataType = typename TypeConfig::YDataType;
using GammaDataType = typename TypeConfig::GammaDataType; using GammaDataType = typename TypeConfig::GammaDataType;
using BetaDataType = typename TypeConfig::BetaDataType; using BetaDataType = typename TypeConfig::BetaDataType;
using XResidualDataType = XDataType;
using YResidualDataType = XDataType;
using MeanDataType = using MeanDataType =
std::conditional_t<SaveMeanVar, typename TypeConfig::MeanDataType, ck_tile::null_type>; std::conditional_t<SaveMeanVar, typename TypeConfig::MeanDataType, ck_tile::null_type>;
...@@ -73,36 +112,72 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -73,36 +112,72 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<GammaDataType> gamma_host({n}); ck_tile::HostTensor<GammaDataType> gamma_host({n});
ck_tile::HostTensor<BetaDataType> beta_host({n}); ck_tile::HostTensor<BetaDataType> beta_host({n});
ck_tile::HostTensor<XResidualDataType> x_residual_host({m, n}, {stride, 1});
ck_tile::HostTensor<YResidualDataType> y_residual_host({m, n}, {stride, 1});
ck_tile::HostTensor<YDataType> y_host_ref({m, n}, {stride, 1}); ck_tile::HostTensor<YDataType> y_host_ref({m, n}, {stride, 1});
ck_tile::HostTensor<YDataType> y_host_dev({m, n}, {stride, 1}); ck_tile::HostTensor<YDataType> y_host_dev({m, n}, {stride, 1});
ck_tile::HostTensor<MeanDataType> mean_host_ref({m}); ck_tile::HostTensor<MeanDataType> mean_host_ref({m});
ck_tile::HostTensor<InvStdDataType> invStd_host_ref({m}); ck_tile::HostTensor<InvStdDataType> invStd_host_ref({m});
ck_tile::HostTensor<YScaleDataType> y_scale_host_ref({m});
ck_tile::HostTensor<YScaleDataType> y_scale_host_dev({m});
ck_tile::HostTensor<XScaleDataType> x_scale_host({n});
ck_tile::HostTensor<XScaleDataType> x_scale_host_dev({n});
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host); ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host); ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host);
ck_tile::FillUniformDistribution<BetaDataType>{-.5f, .5f}(beta_host); ck_tile::FillUniformDistribution<BetaDataType>{-.5f, .5f}(beta_host);
ck_tile::FillUniformDistribution<XScaleDataType>{-1.f, 1.f}(x_scale_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_scale_buf(y_scale_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem x_scale_buf(x_scale_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem x_residual_buf(x_residual_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_residual_buf(y_residual_host.get_element_space_size_in_bytes());
x_buf.ToDevice(x_host.data()); x_buf.ToDevice(x_host.data());
gamma_buf.ToDevice(gamma_host.data()); gamma_buf.ToDevice(gamma_host.data());
beta_buf.ToDevice(beta_host.data()); beta_buf.ToDevice(beta_host.data());
x_residual_buf.ToDevice(x_residual_host.data());
x_scale_buf.ToDevice(x_scale_host.data());
std::cout << "[" << data_type << "]" auto prec_str = [&]() {
auto base_str = prec_i;
if(prec_i != prec_o)
{
base_str += "|" + prec_o;
}
if(fused_quant == 1)
{
base_str += std::string("(") + prec_sy + ")";
}
return base_str;
}();
std::cout << "[" << prec_str << "]"
<< " m:" << m << ", n:" << n << ", stride:" << stride << std::flush; << " m:" << m << ", n:" << n << ", stride:" << stride << std::flush;
layernorm2d_fwd_traits traits{data_type, SaveMeanVar}; layernorm2d_fwd_traits traits{
prec_i, prec_o, prec_sx, prec_sy, SaveMeanVar, fused_add, fused_quant};
layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(), layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(),
fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr,
fused_quant == 1 ? x_scale_buf.GetDeviceBuffer() : nullptr,
gamma_buf.GetDeviceBuffer(), gamma_buf.GetDeviceBuffer(),
beta_buf.GetDeviceBuffer(), beta_buf.GetDeviceBuffer(),
y_buf.GetDeviceBuffer(), y_buf.GetDeviceBuffer(),
nullptr, fused_add == 1 ? y_residual_buf.GetDeviceBuffer() : nullptr,
nullptr, fused_quant != 0 ? y_scale_buf.GetDeviceBuffer() : nullptr,
nullptr, // p_mean, unsupported yet
nullptr, // p_invStd, unsupported yet
epsilon, epsilon,
m, m,
n, n,
...@@ -111,6 +186,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -111,6 +186,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
float ave_time = layernorm2d_fwd( float ave_time = layernorm2d_fwd(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
if(ave_time < 0)
{
std::cout << " not supported!" << std::endl << std::flush;
return false;
}
std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(GammaDataType) * n + std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(GammaDataType) * n +
sizeof(BetaDataType) * n + sizeof(YDataType) * m * n; sizeof(BetaDataType) * n + sizeof(YDataType) * m * n;
...@@ -122,6 +203,17 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -122,6 +203,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(do_validation) if(do_validation)
{ {
// reference // reference
if(fused_add != 0)
{
// fused pre_add/pre_add_store
// TODO we accumulate directly to x_host for simplcity here...
std::transform(x_host.mData.cbegin(),
x_host.mData.cend(),
x_residual_host.mData.cbegin(),
x_host.mData.begin(),
std::plus<XDataType>{});
}
ck_tile::reference_layernorm2d_fwd<XDataType, ck_tile::reference_layernorm2d_fwd<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
...@@ -131,13 +223,80 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -131,13 +223,80 @@ bool run(const ck_tile::ArgParser& arg_parser)
InvStdDataType>( InvStdDataType>(
x_host, gamma_host, beta_host, y_host_ref, mean_host_ref, invStd_host_ref, epsilon); x_host, gamma_host, beta_host, y_host_ref, mean_host_ref, invStd_host_ref, epsilon);
if(fused_quant != 0)
{
auto dquant_functor = [&](int m_, auto& o_, auto& acc_) {
int N_ = acc_.mDesc.get_lengths()[1];
if(fused_quant == 1)
{
for(int n_ = 0; n_ < N_; n_++)
{
// input smooth outlier
acc_(m_, n_) =
acc_(m_, n_) * ck_tile::type_convert<ComputeDataType>(x_scale_host(n_));
}
}
ComputeDataType absmax = static_cast<ComputeDataType>(0);
for(int n_ = 0; n_ < N_; n_++)
{
const auto a = ck_tile::abs(acc_(m_, n_));
absmax = a > absmax ? a : absmax;
}
// printf("cpu:absmax:%f\n", absmax);
ComputeDataType y_scale = absmax / static_cast<ComputeDataType>(127.0);
y_scale_host_ref(m_) = ck_tile::type_convert<YScaleDataType>(y_scale);
for(int n_ = 0; n_ < N_; n_++)
{
o_(m_, n_) = ck_tile::type_convert<YDataType>(acc_(m_, n_) / y_scale);
}
};
ck_tile::reference_layernorm2d_fwd<XDataType,
GammaDataType,
BetaDataType,
ComputeDataType,
YDataType,
MeanDataType,
InvStdDataType>(x_host,
gamma_host,
beta_host,
y_host_ref,
mean_host_ref,
invStd_host_ref,
epsilon,
dquant_functor);
}
else
{
ck_tile::reference_layernorm2d_fwd<XDataType,
GammaDataType,
BetaDataType,
ComputeDataType,
YDataType,
MeanDataType,
InvStdDataType>(
x_host, gamma_host, beta_host, y_host_ref, mean_host_ref, invStd_host_ref, epsilon);
}
y_buf.FromDevice(y_host_dev.data()); y_buf.FromDevice(y_host_dev.data());
auto [rtol, atol] = get_elimit<DataType>(); ck_tile::HostTensor<YResidualDataType> sy_host_dev({m, n}, {stride, 1});
if(fused_add == 1)
{
y_residual_buf.FromDevice(sy_host_dev.data());
}
auto [rtol, atol] = get_elimit<InDataType>();
if(stride == n) if(stride == n)
{ {
pass = ck_tile::check_err( pass = ck_tile::check_err(
y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol); y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
if(fused_add == 1)
{
pass &= ck_tile::check_err(
sy_host_dev, x_host, std::string("ADD Error: Incorrect results!"), rtol, atol);
}
} }
else else
{ {
...@@ -153,8 +312,30 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -153,8 +312,30 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::string("] Error: Incorrect results!"), std::string("] Error: Incorrect results!"),
rtol, rtol,
atol); atol);
if(fused_add == 1)
{
std::vector<YResidualDataType> sy_host_dev_row(
sy_host_dev.begin() + i_r * stride, sy_host_dev.begin() + i_r * stride + n);
std::vector<YResidualDataType> sy_host_ref_row(
x_host.begin() + i_r * stride, x_host.begin() + i_r * stride + n);
pass &= ck_tile::check_err(sy_host_dev_row,
sy_host_ref_row,
std::string("ADD[") + std::to_string(i_r) +
std::string("] Error: Incorrect results!"),
rtol,
atol);
}
} }
} }
if(fused_quant == 1)
{
y_scale_buf.FromDevice(y_scale_host_dev.data());
pass &= ck_tile::check_err(y_scale_host_dev,
y_scale_host_ref,
std::string("SCALE Error: Incorrect results!"),
rtol,
atol);
}
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
} }
...@@ -168,23 +349,56 @@ int main(int argc, char* argv[]) ...@@ -168,23 +349,56 @@ int main(int argc, char* argv[])
if(!result) if(!result)
return -1; return -1;
const std::string data_type = arg_parser.get_str("prec"); std::string prec_i = arg_parser.get_str("prec_i");
int save_mv = arg_parser.get_int("save_mv"); std::string prec_o = arg_parser.get_str("prec_o");
if(data_type == "fp16" && save_mv) std::string prec_sx = arg_parser.get_str("prec_sx");
std::string prec_sy = arg_parser.get_str("prec_sy");
if(prec_o == "auto")
{
prec_o = prec_i;
}
if(prec_sx == "auto")
{ {
return run<ck_tile::half_t, true>(arg_parser) ? 0 : -2; prec_sx = "fp32";
} }
else if(data_type == "fp16" && !save_mv) if(prec_sy == "auto")
{ {
return run<ck_tile::half_t, false>(arg_parser) ? 0 : -2; prec_sy = "fp32";
} }
else if(data_type == "bf16" && save_mv) int save_mv = arg_parser.get_int("save_mv");
// no dynamic quant case
if(prec_i == "fp16" && prec_o == "fp16" && prec_sx == "fp32" && prec_sy == "fp32" && save_mv)
{
return run<ck_tile::half_t, ck_tile::half_t, float, float, true>(arg_parser) ? 0 : -2;
}
else if(prec_i == "fp16" && prec_o == "fp16" && prec_sx == "fp32" && prec_sy == "fp32" &&
!save_mv)
{
return run<ck_tile::half_t, ck_tile::half_t, float, float, false>(arg_parser) ? 0 : -2;
}
else if(prec_i == "bf16" && prec_o == "bf16" && prec_sx == "fp32" && prec_sy == "fp32" &&
save_mv)
{
return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, true>(arg_parser) ? 0 : -2;
}
else if(prec_i == "bf16" && prec_o == "bf16" && prec_sx == "fp32" && prec_sy == "fp32" &&
!save_mv)
{
return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, true>(arg_parser) ? 0 : -2;
}
// dynamic quant case, only in inference
else if(prec_i == "fp16" && prec_o == "int8" && prec_sx == "fp32" && prec_sy == "fp32" &&
!save_mv)
{ {
return run<ck_tile::bf16_t, true>(arg_parser) ? 0 : -2; return run<ck_tile::half_t, ck_tile::int8_t, float, float, false>(arg_parser) ? 0 : -2;
} }
else if(data_type == "bf16" && !save_mv) else if(prec_i == "bf16" && prec_o == "int8" && prec_sx == "fp32" && prec_sy == "fp32" &&
!save_mv)
{ {
return run<ck_tile::bf16_t, true>(arg_parser) ? 0 : -2; return run<ck_tile::bf16_t, ck_tile::int8_t, float, float, false>(arg_parser) ? 0 : -2;
} }
return -3; return -3;
......
...@@ -8,31 +8,35 @@ ...@@ -8,31 +8,35 @@
#include "ck_tile/ops/layernorm2d.hpp" #include "ck_tile/ops/layernorm2d.hpp"
#include <string> #include <string>
template <typename DataType> template <typename InType, typename OutType, typename XScaleDataType_, typename YScaleDataType_>
struct LayerNormTypeConfig; struct LayerNormTypeConfig;
template <> template <typename OutType, typename XScaleDataType_, typename YScaleDataType_>
struct LayerNormTypeConfig<ck_tile::half_t> struct LayerNormTypeConfig<ck_tile::half_t, OutType, XScaleDataType_, YScaleDataType_>
{ {
using XDataType = ck_tile::half_t; using XDataType = ck_tile::half_t;
using YDataType = ck_tile::half_t; using YDataType = OutType;
using GammaDataType = ck_tile::half_t; using GammaDataType = ck_tile::half_t;
using BetaDataType = ck_tile::half_t; using BetaDataType = ck_tile::half_t;
using MeanDataType = ck_tile::half_t; using MeanDataType = ck_tile::half_t;
using InvStdDataType = ck_tile::half_t; using InvStdDataType = ck_tile::half_t;
using ComputeDataType = float; using ComputeDataType = float;
using XScaleDataType = XScaleDataType_;
using YScaleDataType = YScaleDataType_;
}; };
template <> template <typename OutType, typename XScaleDataType_, typename YScaleDataType_>
struct LayerNormTypeConfig<ck_tile::bf16_t> struct LayerNormTypeConfig<ck_tile::bf16_t, OutType, XScaleDataType_, YScaleDataType_>
{ {
using XDataType = ck_tile::bf16_t; using XDataType = ck_tile::bf16_t;
using YDataType = ck_tile::bf16_t; using YDataType = OutType;
using GammaDataType = ck_tile::bf16_t; using GammaDataType = ck_tile::bf16_t;
using BetaDataType = ck_tile::bf16_t; using BetaDataType = ck_tile::bf16_t;
using MeanDataType = ck_tile::bf16_t; using MeanDataType = ck_tile::bf16_t;
using InvStdDataType = ck_tile::bf16_t; using InvStdDataType = ck_tile::bf16_t;
using ComputeDataType = float; using ComputeDataType = float;
using XScaleDataType = XScaleDataType_;
using YScaleDataType = YScaleDataType_;
}; };
// runtime args // runtime args
...@@ -40,82 +44,21 @@ struct layernorm2d_fwd_args : public ck_tile::Layernorm2dFwdHostArgs ...@@ -40,82 +44,21 @@ struct layernorm2d_fwd_args : public ck_tile::Layernorm2dFwdHostArgs
{ {
}; };
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <typename DataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kSaveMeanInvStd_,
bool kTwoPass_>
struct layernorm2d_fwd_traits_
{
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
static constexpr ck_tile::index_t total_warps =
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;
// num of warps along m
static constexpr ck_tile::index_t BlockWarps_M = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
return total_warps * (warpSize / ThreadPerBlock_N_);
}
else
{
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / warpSize);
}
}();
// num of warps along n
static constexpr ck_tile::index_t BlockWarps_N = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
return 1;
}
else
{
static_assert(ThreadPerBlock_N_ % warpSize == 0);
return ThreadPerBlock_N_ / warpSize;
}
}();
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
using Vector = ck_tile::sequence<1, Vector_N_>;
using Shape = ck_tile::Layernorm2dShape<BlockTile, BlockWarps, WarpTile, Vector>;
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
static constexpr bool kTwoPass = kTwoPass_;
};
template <typename Traits_>
float layernorm2d_fwd_(const ck_tile::stream_config& s, layernorm2d_fwd_args a);
// This is the public API, will be generated by script // This is the public API, will be generated by script
struct layernorm2d_fwd_traits struct layernorm2d_fwd_traits
{ {
std::string data_type; std::string prec_i; // input precision
bool save_mean_var; std::string prec_o; // output precision
// if fused_quant == 1, need set prec_sx/prec_sy to proper string, otherwise can set
// arbitrary(will skip check) if fused_quant == 2, need set prec_sy to proper string, otherwise
// can set arbitrary(will skip check)
std::string prec_sx; // x-scale, used for [1*N] input smooth quant
std::string prec_sy; // y-scale, used for [M*1] output for next layer
bool save_mean_var; //
int fused_add; // 0:no-add, 1:pre-add-store, 2:pre-add
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
}; };
float layernorm2d_fwd(layernorm2d_fwd_traits, layernorm2d_fwd_args, const ck_tile::stream_config&); float layernorm2d_fwd(layernorm2d_fwd_traits, layernorm2d_fwd_args, const ck_tile::stream_config&);
...@@ -2,37 +2,37 @@ ...@@ -2,37 +2,37 @@
# run from top of ck folder # run from top of ck folder
EXE=build/bin/tile_example_layernorm2d_fwd EXE=build/bin/tile_example_layernorm2d_fwd
$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 $EXE -m=1 -n=1 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 $EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 $EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 $EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 $EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 $EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 $EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 $EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 $EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 $EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 $EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 $EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 $EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 $EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 $EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 $EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 $EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 $EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 $EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 $EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 $EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 $EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 $EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 $EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 $EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 $EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 $EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 $EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 $EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 $EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 $EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 $EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 $EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
\ No newline at end of file \ No newline at end of file
...@@ -2,30 +2,34 @@ ...@@ -2,30 +2,34 @@
# call from top of CK folder # call from top of CK folder
EXE=./build/bin/tile_example_layernorm2d_fwd EXE=./build/bin/tile_example_layernorm2d_fwd
for fquant in "" "-fquant=1 -prec_o=int8"; do
for pr_i in "fp16" "bf16" ; do for pr_i in "fp16" "bf16" ; do
$EXE -prec=$pr_i -m=99 -n=13 for fadd in "0" "1"; do
$EXE -prec=$pr_i -m=17 -n=16 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=99 -n=13
$EXE -prec=$pr_i -m=1 -n=100 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=17 -n=16
$EXE -prec=$pr_i -m=4 -n=128 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=100
$EXE -prec=$pr_i -m=80 -n=127 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=4 -n=128
$EXE -prec=$pr_i -m=22 -n=255 -stride=256 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=80 -n=127
$EXE -prec=$pr_i -m=7 -n=599 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=22 -n=255 -stride=256
$EXE -prec=$pr_i -m=19 -n=512 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=599
$EXE -prec=$pr_i -m=33 -n=313 -stride=1000 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=19 -n=512
$EXE -prec=$pr_i -m=11 -n=510 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=33 -n=313 -stride=1000
$EXE -prec=$pr_i -m=171 -n=676 -stride=818 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=11 -n=510
$EXE -prec=$pr_i -m=91 -n=636 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=171 -n=676 -stride=818
$EXE -prec=$pr_i -m=12 -n=768 -stride=800 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=91 -n=636
$EXE -prec=$pr_i -m=100 -n=766 -stride=812 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=12 -n=768 -stride=800
$EXE -prec=$pr_i -m=31 -n=1024 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=100 -n=766 -stride=812
$EXE -prec=$pr_i -m=64 -n=1000 -stride=1004 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=31 -n=1024
$EXE -prec=$pr_i -m=8 -n=1501 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=64 -n=1000 -stride=1004
$EXE -prec=$pr_i -m=3 -n=1826 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=8 -n=1501
$EXE -prec=$pr_i -m=5 -n=2040 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=1826
$EXE -prec=$pr_i -m=7 -n=2734 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=5 -n=2040
$EXE -prec=$pr_i -m=1 -n=3182 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734
$EXE -prec=$pr_i -m=9 -n=4096 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=3182
$EXE -prec=$pr_i -m=3 -n=8192 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=9 -n=4096
$EXE -prec=$pr_i -m=1 -n=10547 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=8192
$EXE -prec=$pr_i -m=3 -n=17134 #$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
done
done
done done
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include "ck_tile/core/numeric/bfloat16.hpp" #include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp" #include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/int8.hpp"
#include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp" #include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp" #include "ck_tile/core/numeric/math.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/random.hpp"
#include <stdint.h>
#include <type_traits>
#pragma once
namespace ck_tile {
// use int8_t directly for int8 arithemetic
// here one can use ck_tile::int8_t to access original int8_t
using int8_t = int8_t;
// limits
template <class T>
struct numeric;
template <>
struct numeric<int8_t>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE static constexpr int8_t min() { return int8_t(-128); }
// minumum finite value
CK_TILE_HOST_DEVICE static constexpr int8_t lowest() { return int8_t(-128); }
// maximum finite value
CK_TILE_HOST_DEVICE static constexpr int8_t max() { return int8_t(127); }
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE static constexpr int8_t epsilon()
{
return 1; // not used
}
CK_TILE_HOST_DEVICE static constexpr int8_t round_error()
{
return 1; // not used
}
// positive infinity value
CK_TILE_HOST_DEVICE static constexpr int8_t infinity()
{
return 1; // not used
}
// quiet NaN
CK_TILE_HOST_DEVICE static constexpr int8_t quiet_NaN()
{
return 1; // not used
}
// signaling NaN
CK_TILE_HOST_DEVICE static constexpr int8_t signaling_NaN()
{
return 1; // not used
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE static constexpr int8_t denorm_min()
{
return 1; // not used
}
CK_TILE_HOST_DEVICE static constexpr int8_t zero() { return 0; }
};
#if 0
template <typename T>
struct numeric_traits;
template <>
struct numeric_traits<int8_t>
{
static constexpr int exp = 5;
static constexpr int mant = 10;
static constexpr int bias = 15;
static constexpr uint16_t nan_mask = 0x7C00;
static constexpr uint16_t head_mask = 0xFC00;
static constexpr uint16_t mant_mask = 0x3FF;
static constexpr uint16_t exp_mask = 0x1F;
static constexpr uint32_t Inf = 0x7C00;
static constexpr uint32_t NegInf = 0xFC00;
static constexpr uint32_t NaN = 0x7C01;
static constexpr uint32_t Neg0 = 0x8000;
using bitwise_type = uint16_t;
};
#endif
CK_TILE_HOST_DEVICE
constexpr float int8_to_float(const int8_t& x) { return static_cast<float>(x); }
CK_TILE_HOST_DEVICE
constexpr int8_t float_to_int8(const float& x) { return static_cast<int8_t>(x); }
} // namespace ck_tile
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp" #include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp" #include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/int8.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -60,6 +61,9 @@ CK_TILE_TYPE_CONVERT(bf16_t, bf16, float, float) ...@@ -60,6 +61,9 @@ CK_TILE_TYPE_CONVERT(bf16_t, bf16, float, float)
CK_TILE_TYPE_CONVERT(fp8_t, fp8, float, float) CK_TILE_TYPE_CONVERT(fp8_t, fp8, float, float)
CK_TILE_TYPE_CONVERT(bf8_t, bf8, float, float) CK_TILE_TYPE_CONVERT(bf8_t, bf8, float, float)
CK_TILE_TYPE_CONVERT(float, float, int8_t, int8)
CK_TILE_TYPE_CONVERT(int8_t, int8, float, float)
#undef CK_TILE_TYPE_CONVERT #undef CK_TILE_TYPE_CONVERT
#endif #endif
......
...@@ -80,6 +80,13 @@ CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, ...@@ -80,6 +80,13 @@ CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view,
return null_tile_window<remove_cvref_t<WindowLengths>>{window_lengths}; return null_tile_window<remove_cvref_t<WindowLengths>>{window_lengths};
} }
template <typename WindowLengths, typename StaticTileDistribution>
CK_TILE_DEVICE constexpr auto make_tile_window(const null_tile_window<WindowLengths>& t,
const StaticTileDistribution&)
{
return t;
}
template <typename WindowLengths> template <typename WindowLengths>
CK_TILE_DEVICE void CK_TILE_DEVICE void
move_tile_window(null_tile_window<WindowLengths>&, move_tile_window(null_tile_window<WindowLengths>&,
......
...@@ -8,20 +8,44 @@ ...@@ -8,20 +8,44 @@
namespace ck_tile { namespace ck_tile {
// Note: for simplicity, each functor only care about single M
struct reference_layernorm2d_default_epilogue
{
template <typename OutDataType, typename AccDataType>
void operator()(int m, HostTensor<OutDataType>& o, const HostTensor<AccDataType>& acc)
{
const int N = acc.mDesc.get_lengths()[1];
for(int n = 0; n < N; ++n)
{
o(m, n) = ck_tile::type_convert<OutDataType>(acc(m, n));
}
}
template <typename OutDataType, typename AccDataType>
auto operator()(int m, const HostTensor<AccDataType>& acc)
{
HostTensor<OutDataType> o(acc.get_lengths(), acc.get_strides());
operator()(m, o, acc);
return o;
}
};
template <typename XDataType, template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename ComputeDataType, typename ComputeDataType,
typename YDataType, typename YDataType,
typename MeanDataType, typename MeanDataType,
typename InvStdDataType> typename InvStdDataType,
typename Epilogue = reference_layernorm2d_default_epilogue>
void reference_layernorm2d_fwd(const HostTensor<XDataType>& x_m_n, void reference_layernorm2d_fwd(const HostTensor<XDataType>& x_m_n,
const HostTensor<GammaDataType>& gamma_n, const HostTensor<GammaDataType>& gamma_n,
const HostTensor<BetaDataType>& beta_n, const HostTensor<BetaDataType>& beta_n,
HostTensor<YDataType>& y_m_n, HostTensor<YDataType>& y_m_n,
HostTensor<MeanDataType>& mean_m, HostTensor<MeanDataType>& mean_m,
HostTensor<InvStdDataType>& invStd_m, HostTensor<InvStdDataType>& invStd_m,
ComputeDataType epsilon) ComputeDataType epsilon,
Epilogue epilogue_functor = {})
{ {
auto layernorm2d_fwd_func = [&](auto m) { auto layernorm2d_fwd_func = [&](auto m) {
const int N = x_m_n.mDesc.get_lengths()[1]; const int N = x_m_n.mDesc.get_lengths()[1];
...@@ -51,16 +75,19 @@ void reference_layernorm2d_fwd(const HostTensor<XDataType>& x_m_n, ...@@ -51,16 +75,19 @@ void reference_layernorm2d_fwd(const HostTensor<XDataType>& x_m_n,
if constexpr(!std::is_same_v<InvStdDataType, ck_tile::null_type>) if constexpr(!std::is_same_v<InvStdDataType, ck_tile::null_type>)
invStd_m(m) = ck_tile::type_convert<InvStdDataType>(divisor); invStd_m(m) = ck_tile::type_convert<InvStdDataType>(divisor);
HostTensor<ComputeDataType> acc(x_m_n.get_lengths(), x_m_n.get_strides());
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n)); ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
ComputeDataType gamma = ck_tile::type_convert<ComputeDataType>(gamma_n(n)); ComputeDataType gamma = ck_tile::type_convert<ComputeDataType>(gamma_n(n));
ComputeDataType beta = ck_tile::type_convert<ComputeDataType>(beta_n(n)); ComputeDataType beta = ck_tile::type_convert<ComputeDataType>(beta_n(n));
auto y = (x - mean) * divisor; auto a_ = (x - mean) * divisor;
y = y * gamma + beta; a_ = a_ * gamma + beta;
y_m_n(m, n) = ck_tile::type_convert<YDataType>(y); acc(m, n) = a_;
} }
epilogue_functor(m, y_m_n, acc);
}; };
make_ParallelTensorFunctor(layernorm2d_fwd_func, make_ParallelTensorFunctor(layernorm2d_fwd_func,
......
...@@ -9,4 +9,5 @@ ...@@ -9,4 +9,5 @@
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp" #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp" #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp" #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
...@@ -3,4 +3,5 @@ ...@@ -3,4 +3,5 @@
#pragma once #pragma once
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/core.hpp"
namespace ck_tile { namespace ck_tile {
/* /*
// clang-format off // clang-format off
...@@ -42,7 +41,7 @@ template <typename BlockTile_, // block size, seq<M, N> ...@@ -42,7 +41,7 @@ template <typename BlockTile_, // block size, seq<M, N>
typename Vector_, // contiguous pixels(vector size) along seq<M, N> typename Vector_, // contiguous pixels(vector size) along seq<M, N>
index_t BlockSize_ = index_t BlockSize_ =
warpSize* reduce_on_sequence(WarpPerBlock_{}, multiplies{}, number<1>{})> warpSize* reduce_on_sequence(WarpPerBlock_{}, multiplies{}, number<1>{})>
struct Layernorm2dShape struct Generic2dBlockShape
{ {
// block size // block size
static constexpr index_t Block_M = BlockTile_::at(number<0>{}); static constexpr index_t Block_M = BlockTile_::at(number<0>{});
......
...@@ -4,4 +4,5 @@ ...@@ -4,4 +4,5 @@
#pragma once #pragma once
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment