Commit b894487c authored by rocking's avatar rocking
Browse files

Support bf16

parent 4e14a894
...@@ -2,6 +2,23 @@ ...@@ -2,6 +2,23 @@
#include "layernorm2d_fwd.hpp" #include "layernorm2d_fwd.hpp"
#include <cstring> #include <cstring>
// different threshold for different dtype
template <typename DataType>
auto get_elimit()
{
double rtol = 1e-3;
double atol = 1e-3;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::bf16_t>()
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
auto create_args(int argc, char* argv[]) auto create_args(int argc, char* argv[])
{ {
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
...@@ -52,7 +69,6 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -52,7 +69,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<MeanDataType> mean_host_ref({M}); ck_tile::HostTensor<MeanDataType> mean_host_ref({M});
ck_tile::HostTensor<InvStdDataType> invStd_host_ref({M}); ck_tile::HostTensor<InvStdDataType> invStd_host_ref({M});
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host); ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host); ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host);
ck_tile::FillUniformDistribution<BetaDataType>{-.5f, .5f}(beta_host); ck_tile::FillUniformDistribution<BetaDataType>{-.5f, .5f}(beta_host);
...@@ -105,14 +121,13 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -105,14 +121,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
y_buf.FromDevice(y_host_dev.data()); y_buf.FromDevice(y_host_dev.data());
pass = ck_tile::check_err(y_host_dev, y_host_ref); auto [rtol, atol] = get_elimit<DataType>();
pass = ck_tile::check_err(
y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush; std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
} }
std::cout << std::endl << std::flush;
std::cout << "pass = " << pass << std::endl;
return pass; return pass;
} }
...@@ -127,6 +142,10 @@ int main(int argc, char* argv[]) ...@@ -127,6 +142,10 @@ int main(int argc, char* argv[])
{ {
return run<ck_tile::half_t>(arg_parser) ? 0 : -2; return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
} }
if(data_type == "bf16")
{
return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2;
}
return -3; return -3;
} }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm2d_fwd.hpp"
template <typename Traits_>
float layernorm2d_fwd_(const ck_tile::stream_config& s, layernorm2d_fwd_args a)
{
using DataType = typename Traits_::DataType;
using PipelineProblem =
ck_tile::BlockLayernorm2dFwdProblem<typename LayerNormTypeConfig<DataType>::XDataType,
typename LayerNormTypeConfig<DataType>::GammaDataType,
typename LayerNormTypeConfig<DataType>::BetaDataType,
typename LayerNormTypeConfig<DataType>::ComputeDataType,
typename LayerNormTypeConfig<DataType>::YDataType,
typename LayerNormTypeConfig<DataType>::MeanDataType,
typename LayerNormTypeConfig<DataType>::InvStdDataType,
typename Traits_::Shape,
Traits_::kPadN,
Traits_::kSaveMeanInvStd,
Traits_::kTwoPass>;
using Kernel = ck_tile::Layernorm2dFwd<PipelineProblem>;
const dim3 grids = Kernel::GridSize(a.M);
constexpr dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
return ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{},
grids,
blocks,
0,
a.p_x,
a.p_gamma,
a.p_beta,
a.p_y,
a.p_mean,
a.p_invStd,
a.epsilon,
a.M,
a.N));
}
template <ck_tile::index_t NRepeat,
ck_tile::index_t NThread,
ck_tile::index_t VectorAccessSize,
bool kTwoPass>
using t = layernorm2d_fwd_traits_<ck_tile::bf16_t,
NRepeat,
NThread,
VectorAccessSize,
false,
false,
kTwoPass>;
using S = ck_tile::stream_config;
using A = layernorm2d_fwd_args;
// Disable all vector 8fp16 read/write instances as it has performance issue regarding compiler
// template float layernorm2d_fwd_<t<1, 16, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 32, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<2, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 64, 8, true>>(const S&, A);
template float layernorm2d_fwd_<t<1, 32, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<1, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<2, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 64, 4, true>>(const S&, A);
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm2d_fwd.hpp"
template <typename Traits_>
float layernorm2d_fwd_(const ck_tile::stream_config& s, layernorm2d_fwd_args a)
{
using DataType = typename Traits_::DataType;
using PipelineProblem =
ck_tile::BlockLayernorm2dFwdProblem<typename LayerNormTypeConfig<DataType>::XDataType,
typename LayerNormTypeConfig<DataType>::GammaDataType,
typename LayerNormTypeConfig<DataType>::BetaDataType,
typename LayerNormTypeConfig<DataType>::ComputeDataType,
typename LayerNormTypeConfig<DataType>::YDataType,
typename LayerNormTypeConfig<DataType>::MeanDataType,
typename LayerNormTypeConfig<DataType>::InvStdDataType,
typename Traits_::Shape,
Traits_::kPadN,
Traits_::kSaveMeanInvStd,
Traits_::kTwoPass>;
using Kernel = ck_tile::Layernorm2dFwd<PipelineProblem>;
const dim3 grids = Kernel::GridSize(a.M);
constexpr dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
return ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{},
grids,
blocks,
0,
a.p_x,
a.p_gamma,
a.p_beta,
a.p_y,
a.p_mean,
a.p_invStd,
a.epsilon,
a.M,
a.N));
}
template <ck_tile::index_t NRepeat,
ck_tile::index_t NThread,
ck_tile::index_t VectorAccessSize,
bool kTwoPass>
using t = layernorm2d_fwd_traits_<ck_tile::bf16_t,
NRepeat,
NThread,
VectorAccessSize,
true,
false,
kTwoPass>;
using S = const ck_tile::stream_config;
using A = layernorm2d_fwd_args;
// Disable all vector 8fp16 read/write instances as it has performance issue regarding compiler
// template float layernorm2d_fwd_<t<1, 16, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 32, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<2, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 64, 8, true>>(const S&, A);
template float layernorm2d_fwd_<t<1, 32, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<1, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<2, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 64, 4, true>>(const S&, A);
template float layernorm2d_fwd_<t<1, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<2, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<4, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<16, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<16, 64, 2, true>>(const S&, A);
template float layernorm2d_fwd_<t<32, 64, 1, false>>(const S&, A);
template float layernorm2d_fwd_<t<32, 64, 1, true>>(const S&, A);
...@@ -24,14 +24,14 @@ struct LayerNormTypeConfig<ck_tile::half_t> ...@@ -24,14 +24,14 @@ struct LayerNormTypeConfig<ck_tile::half_t>
}; };
template <> template <>
struct LayerNormTypeConfig<float> struct LayerNormTypeConfig<ck_tile::bf16_t>
{ {
using XDataType = float; using XDataType = ck_tile::bf16_t;
using YDataType = float; using YDataType = ck_tile::bf16_t;
using GammaDataType = float; using GammaDataType = ck_tile::bf16_t;
using BetaDataType = float; using BetaDataType = ck_tile::bf16_t;
using MeanDataType = float; using MeanDataType = ck_tile::bf16_t;
using InvStdDataType = float; using InvStdDataType = ck_tile::bf16_t;
using ComputeDataType = float; using ComputeDataType = float;
}; };
......
...@@ -90,6 +90,78 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -90,6 +90,78 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
: layernorm2d_fwd_<trait_<ck_tile::fp16_t, 32, 64, 1, true, true>>(s, a); : layernorm2d_fwd_<trait_<ck_tile::fp16_t, 32, 64, 1, true, true>>(s, a);
} }
} }
else if(t.data_type.compare("bf16") == 0)
{
if(a.N % 4 == 0)
{
if(a.N <= 128)
{
return a.N == 128 ? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 32, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 32, 4, true>>(s, a);
}
else if(a.N <= 256)
{
return a.N == 256 ? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 64, 4, true>>(s, a);
}
else if(a.N <= 512)
{
return a.N == 512 ? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 2, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 2, 64, 4, true>>(s, a);
}
else if(a.N <= 1024)
{
return a.N == 1024
? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 4, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 4, 64, 4, true>>(s, a);
}
else if(a.N <= 2048)
{
return a.N == 2048
? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 8, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 8, 64, 4, true>>(s, a);
}
else
{
return a.N % 2048 == 0
? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 8, 64, 4, false, true>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 8, 64, 4, true, true>>(s, a);
}
}
else if(a.N % 2 == 0)
{
if(a.N <= 128)
{
return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 64, 2, true>>(s, a);
}
else if(a.N <= 256)
{
return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 2, 64, 2, true>>(s, a);
}
else if(a.N <= 512)
{
return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 4, 64, 2, true>>(s, a);
}
else if(a.N <= 1024)
{
return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 8, 64, 2, true>>(s, a);
}
else if(a.N <= 2048)
{
return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 16, 64, 2, true>>(s, a);
}
else
{
return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 16, 64, 2, true, true>>(s, a);
}
}
else
{
return a.N <= 2048
? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 32, 64, 1, true, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 32, 64, 1, true, true>>(s, a);
}
}
if(r < 0) if(r < 0)
throw std::runtime_error("Without supported instances!"); throw std::runtime_error("Without supported instances!");
......
./bin/tile_example_layernorm2d_fwd -m=700 -n=80 -e=1e-12 -v=1 -prec=fp32 -repeat=1000 ./bin/tile_example_layernorm2d_fwd -m=1 -n=1 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=128 -e=1e-12 -v=1 -prec=fp32 -repeat=1000 ./bin/tile_example_layernorm2d_fwd -m=700 -n=80 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=144 -e=1e-12 -v=1 -prec=fp32 -repeat=1000 ./bin/tile_example_layernorm2d_fwd -m=700 -n=128 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=168 -e=1e-12 -v=1 -prec=fp32 -repeat=1000 ./bin/tile_example_layernorm2d_fwd -m=700 -n=144 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=184 -e=1e-12 -v=1 -prec=fp32 -repeat=1000 ./bin/tile_example_layernorm2d_fwd -m=700 -n=168 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=256 -e=1e-12 -v=1 -prec=fp32 -repeat=1000 ./bin/tile_example_layernorm2d_fwd -m=700 -n=184 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=288 -e=1e-12 -v=1 -prec=fp32 -repeat=1000 ./bin/tile_example_layernorm2d_fwd -m=700 -n=256 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=344 -e=1e-12 -v=1 -prec=fp32 -repeat=1000 ./bin/tile_example_layernorm2d_fwd -m=700 -n=288 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=376 -e=1e-12 -v=1 -prec=fp32 -repeat=1000 ./bin/tile_example_layernorm2d_fwd -m=700 -n=344 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=448 -e=1e-12 -v=1 -prec=fp32 -repeat=1000 ./bin/tile_example_layernorm2d_fwd -m=700 -n=376 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=512 -e=1e-12 -v=1 -prec=fp32 -repeat=1000 ./bin/tile_example_layernorm2d_fwd -m=700 -n=448 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=924 -e=1e-12 -v=1 -prec=fp32 -repeat=1000 ./bin/tile_example_layernorm2d_fwd -m=700 -n=512 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=1024 -e=1e-12 -v=1 -prec=fp32 -repeat=1000 ./bin/tile_example_layernorm2d_fwd -m=700 -n=924 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=1078 -e=1e-12 -v=1 -prec=fp32 -repeat=1000 ./bin/tile_example_layernorm2d_fwd -m=700 -n=1024 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=1996 -e=1e-12 -v=1 -prec=fp32 -repeat=1000 ./bin/tile_example_layernorm2d_fwd -m=700 -n=1078 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=4080 -e=1e-12 -v=1 -prec=fp32 -repeat=1000 ./bin/tile_example_layernorm2d_fwd -m=700 -n=1996 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=4080 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=80 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 ./bin/tile_example_layernorm2d_fwd -m=700 -n=80 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=128 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 ./bin/tile_example_layernorm2d_fwd -m=700 -n=128 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=144 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 ./bin/tile_example_layernorm2d_fwd -m=700 -n=144 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment