"doc/git@developer.sourcefind.cn:wqshmzh/ktransformers.git" did not exist on "0905d2e270853d1de11ab93adefe067606c3287c"
Commit 63214d01 authored by letaoqin's avatar letaoqin
Browse files

port layernorm

parent 29d384d0
set(EXAMPLE_LAYERNORM2D_FWD "tile_example_layernorm2d_fwd")
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
add_executable(tile_example_layernorm2d_fwd EXCLUDE_FROM_ALL layernorm2d_fwd.cpp)
target_compile_options(tile_example_layernorm2d_fwd PRIVATE -DSAVE_MEAN_INV_STD)
\ No newline at end of file
message("adding example ${EXAMPLE_LAYERNORM2D_FWD}")
file(GLOB INSTANCE_SRCS instances/*.cpp)
add_executable(${EXAMPLE_LAYERNORM2D_FWD} EXCLUDE_FROM_ALL example_layernorm2d_fwd.cpp)
target_include_directories(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE layernorm2d_fwd_fp16.cpp layernorm2d_fwd_fp32.cpp ${INSTANCE_SRCS})
set(EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS)
# list(APPEND EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
target_compile_options(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS})
# TODO: we have to turn off this global prop, otherwise the progress bar generated
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
# however, this property may affect global
# TODO: consider codegen a makefile by us
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
......@@ -6,8 +6,7 @@ This folder contains example for Layernorm2D forward using ck_tile tile-programm
```
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_example_layernorm2d_fwd -j
```
This will result in an executable `build/bin/tile_example_layernorm2d_fwd`
......@@ -20,4 +19,4 @@ args:
-e epsilon (default:1e-5)
-v cpu validation or not (default:1)
-prec precision (default:fp16)
```
\ No newline at end of file
```
......@@ -2,61 +2,8 @@
#include "layernorm2d_fwd.hpp"
#include <cstring>
// Host API implementation
float layernorm2d_fwd(layernorm2d_fwd_traits t,
layernorm2d_fwd_args a,
const ck_tile::stream_config& s)
{
if(t.data_type.compare("fp16") == 0)
{
using XDataType = ck_tile::half_t;
using YDataType = ck_tile::half_t;
using GammaDataType = ck_tile::half_t;
using BetaDataType = ck_tile::half_t;
#ifdef SAVE_MEAN_INV_STD
using MeanDataType = ck_tile::half_t;
using InvStdDataType = ck_tile::half_t;
#else
using MeanDataType = ck_tile::null_type;
using InvStdDataType = ck_tile::null_type;
#endif
using ComputeDataType = float;
using thread_tile = ck_tile::sequence<4, 4>;
using warp_tile = ck_tile::sequence<8, 128>;
using block_tile = ck_tile::sequence<32, 128>;
using Shape = ck_tile::TileLayernorm2dShape<thread_tile, warp_tile, block_tile>;
using PipelineProblem = ck_tile::BlockLayernorm2dFwdProblem<XDataType,
GammaDataType,
BetaDataType,
ComputeDataType,
YDataType,
MeanDataType,
InvStdDataType,
Shape,
true,
true>;
using Kernel = ck_tile::Layernorm2dFwd<PipelineProblem>;
auto kargs = Kernel::MakeKargs(
a.p_x, a.p_gamma, a.p_beta, a.p_y, a.p_mean, a.p_invStd, a.epsilon, a.M, a.N);
const dim3 grids = Kernel::GridSize(a.M);
constexpr dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = Shape::kMWarpPerBlock * Shape::kNWarpPerBlock;
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
return 0;
}
extern float layernorm2d_fwd_fp16(layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern float layernorm2d_fwd_fp32(layernorm2d_fwd_args& param, ck_tile::stream_config stream);
auto create_args(int argc, char* argv[])
{
......@@ -65,37 +12,37 @@ auto create_args(int argc, char* argv[])
.insert("n", "4096", "m dimension")
.insert("e", "1e-5", "epsilon")
.insert("v", "1", "cpu validation or not")
.insert("prec", "fp16", "precision");
.insert("prec", "fp32", "precision")
.insert("warmup", "5", "cold iter")
.insert("repeat", "20", "hot iter");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
int main(int argc, char* argv[])
template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
float epsilon = arg_parser.get_float("e");
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
std::string data_type = arg_parser.get_str("prec");
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
using TypeConfig = LayerNormTypeConfig<DataType>;
using XDataType = typename TypeConfig::XDataType;
using YDataType = typename TypeConfig::YDataType;
using GammaDataType = typename TypeConfig::GammaDataType;
using BetaDataType = typename TypeConfig::BetaDataType;
using XDataType = ck_tile::half_t;
using YDataType = ck_tile::half_t;
using GammaDataType = ck_tile::half_t;
using BetaDataType = ck_tile::half_t;
#ifdef SAVE_MEAN_INV_STD
using MeanDataType = ck_tile::half_t;
using InvStdDataType = ck_tile::half_t;
#else
using MeanDataType = ck_tile::null_type;
using MeanDataType = ck_tile::null_type;
using InvStdDataType = ck_tile::null_type;
#endif
using ComputeDataType = float;
using ComputeDataType = typename TypeConfig::ComputeDataType;
// host verify
ck_tile::HostTensor<XDataType> x_host({M, N});
......@@ -108,25 +55,15 @@ int main(int argc, char* argv[])
ck_tile::HostTensor<MeanDataType> mean_host_ref({M});
ck_tile::HostTensor<InvStdDataType> invStd_host_ref({M});
#ifdef SAVE_MEAN_INV_STD
ck_tile::HostTensor<MeanDataType> mean_host_dev({M});
ck_tile::HostTensor<InvStdDataType> invStd_host_dev({M});
#endif
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(x_host);
ck_tile::FillUniformDistribution<GammaDataType>{-5.f, 5.f}(gamma_host);
ck_tile::FillUniformDistribution<BetaDataType>{-5.f, 5.f}(beta_host);
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host);
ck_tile::FillUniformDistribution<BetaDataType>{-.5f, .5f}(beta_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes());
#ifdef SAVE_MEAN_INV_STD
ck_tile::DeviceMem mean_buf(mean_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem invStd_buf(invStd_host_dev.get_element_space_size_in_bytes());
#endif
x_buf.ToDevice(x_host.data());
gamma_buf.ToDevice(gamma_host.data());
beta_buf.ToDevice(beta_host.data());
......@@ -137,26 +74,30 @@ int main(int argc, char* argv[])
gamma_buf.GetDeviceBuffer(),
beta_buf.GetDeviceBuffer(),
y_buf.GetDeviceBuffer(),
#ifdef SAVE_MEAN_INV_STD
mean_buf.GetDeviceBuffer(),
invStd_buf.GetDeviceBuffer(),
#else
nullptr,
nullptr,
#endif
epsilon,
M,
N};
float ave_time = layernorm2d_fwd(traits, args, ck_tile::stream_config{nullptr, true});
float ave_time = .0;
if constexpr(std::is_same<DataType, ck_tile::fp16_t>::value)
{
ave_time =
layernorm2d_fwd_fp16(args, ck_tile::stream_config{nullptr, true, 0, warmup, repeat});
}
else if constexpr(std::is_same<DataType, float>::value)
{
ave_time =
layernorm2d_fwd_fp32(args, ck_tile::stream_config{nullptr, true, 0, warmup, repeat});
}
std::size_t num_byte = sizeof(XDataType) * M * N + sizeof(GammaDataType) * N +
sizeof(BetaDataType) * N + sizeof(YDataType) * M * N;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "[" << data_type << "]"
<< " m:" << M << ", n:" << N << ", " << ave_time << " ms, " << gb_per_sec << " GB/s"
<< std::flush;
<< " m:" << M << ", n:" << N << ", " << ave_time * 1.E6 << " ns, " << gb_per_sec
<< " GB/s" << std::flush;
bool pass = true;
......@@ -176,18 +117,29 @@ int main(int argc, char* argv[])
pass = ck_tile::check_err(y_host_dev, y_host_ref);
#ifdef SAVE_MEAN_INV_STD
mean_buf.FromDevice(mean_host_dev.data());
pass &= ck_tile::check_err(mean_host_dev, mean_host_ref);
invStd_buf.FromDevice(invStd_host_dev.data());
pass &= ck_tile::check_err(invStd_host_dev, invStd_host_ref);
#endif
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush;
}
std::cout << std::endl << std::flush;
return !pass;
return pass;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
{
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
}
if(data_type == "fp32")
{
return run<float>(arg_parser) ? 0 : -2;
}
return -3;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm_dispatch.hpp"
// clang-format off
// template float run_layernorm<ck_tile::fp16_t, 1, 16, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 1, 32, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 1, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// template float run_layernorm<ck_tile::fp16_t, 1, 32, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 1, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 2, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// template float run_layernorm<ck_tile::fp16_t, 1, 64, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 2, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 4, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// template float run_layernorm<ck_tile::fp16_t, 2, 64, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 4, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 8, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// template float run_layernorm<ck_tile::fp16_t, 4, 64, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 8, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 16, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// template float run_layernorm<ck_tile::fp16_t, 4, 64, 8, false, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 8, 64, 4, false, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 16, 64, 2, false, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm_dispatch.hpp"
// clang-format off
// template float run_layernorm<ck_tile::fp16_t, 1, 16, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 1, 32, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 1, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// template float run_layernorm<ck_tile::fp16_t, 1, 32, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 1, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 2, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// template float run_layernorm<ck_tile::fp16_t, 1, 64, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 2, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 4, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// template float run_layernorm<ck_tile::fp16_t, 2, 64, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 4, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 8, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// template float run_layernorm<ck_tile::fp16_t, 4, 64, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 8, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 16, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// template float run_layernorm<ck_tile::fp16_t, 4, 64, 8, true, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 8, 64, 4, true, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 16, 64, 2, true, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm_dispatch.hpp"
// clang-format off
template float run_layernorm<float, 1, 32, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 1, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 1, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 2, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 2, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 4, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 4, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 8, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 8, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 16, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 8, 64, 4, false, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 16, 64, 2, false, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 1, 32, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 1, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 1, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 2, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 2, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 4, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 4, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 8, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 8, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 16, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 8, 64, 4, true, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 16, 64, 2, true, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// clang-format on
......@@ -13,14 +13,51 @@ struct layernorm2d_fwd_traits
std::string data_type;
};
template <typename DataType>
struct LayerNormTypeConfig;
template <>
struct LayerNormTypeConfig<ck_tile::half_t>
{
using XDataType = ck_tile::half_t;
using YDataType = ck_tile::half_t;
using GammaDataType = ck_tile::half_t;
using BetaDataType = ck_tile::half_t;
#ifdef SAVE_MEAN_INV_STD
using MeanDataType = ck_tile::half_t;
using InvStdDataType = ck_tile::half_t;
#else
using MeanDataType = ck_tile::null_type;
using InvStdDataType = ck_tile::null_type;
#endif
using ComputeDataType = float;
};
template <>
struct LayerNormTypeConfig<float>
{
using XDataType = float;
using YDataType = float;
using GammaDataType = float;
using BetaDataType = float;
#ifdef SAVE_MEAN_INV_STD
using MeanDataType = float;
using InvStdDataType = float;
#else
using MeanDataType = ck_tile::null_type;
using InvStdDataType = ck_tile::null_type;
#endif
using ComputeDataType = float;
};
struct layernorm2d_fwd_args
{
const void* p_x;
const void* p_gamma;
const void* p_beta;
void* p_y;
void* p_mean;
void* p_invStd;
// void* p_mean;
// void* p_invStd;
float epsilon;
ck_tile::index_t M;
ck_tile::index_t N;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm_dispatch.hpp"
// clang-format off
// extern template float run_layernorm<ck_tile::fp16_t, 1, 16, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 1, 32, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 1, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// extern template float run_layernorm<ck_tile::fp16_t, 1, 32, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 1, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 2, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// extern template float run_layernorm<ck_tile::fp16_t, 1, 64, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 2, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 4, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// extern template float run_layernorm<ck_tile::fp16_t, 2, 64, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 4, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 8, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// extern template float run_layernorm<ck_tile::fp16_t, 4, 64, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 8, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 16, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 8, 64, 4, false, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 16, 64, 2, false, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// extern template float run_layernorm<ck_tile::fp16_t, 1, 16, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 1, 32, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 1, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// extern template float run_layernorm<ck_tile::fp16_t, 1, 32, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 1, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 2, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// extern template float run_layernorm<ck_tile::fp16_t, 1, 64, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 2, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 4, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// extern template float run_layernorm<ck_tile::fp16_t, 2, 64, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 4, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 8, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// extern template float run_layernorm<ck_tile::fp16_t, 4, 64, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 8, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 16, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 8, 64, 4, true, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 16, 64, 2, true, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// clang-format on
float layernorm2d_fwd_fp16(layernorm2d_fwd_args& param, ck_tile::stream_config stream)
{
// Disable all vector 8fp16 read/write instances as it has performance issue regarding compiler
#if 0
if(param.N % 8 == 0)
{
if(param.N <= 128)
{
return param.N == 128 ? run_layernorm<ck_tile::fp16_t, 1, 16, 8, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 1, 16, 8, true>(param, stream);
}
else if(param.N <= 256)
{
return param.N == 256 ? run_layernorm<ck_tile::fp16_t, 1, 32, 8, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 1, 32, 8, true>(param, stream);
}
else if(param.N <= 512)
{
return param.N == 512 ? run_layernorm<ck_tile::fp16_t, 1, 64, 8, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 1, 64, 8, true>(param, stream);
}
else if(param.N <= 1024)
{
return param.N == 1024 ? run_layernorm<ck_tile::fp16_t, 2, 64, 8, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 2, 64, 8, true>(param, stream);
}
else
{
return param.N == 2048 ? run_layernorm<ck_tile::fp16_t, 4, 64, 8, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 4, 64, 8, true>(param, stream);
}
}
else if(param.N % 4 == 0)
#endif
if(param.N % 4 == 0)
{
if(param.N <= 128)
{
return param.N == 128 ? run_layernorm<ck_tile::fp16_t, 1, 32, 4, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 1, 32, 4, true>(param, stream);
}
else if(param.N <= 256)
{
return param.N == 256 ? run_layernorm<ck_tile::fp16_t, 1, 64, 4, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 1, 64, 4, true>(param, stream);
}
else if(param.N <= 512)
{
return param.N == 512 ? run_layernorm<ck_tile::fp16_t, 2, 64, 4, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 2, 64, 4, true>(param, stream);
}
else if(param.N <= 1024)
{
return param.N == 1024 ? run_layernorm<ck_tile::fp16_t, 4, 64, 4, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 4, 64, 4, true>(param, stream);
}
else if(param.N <= 2048)
{
return param.N == 2048 ? run_layernorm<ck_tile::fp16_t, 8, 64, 4, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 8, 64, 4, true>(param, stream);
}
else
{
return param.N % 2048 == 0
? run_layernorm<ck_tile::fp16_t, 8, 64, 4, false, true>(param, stream)
: run_layernorm<ck_tile::fp16_t, 8, 64, 4, true, true>(param, stream);
}
}
else if(param.N % 2 == 0)
{
if(param.N <= 128)
{
return param.N == 128 ? run_layernorm<ck_tile::fp16_t, 1, 64, 2, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 1, 64, 2, true>(param, stream);
}
else if(param.N <= 256)
{
return param.N == 256 ? run_layernorm<ck_tile::fp16_t, 2, 64, 2, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 2, 64, 2, true>(param, stream);
}
else if(param.N <= 512)
{
return param.N == 512 ? run_layernorm<ck_tile::fp16_t, 4, 64, 2, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 4, 64, 2, true>(param, stream);
}
else if(param.N <= 1024)
{
return param.N == 1024 ? run_layernorm<ck_tile::fp16_t, 8, 64, 2, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 8, 64, 2, true>(param, stream);
}
else if(param.N <= 2048)
{
return param.N == 2048 ? run_layernorm<ck_tile::fp16_t, 16, 64, 2, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 16, 64, 2, true>(param, stream);
}
else
{
return param.N % 2048 == 0
? run_layernorm<ck_tile::fp16_t, 16, 64, 2, false, true>(param, stream)
: run_layernorm<ck_tile::fp16_t, 16, 64, 2, true, true>(param, stream);
}
}
else
{
throw std::runtime_error("Sequence length sizes not supported!");
}
};
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm_dispatch.hpp"
// clang-format off
extern template float run_layernorm<float, 1, 32, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 1, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 1, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 2, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 2, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 4, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 4, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 8, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 8, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 16, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 1, 32, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 1, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 1, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 2, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 2, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 4, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 4, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 8, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 8, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 16, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// clang-format on
float layernorm2d_fwd_fp32(layernorm2d_fwd_args& param, ck_tile::stream_config stream)
{
if(param.N % 4 == 0)
{
if(param.N <= 128)
{
return param.N == 128 ? run_layernorm<float, 1, 32, 4, false>(param, stream)
: run_layernorm<float, 1, 32, 4, true>(param, stream);
}
else if(param.N <= 256)
{
return param.N == 256 ? run_layernorm<float, 1, 64, 4, false>(param, stream)
: run_layernorm<float, 1, 64, 4, true>(param, stream);
}
else if(param.N <= 512)
{
return param.N == 512 ? run_layernorm<float, 2, 64, 4, false>(param, stream)
: run_layernorm<float, 2, 64, 4, true>(param, stream);
}
else if(param.N <= 1024)
{
return param.N == 1024 ? run_layernorm<float, 4, 64, 4, false>(param, stream)
: run_layernorm<float, 4, 64, 4, true>(param, stream);
}
else if(param.N <= 2048)
{
return param.N == 2048 ? run_layernorm<float, 8, 64, 4, false>(param, stream)
: run_layernorm<float, 8, 64, 4, true>(param, stream);
}
else
{
return param.N % 2048 == 0 ? run_layernorm<float, 8, 64, 4, false, true>(param, stream)
: run_layernorm<float, 8, 64, 4, true, true>(param, stream);
}
}
else if(param.N % 2 == 0)
{
if(param.N <= 128)
{
return param.N == 128 ? run_layernorm<float, 1, 64, 2, false>(param, stream)
: run_layernorm<float, 1, 64, 2, true>(param, stream);
}
else if(param.N <= 256)
{
return param.N == 256 ? run_layernorm<float, 2, 64, 2, false>(param, stream)
: run_layernorm<float, 2, 64, 2, true>(param, stream);
}
else if(param.N <= 512)
{
return param.N == 512 ? run_layernorm<float, 4, 64, 2, false>(param, stream)
: run_layernorm<float, 4, 64, 2, true>(param, stream);
}
else if(param.N <= 1024)
{
return param.N == 1024 ? run_layernorm<float, 8, 64, 2, false>(param, stream)
: run_layernorm<float, 8, 64, 2, true>(param, stream);
}
else if(param.N <= 2048)
{
return param.N == 2048 ? run_layernorm<float, 16, 64, 2, false>(param, stream)
: run_layernorm<float, 16, 64, 2, true>(param, stream);
}
else
{
return param.N % 2048 == 0 ? run_layernorm<float, 16, 64, 2, false, true>(param, stream)
: run_layernorm<float, 16, 64, 2, true, true>(param, stream);
}
}
else
{
throw std::runtime_error("Sequence length sizes not supported!");
}
};
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ck_tile/core/numeric/integer.hpp>
#include <ck_tile/host.hpp>
#include <ck_tile/ops/epilogue.hpp>
#include "layernorm2d_fwd.hpp"
template <typename InOutDataType,
ck_tile::index_t NRepeat,
ck_tile::index_t NThread,
ck_tile::index_t VectorAccessSize,
bool kPadN,
bool kTwoPass>
struct layernorm_dispatch
{
static constexpr ck_tile::index_t MRepeat = 1;
static_assert(NThread <= 64, "We only support intra-wave reduction");
static constexpr ck_tile::index_t WaveNum = NThread / 16;
// clang-format off
using thread_tile = ck_tile::sequence<MRepeat, NRepeat, VectorAccessSize>;
using warp_tile = ck_tile::sequence<MRepeat*64/NThread, NRepeat * NThread*VectorAccessSize>;
using block_tile = ck_tile::sequence<MRepeat*WaveNum*64/NThread, NRepeat * NThread*VectorAccessSize>;
// clang-format on
using Shape = ck_tile::TileLayernorm2dShape<thread_tile, warp_tile, block_tile>;
using PipelineProblem = ck_tile::BlockLayernorm2dFwdProblem<
typename LayerNormTypeConfig<InOutDataType>::XDataType,
typename LayerNormTypeConfig<InOutDataType>::GammaDataType,
typename LayerNormTypeConfig<InOutDataType>::BetaDataType,
typename LayerNormTypeConfig<InOutDataType>::ComputeDataType,
typename LayerNormTypeConfig<InOutDataType>::YDataType,
typename LayerNormTypeConfig<InOutDataType>::MeanDataType,
typename LayerNormTypeConfig<InOutDataType>::InvStdDataType,
Shape,
kPadN,
kTwoPass>;
using Kernel = ck_tile::Layernorm2dFwd<PipelineProblem>;
static float Run(const layernorm2d_fwd_args& param, ck_tile::stream_config stream)
{
using k_ = Kernel;
const dim3 grids = k_::GridSize(param.M);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
return ck_tile::launch_kernel(stream,
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{},
grids,
blocks,
0,
param.p_x,
param.p_gamma,
param.p_beta,
param.p_y,
param.epsilon,
param.M,
param.N));
};
};
template <typename InOutDataType,
ck_tile::index_t NRepeat,
ck_tile::index_t NThread,
ck_tile::index_t VectorAccessSize,
bool kPadN,
bool kTwoPass = false>
float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream)
{
return layernorm_dispatch<InOutDataType, NRepeat, NThread, VectorAccessSize, kPadN, kTwoPass>::
Run(param, stream);
};
./bin/tile_example_layernorm2d_fwd -m=700 -n=80 -e=1e-12 -v=1 -prec=fp32 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=128 -e=1e-12 -v=1 -prec=fp32 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=144 -e=1e-12 -v=1 -prec=fp32 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=168 -e=1e-12 -v=1 -prec=fp32 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=184 -e=1e-12 -v=1 -prec=fp32 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=256 -e=1e-12 -v=1 -prec=fp32 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=288 -e=1e-12 -v=1 -prec=fp32 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=344 -e=1e-12 -v=1 -prec=fp32 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=376 -e=1e-12 -v=1 -prec=fp32 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=448 -e=1e-12 -v=1 -prec=fp32 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=512 -e=1e-12 -v=1 -prec=fp32 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=924 -e=1e-12 -v=1 -prec=fp32 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=1024 -e=1e-12 -v=1 -prec=fp32 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=1078 -e=1e-12 -v=1 -prec=fp32 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=1996 -e=1e-12 -v=1 -prec=fp32 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=4080 -e=1e-12 -v=1 -prec=fp32 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=80 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=128 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=144 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=168 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=184 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=256 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=288 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=344 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=376 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=448 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=512 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=924 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=1024 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=1078 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=1996 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=4080 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
\ No newline at end of file
......@@ -15,20 +15,20 @@ template <typename XDataType_,
typename MeanDataType_,
typename InvStdDataType_,
typename BlockShape_,
bool kPadM_,
bool kPadN_>
bool kPadN_,
bool kTwoPass_>
struct BlockLayernorm2dFwdProblem
{
using XDataType = remove_cvref_t<XDataType_>;
using GammaDataType = remove_cvref_t<GammaDataType_>;
using BetaDataType = remove_cvref_t<BetaDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using YDataType = remove_cvref_t<YDataType_>;
using MeanDataType = remove_cvref_t<MeanDataType_>;
using InvStdDataType = remove_cvref_t<InvStdDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
using XDataType = remove_cvref_t<XDataType_>;
using GammaDataType = remove_cvref_t<GammaDataType_>;
using BetaDataType = remove_cvref_t<BetaDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using YDataType = remove_cvref_t<YDataType_>;
using MeanDataType = remove_cvref_t<MeanDataType_>;
using InvStdDataType = remove_cvref_t<InvStdDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool kPadN = kPadN_;
static constexpr bool kTwoPass = kTwoPass_;
};
} // namespace ck_tile
......@@ -12,13 +12,14 @@ template <typename ThreadTile, // Sequence<...
struct TileLayernorm2dShape
{
static constexpr index_t kMPerThread = ThreadTile::at(number<0>{});
static constexpr index_t kNPerThread = ThreadTile::at(number<1>{});
static constexpr index_t kNRepeat = ThreadTile::at(number<1>{});
static constexpr index_t kNPerThread = ThreadTile::at(number<2>{});
static constexpr index_t kMPerWarp = WarpTile::at(number<0>{});
static constexpr index_t kNPerWarp = WarpTile::at(number<1>{});
static constexpr index_t kMThreadPerWarp = kMPerWarp / kMPerThread;
static constexpr index_t kNThreadPerWarp = kNPerWarp / kNPerThread;
static constexpr index_t kNThreadPerWarp = kNPerWarp / kNPerThread / kNRepeat;
static constexpr index_t kMPerBlock = BlockTile::at(number<0>{});
static constexpr index_t kNPerBlock = BlockTile::at(number<1>{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment