Commit 63214d01 authored by letaoqin's avatar letaoqin
Browse files

port layernorm

parent 29d384d0
set(EXAMPLE_LAYERNORM2D_FWD "tile_example_layernorm2d_fwd")
# not using add_example_executable() to add this target, since we don't want this to have # not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check" # to be included in "make all/install/check"
add_executable(tile_example_layernorm2d_fwd EXCLUDE_FROM_ALL layernorm2d_fwd.cpp) message("adding example ${EXAMPLE_LAYERNORM2D_FWD}")
target_compile_options(tile_example_layernorm2d_fwd PRIVATE -DSAVE_MEAN_INV_STD) file(GLOB INSTANCE_SRCS instances/*.cpp)
\ No newline at end of file add_executable(${EXAMPLE_LAYERNORM2D_FWD} EXCLUDE_FROM_ALL example_layernorm2d_fwd.cpp)
target_include_directories(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE layernorm2d_fwd_fp16.cpp layernorm2d_fwd_fp32.cpp ${INSTANCE_SRCS})
set(EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS)
# list(APPEND EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
target_compile_options(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS})
# TODO: we have to turn off this global prop, otherwise the progress bar generated
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
# however, this property may affect global
# TODO: consider codegen a makefile by us
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
...@@ -6,8 +6,7 @@ This folder contains example for Layernorm2D forward using ck_tile tile-programm ...@@ -6,8 +6,7 @@ This folder contains example for Layernorm2D forward using ck_tile tile-programm
``` ```
# in the root of ck_tile # in the root of ck_tile
mkdir build && cd build mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
sh ../script/cmake-ck-dev.sh ../ <arch>
make tile_example_layernorm2d_fwd -j make tile_example_layernorm2d_fwd -j
``` ```
This will result in an executable `build/bin/tile_example_layernorm2d_fwd` This will result in an executable `build/bin/tile_example_layernorm2d_fwd`
...@@ -20,4 +19,4 @@ args: ...@@ -20,4 +19,4 @@ args:
-e epsilon (default:1e-5) -e epsilon (default:1e-5)
-v cpu validation or not (default:1) -v cpu validation or not (default:1)
-prec precision (default:fp16) -prec precision (default:fp16)
``` ```
\ No newline at end of file
...@@ -2,61 +2,8 @@ ...@@ -2,61 +2,8 @@
#include "layernorm2d_fwd.hpp" #include "layernorm2d_fwd.hpp"
#include <cstring> #include <cstring>
// Host API implementation extern float layernorm2d_fwd_fp16(layernorm2d_fwd_args& param, ck_tile::stream_config stream);
float layernorm2d_fwd(layernorm2d_fwd_traits t, extern float layernorm2d_fwd_fp32(layernorm2d_fwd_args& param, ck_tile::stream_config stream);
layernorm2d_fwd_args a,
const ck_tile::stream_config& s)
{
if(t.data_type.compare("fp16") == 0)
{
using XDataType = ck_tile::half_t;
using YDataType = ck_tile::half_t;
using GammaDataType = ck_tile::half_t;
using BetaDataType = ck_tile::half_t;
#ifdef SAVE_MEAN_INV_STD
using MeanDataType = ck_tile::half_t;
using InvStdDataType = ck_tile::half_t;
#else
using MeanDataType = ck_tile::null_type;
using InvStdDataType = ck_tile::null_type;
#endif
using ComputeDataType = float;
using thread_tile = ck_tile::sequence<4, 4>;
using warp_tile = ck_tile::sequence<8, 128>;
using block_tile = ck_tile::sequence<32, 128>;
using Shape = ck_tile::TileLayernorm2dShape<thread_tile, warp_tile, block_tile>;
using PipelineProblem = ck_tile::BlockLayernorm2dFwdProblem<XDataType,
GammaDataType,
BetaDataType,
ComputeDataType,
YDataType,
MeanDataType,
InvStdDataType,
Shape,
true,
true>;
using Kernel = ck_tile::Layernorm2dFwd<PipelineProblem>;
auto kargs = Kernel::MakeKargs(
a.p_x, a.p_gamma, a.p_beta, a.p_y, a.p_mean, a.p_invStd, a.epsilon, a.M, a.N);
const dim3 grids = Kernel::GridSize(a.M);
constexpr dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = Shape::kMWarpPerBlock * Shape::kNWarpPerBlock;
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
return 0;
}
auto create_args(int argc, char* argv[]) auto create_args(int argc, char* argv[])
{ {
...@@ -65,37 +12,37 @@ auto create_args(int argc, char* argv[]) ...@@ -65,37 +12,37 @@ auto create_args(int argc, char* argv[])
.insert("n", "4096", "m dimension") .insert("n", "4096", "m dimension")
.insert("e", "1e-5", "epsilon") .insert("e", "1e-5", "epsilon")
.insert("v", "1", "cpu validation or not") .insert("v", "1", "cpu validation or not")
.insert("prec", "fp16", "precision"); .insert("prec", "fp32", "precision")
.insert("warmup", "5", "cold iter")
.insert("repeat", "20", "hot iter");
bool result = arg_parser.parse(argc, argv); bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser); return std::make_tuple(result, arg_parser);
} }
int main(int argc, char* argv[]) template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser)
{ {
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
float epsilon = arg_parser.get_float("e"); float epsilon = arg_parser.get_float("e");
ck_tile::index_t M = arg_parser.get_int("m"); ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n"); ck_tile::index_t N = arg_parser.get_int("n");
std::string data_type = arg_parser.get_str("prec"); std::string data_type = arg_parser.get_str("prec");
int do_validation = arg_parser.get_int("v"); int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
using TypeConfig = LayerNormTypeConfig<DataType>;
using XDataType = typename TypeConfig::XDataType;
using YDataType = typename TypeConfig::YDataType;
using GammaDataType = typename TypeConfig::GammaDataType;
using BetaDataType = typename TypeConfig::BetaDataType;
using XDataType = ck_tile::half_t; using MeanDataType = ck_tile::null_type;
using YDataType = ck_tile::half_t;
using GammaDataType = ck_tile::half_t;
using BetaDataType = ck_tile::half_t;
#ifdef SAVE_MEAN_INV_STD
using MeanDataType = ck_tile::half_t;
using InvStdDataType = ck_tile::half_t;
#else
using MeanDataType = ck_tile::null_type;
using InvStdDataType = ck_tile::null_type; using InvStdDataType = ck_tile::null_type;
#endif
using ComputeDataType = float; using ComputeDataType = typename TypeConfig::ComputeDataType;
// host verify // host verify
ck_tile::HostTensor<XDataType> x_host({M, N}); ck_tile::HostTensor<XDataType> x_host({M, N});
...@@ -108,25 +55,15 @@ int main(int argc, char* argv[]) ...@@ -108,25 +55,15 @@ int main(int argc, char* argv[])
ck_tile::HostTensor<MeanDataType> mean_host_ref({M}); ck_tile::HostTensor<MeanDataType> mean_host_ref({M});
ck_tile::HostTensor<InvStdDataType> invStd_host_ref({M}); ck_tile::HostTensor<InvStdDataType> invStd_host_ref({M});
#ifdef SAVE_MEAN_INV_STD ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::HostTensor<MeanDataType> mean_host_dev({M}); ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host);
ck_tile::HostTensor<InvStdDataType> invStd_host_dev({M}); ck_tile::FillUniformDistribution<BetaDataType>{-.5f, .5f}(beta_host);
#endif
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(x_host);
ck_tile::FillUniformDistribution<GammaDataType>{-5.f, 5.f}(gamma_host);
ck_tile::FillUniformDistribution<BetaDataType>{-5.f, 5.f}(beta_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes());
#ifdef SAVE_MEAN_INV_STD
ck_tile::DeviceMem mean_buf(mean_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem invStd_buf(invStd_host_dev.get_element_space_size_in_bytes());
#endif
x_buf.ToDevice(x_host.data()); x_buf.ToDevice(x_host.data());
gamma_buf.ToDevice(gamma_host.data()); gamma_buf.ToDevice(gamma_host.data());
beta_buf.ToDevice(beta_host.data()); beta_buf.ToDevice(beta_host.data());
...@@ -137,26 +74,30 @@ int main(int argc, char* argv[]) ...@@ -137,26 +74,30 @@ int main(int argc, char* argv[])
gamma_buf.GetDeviceBuffer(), gamma_buf.GetDeviceBuffer(),
beta_buf.GetDeviceBuffer(), beta_buf.GetDeviceBuffer(),
y_buf.GetDeviceBuffer(), y_buf.GetDeviceBuffer(),
#ifdef SAVE_MEAN_INV_STD
mean_buf.GetDeviceBuffer(),
invStd_buf.GetDeviceBuffer(),
#else
nullptr,
nullptr,
#endif
epsilon, epsilon,
M, M,
N}; N};
float ave_time = layernorm2d_fwd(traits, args, ck_tile::stream_config{nullptr, true}); float ave_time = .0;
if constexpr(std::is_same<DataType, ck_tile::fp16_t>::value)
{
ave_time =
layernorm2d_fwd_fp16(args, ck_tile::stream_config{nullptr, true, 0, warmup, repeat});
}
else if constexpr(std::is_same<DataType, float>::value)
{
ave_time =
layernorm2d_fwd_fp32(args, ck_tile::stream_config{nullptr, true, 0, warmup, repeat});
}
std::size_t num_byte = sizeof(XDataType) * M * N + sizeof(GammaDataType) * N + std::size_t num_byte = sizeof(XDataType) * M * N + sizeof(GammaDataType) * N +
sizeof(BetaDataType) * N + sizeof(YDataType) * M * N; sizeof(BetaDataType) * N + sizeof(YDataType) * M * N;
float gb_per_sec = num_byte / 1.E6 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "[" << data_type << "]" std::cout << "[" << data_type << "]"
<< " m:" << M << ", n:" << N << ", " << ave_time << " ms, " << gb_per_sec << " GB/s" << " m:" << M << ", n:" << N << ", " << ave_time * 1.E6 << " ns, " << gb_per_sec
<< std::flush; << " GB/s" << std::flush;
bool pass = true; bool pass = true;
...@@ -176,18 +117,29 @@ int main(int argc, char* argv[]) ...@@ -176,18 +117,29 @@ int main(int argc, char* argv[])
pass = ck_tile::check_err(y_host_dev, y_host_ref); pass = ck_tile::check_err(y_host_dev, y_host_ref);
#ifdef SAVE_MEAN_INV_STD
mean_buf.FromDevice(mean_host_dev.data());
pass &= ck_tile::check_err(mean_host_dev, mean_host_ref);
invStd_buf.FromDevice(invStd_host_dev.data());
pass &= ck_tile::check_err(invStd_host_dev, invStd_host_ref);
#endif
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush; std::cout << ", valid:" << (pass ? "y" : "n") << std::flush;
} }
std::cout << std::endl << std::flush; std::cout << std::endl << std::flush;
return !pass; return pass;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
{
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
}
if(data_type == "fp32")
{
return run<float>(arg_parser) ? 0 : -2;
}
return -3;
} }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm_dispatch.hpp"
// clang-format off
// template float run_layernorm<ck_tile::fp16_t, 1, 16, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 1, 32, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 1, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// template float run_layernorm<ck_tile::fp16_t, 1, 32, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 1, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 2, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// template float run_layernorm<ck_tile::fp16_t, 1, 64, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 2, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 4, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// template float run_layernorm<ck_tile::fp16_t, 2, 64, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 4, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 8, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// template float run_layernorm<ck_tile::fp16_t, 4, 64, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 8, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 16, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// template float run_layernorm<ck_tile::fp16_t, 4, 64, 8, false, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 8, 64, 4, false, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 16, 64, 2, false, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm_dispatch.hpp"
// clang-format off
// template float run_layernorm<ck_tile::fp16_t, 1, 16, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 1, 32, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 1, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// template float run_layernorm<ck_tile::fp16_t, 1, 32, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 1, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 2, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// template float run_layernorm<ck_tile::fp16_t, 1, 64, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 2, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 4, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// template float run_layernorm<ck_tile::fp16_t, 2, 64, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 4, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 8, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// template float run_layernorm<ck_tile::fp16_t, 4, 64, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 8, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 16, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// template float run_layernorm<ck_tile::fp16_t, 4, 64, 8, true, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 8, 64, 4, true, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<ck_tile::fp16_t, 16, 64, 2, true, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm_dispatch.hpp"
// clang-format off
template float run_layernorm<float, 1, 32, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 1, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 1, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 2, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 2, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 4, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 4, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 8, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 8, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 16, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 8, 64, 4, false, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 16, 64, 2, false, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 1, 32, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 1, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 1, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 2, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 2, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 4, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 4, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 8, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 8, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 16, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 8, 64, 4, true, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template float run_layernorm<float, 16, 64, 2, true, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// clang-format on
...@@ -13,14 +13,51 @@ struct layernorm2d_fwd_traits ...@@ -13,14 +13,51 @@ struct layernorm2d_fwd_traits
std::string data_type; std::string data_type;
}; };
template <typename DataType>
struct LayerNormTypeConfig;
template <>
struct LayerNormTypeConfig<ck_tile::half_t>
{
using XDataType = ck_tile::half_t;
using YDataType = ck_tile::half_t;
using GammaDataType = ck_tile::half_t;
using BetaDataType = ck_tile::half_t;
#ifdef SAVE_MEAN_INV_STD
using MeanDataType = ck_tile::half_t;
using InvStdDataType = ck_tile::half_t;
#else
using MeanDataType = ck_tile::null_type;
using InvStdDataType = ck_tile::null_type;
#endif
using ComputeDataType = float;
};
template <>
struct LayerNormTypeConfig<float>
{
using XDataType = float;
using YDataType = float;
using GammaDataType = float;
using BetaDataType = float;
#ifdef SAVE_MEAN_INV_STD
using MeanDataType = float;
using InvStdDataType = float;
#else
using MeanDataType = ck_tile::null_type;
using InvStdDataType = ck_tile::null_type;
#endif
using ComputeDataType = float;
};
struct layernorm2d_fwd_args struct layernorm2d_fwd_args
{ {
const void* p_x; const void* p_x;
const void* p_gamma; const void* p_gamma;
const void* p_beta; const void* p_beta;
void* p_y; void* p_y;
void* p_mean; // void* p_mean;
void* p_invStd; // void* p_invStd;
float epsilon; float epsilon;
ck_tile::index_t M; ck_tile::index_t M;
ck_tile::index_t N; ck_tile::index_t N;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm_dispatch.hpp"
// clang-format off
// extern template float run_layernorm<ck_tile::fp16_t, 1, 16, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 1, 32, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 1, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// extern template float run_layernorm<ck_tile::fp16_t, 1, 32, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 1, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 2, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// extern template float run_layernorm<ck_tile::fp16_t, 1, 64, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 2, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 4, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// extern template float run_layernorm<ck_tile::fp16_t, 2, 64, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 4, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 8, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// extern template float run_layernorm<ck_tile::fp16_t, 4, 64, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 8, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 16, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 8, 64, 4, false, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 16, 64, 2, false, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// extern template float run_layernorm<ck_tile::fp16_t, 1, 16, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 1, 32, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 1, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// extern template float run_layernorm<ck_tile::fp16_t, 1, 32, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 1, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 2, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// extern template float run_layernorm<ck_tile::fp16_t, 1, 64, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 2, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 4, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// extern template float run_layernorm<ck_tile::fp16_t, 2, 64, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 4, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 8, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// extern template float run_layernorm<ck_tile::fp16_t, 4, 64, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 8, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 16, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 8, 64, 4, true, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 16, 64, 2, true, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// clang-format on
float layernorm2d_fwd_fp16(layernorm2d_fwd_args& param, ck_tile::stream_config stream)
{
// Disable all vector 8fp16 read/write instances as it has performance issue regarding compiler
#if 0
if(param.N % 8 == 0)
{
if(param.N <= 128)
{
return param.N == 128 ? run_layernorm<ck_tile::fp16_t, 1, 16, 8, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 1, 16, 8, true>(param, stream);
}
else if(param.N <= 256)
{
return param.N == 256 ? run_layernorm<ck_tile::fp16_t, 1, 32, 8, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 1, 32, 8, true>(param, stream);
}
else if(param.N <= 512)
{
return param.N == 512 ? run_layernorm<ck_tile::fp16_t, 1, 64, 8, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 1, 64, 8, true>(param, stream);
}
else if(param.N <= 1024)
{
return param.N == 1024 ? run_layernorm<ck_tile::fp16_t, 2, 64, 8, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 2, 64, 8, true>(param, stream);
}
else
{
return param.N == 2048 ? run_layernorm<ck_tile::fp16_t, 4, 64, 8, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 4, 64, 8, true>(param, stream);
}
}
else if(param.N % 4 == 0)
#endif
if(param.N % 4 == 0)
{
if(param.N <= 128)
{
return param.N == 128 ? run_layernorm<ck_tile::fp16_t, 1, 32, 4, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 1, 32, 4, true>(param, stream);
}
else if(param.N <= 256)
{
return param.N == 256 ? run_layernorm<ck_tile::fp16_t, 1, 64, 4, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 1, 64, 4, true>(param, stream);
}
else if(param.N <= 512)
{
return param.N == 512 ? run_layernorm<ck_tile::fp16_t, 2, 64, 4, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 2, 64, 4, true>(param, stream);
}
else if(param.N <= 1024)
{
return param.N == 1024 ? run_layernorm<ck_tile::fp16_t, 4, 64, 4, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 4, 64, 4, true>(param, stream);
}
else if(param.N <= 2048)
{
return param.N == 2048 ? run_layernorm<ck_tile::fp16_t, 8, 64, 4, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 8, 64, 4, true>(param, stream);
}
else
{
return param.N % 2048 == 0
? run_layernorm<ck_tile::fp16_t, 8, 64, 4, false, true>(param, stream)
: run_layernorm<ck_tile::fp16_t, 8, 64, 4, true, true>(param, stream);
}
}
else if(param.N % 2 == 0)
{
if(param.N <= 128)
{
return param.N == 128 ? run_layernorm<ck_tile::fp16_t, 1, 64, 2, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 1, 64, 2, true>(param, stream);
}
else if(param.N <= 256)
{
return param.N == 256 ? run_layernorm<ck_tile::fp16_t, 2, 64, 2, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 2, 64, 2, true>(param, stream);
}
else if(param.N <= 512)
{
return param.N == 512 ? run_layernorm<ck_tile::fp16_t, 4, 64, 2, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 4, 64, 2, true>(param, stream);
}
else if(param.N <= 1024)
{
return param.N == 1024 ? run_layernorm<ck_tile::fp16_t, 8, 64, 2, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 8, 64, 2, true>(param, stream);
}
else if(param.N <= 2048)
{
return param.N == 2048 ? run_layernorm<ck_tile::fp16_t, 16, 64, 2, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 16, 64, 2, true>(param, stream);
}
else
{
return param.N % 2048 == 0
? run_layernorm<ck_tile::fp16_t, 16, 64, 2, false, true>(param, stream)
: run_layernorm<ck_tile::fp16_t, 16, 64, 2, true, true>(param, stream);
}
}
else
{
throw std::runtime_error("Sequence length sizes not supported!");
}
};
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm_dispatch.hpp"
// clang-format off
extern template float run_layernorm<float, 1, 32, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 1, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 1, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 2, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 2, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 4, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 4, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 8, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 8, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 16, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 1, 32, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 1, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 1, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 2, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 2, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 4, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 4, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 8, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 8, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 16, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// clang-format on
float layernorm2d_fwd_fp32(layernorm2d_fwd_args& param, ck_tile::stream_config stream)
{
if(param.N % 4 == 0)
{
if(param.N <= 128)
{
return param.N == 128 ? run_layernorm<float, 1, 32, 4, false>(param, stream)
: run_layernorm<float, 1, 32, 4, true>(param, stream);
}
else if(param.N <= 256)
{
return param.N == 256 ? run_layernorm<float, 1, 64, 4, false>(param, stream)
: run_layernorm<float, 1, 64, 4, true>(param, stream);
}
else if(param.N <= 512)
{
return param.N == 512 ? run_layernorm<float, 2, 64, 4, false>(param, stream)
: run_layernorm<float, 2, 64, 4, true>(param, stream);
}
else if(param.N <= 1024)
{
return param.N == 1024 ? run_layernorm<float, 4, 64, 4, false>(param, stream)
: run_layernorm<float, 4, 64, 4, true>(param, stream);
}
else if(param.N <= 2048)
{
return param.N == 2048 ? run_layernorm<float, 8, 64, 4, false>(param, stream)
: run_layernorm<float, 8, 64, 4, true>(param, stream);
}
else
{
return param.N % 2048 == 0 ? run_layernorm<float, 8, 64, 4, false, true>(param, stream)
: run_layernorm<float, 8, 64, 4, true, true>(param, stream);
}
}
else if(param.N % 2 == 0)
{
if(param.N <= 128)
{
return param.N == 128 ? run_layernorm<float, 1, 64, 2, false>(param, stream)
: run_layernorm<float, 1, 64, 2, true>(param, stream);
}
else if(param.N <= 256)
{
return param.N == 256 ? run_layernorm<float, 2, 64, 2, false>(param, stream)
: run_layernorm<float, 2, 64, 2, true>(param, stream);
}
else if(param.N <= 512)
{
return param.N == 512 ? run_layernorm<float, 4, 64, 2, false>(param, stream)
: run_layernorm<float, 4, 64, 2, true>(param, stream);
}
else if(param.N <= 1024)
{
return param.N == 1024 ? run_layernorm<float, 8, 64, 2, false>(param, stream)
: run_layernorm<float, 8, 64, 2, true>(param, stream);
}
else if(param.N <= 2048)
{
return param.N == 2048 ? run_layernorm<float, 16, 64, 2, false>(param, stream)
: run_layernorm<float, 16, 64, 2, true>(param, stream);
}
else
{
return param.N % 2048 == 0 ? run_layernorm<float, 16, 64, 2, false, true>(param, stream)
: run_layernorm<float, 16, 64, 2, true, true>(param, stream);
}
}
else
{
throw std::runtime_error("Sequence length sizes not supported!");
}
};
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ck_tile/core/numeric/integer.hpp>
#include <ck_tile/host.hpp>
#include <ck_tile/ops/epilogue.hpp>
#include "layernorm2d_fwd.hpp"
template <typename InOutDataType,
ck_tile::index_t NRepeat,
ck_tile::index_t NThread,
ck_tile::index_t VectorAccessSize,
bool kPadN,
bool kTwoPass>
struct layernorm_dispatch
{
static constexpr ck_tile::index_t MRepeat = 1;
static_assert(NThread <= 64, "We only support intra-wave reduction");
static constexpr ck_tile::index_t WaveNum = NThread / 16;
// clang-format off
using thread_tile = ck_tile::sequence<MRepeat, NRepeat, VectorAccessSize>;
using warp_tile = ck_tile::sequence<MRepeat*64/NThread, NRepeat * NThread*VectorAccessSize>;
using block_tile = ck_tile::sequence<MRepeat*WaveNum*64/NThread, NRepeat * NThread*VectorAccessSize>;
// clang-format on
using Shape = ck_tile::TileLayernorm2dShape<thread_tile, warp_tile, block_tile>;
using PipelineProblem = ck_tile::BlockLayernorm2dFwdProblem<
typename LayerNormTypeConfig<InOutDataType>::XDataType,
typename LayerNormTypeConfig<InOutDataType>::GammaDataType,
typename LayerNormTypeConfig<InOutDataType>::BetaDataType,
typename LayerNormTypeConfig<InOutDataType>::ComputeDataType,
typename LayerNormTypeConfig<InOutDataType>::YDataType,
typename LayerNormTypeConfig<InOutDataType>::MeanDataType,
typename LayerNormTypeConfig<InOutDataType>::InvStdDataType,
Shape,
kPadN,
kTwoPass>;
using Kernel = ck_tile::Layernorm2dFwd<PipelineProblem>;
static float Run(const layernorm2d_fwd_args& param, ck_tile::stream_config stream)
{
using k_ = Kernel;
const dim3 grids = k_::GridSize(param.M);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
return ck_tile::launch_kernel(stream,
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{},
grids,
blocks,
0,
param.p_x,
param.p_gamma,
param.p_beta,
param.p_y,
param.epsilon,
param.M,
param.N));
};
};
template <typename InOutDataType,
ck_tile::index_t NRepeat,
ck_tile::index_t NThread,
ck_tile::index_t VectorAccessSize,
bool kPadN,
bool kTwoPass = false>
float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream)
{
return layernorm_dispatch<InOutDataType, NRepeat, NThread, VectorAccessSize, kPadN, kTwoPass>::
Run(param, stream);
};
./bin/tile_example_layernorm2d_fwd -m=700 -n=80 -e=1e-12 -v=1 -prec=fp32 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=128 -e=1e-12 -v=1 -prec=fp32 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=144 -e=1e-12 -v=1 -prec=fp32 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=168 -e=1e-12 -v=1 -prec=fp32 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=184 -e=1e-12 -v=1 -prec=fp32 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=256 -e=1e-12 -v=1 -prec=fp32 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=288 -e=1e-12 -v=1 -prec=fp32 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=344 -e=1e-12 -v=1 -prec=fp32 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=376 -e=1e-12 -v=1 -prec=fp32 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=448 -e=1e-12 -v=1 -prec=fp32 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=512 -e=1e-12 -v=1 -prec=fp32 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=924 -e=1e-12 -v=1 -prec=fp32 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=1024 -e=1e-12 -v=1 -prec=fp32 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=1078 -e=1e-12 -v=1 -prec=fp32 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=1996 -e=1e-12 -v=1 -prec=fp32 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=4080 -e=1e-12 -v=1 -prec=fp32 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=80 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=128 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=144 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=168 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=184 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=256 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=288 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=344 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=376 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=448 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=512 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=924 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=1024 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=1078 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=1996 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=4080 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
\ No newline at end of file
...@@ -31,14 +31,10 @@ struct Layernorm2dFwd ...@@ -31,14 +31,10 @@ struct Layernorm2dFwd
static constexpr ck_tile::index_t kMPerBlock = Problem::BlockShape::kMPerBlock; static constexpr ck_tile::index_t kMPerBlock = Problem::BlockShape::kMPerBlock;
static constexpr ck_tile::index_t kNPerBlock = Problem::BlockShape::kNPerBlock; static constexpr ck_tile::index_t kNPerBlock = Problem::BlockShape::kNPerBlock;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kTwoPass = Problem::kTwoPass;
static constexpr ck_tile::index_t kNThreadPerWarp = Problem::BlockShape::kNThreadPerWarp; static constexpr ck_tile::index_t kNThreadPerWarp = Problem::BlockShape::kNThreadPerWarp;
static constexpr ck_tile::index_t kNPerThread = Problem::BlockShape::kNPerThread;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
struct Kargs struct Kargs
{ {
...@@ -47,8 +43,8 @@ struct Layernorm2dFwd ...@@ -47,8 +43,8 @@ struct Layernorm2dFwd
const void* p_beta; const void* p_beta;
void* p_y; void* p_y;
void* p_mean; // void* p_mean;
void* p_invStd; // void* p_invStd;
float epsilon; float epsilon;
...@@ -69,7 +65,10 @@ struct Layernorm2dFwd ...@@ -69,7 +65,10 @@ struct Layernorm2dFwd
return Kargs{p_x, p_gamma, p_beta, p_y, p_mean, p_invStd, epsilon, M, N}; return Kargs{p_x, p_gamma, p_beta, p_y, p_mean, p_invStd, epsilon, M, N};
} }
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t M) { return M / kMPerBlock; } CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t M)
{
return (M + kMPerBlock - 1) / kMPerBlock;
}
CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::kBlockSize; } CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::kBlockSize; }
...@@ -81,11 +80,11 @@ struct Layernorm2dFwd ...@@ -81,11 +80,11 @@ struct Layernorm2dFwd
tile_distribution_encoding< tile_distribution_encoding<
sequence<>, sequence<>,
tuple<sequence<S::kMWarpPerBlock, S::kMThreadPerWarp, S::kMPerThread>, tuple<sequence<S::kMWarpPerBlock, S::kMThreadPerWarp, S::kMPerThread>,
sequence<S::kNWarpPerBlock, S::kNThreadPerWarp, S::kNPerThread>>, sequence<S::kNRepeat, S::kNWarpPerBlock, S::kNThreadPerWarp, S::kNPerThread>>,
tuple<sequence<1, 2>, sequence<1, 2>>, tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<0, 0>, sequence<1, 1>>, tuple<sequence<0, 1>, sequence<1, 2>>,
sequence<1, 2>, sequence<1, 2, 2>,
sequence<2, 2>>{}); sequence<2, 0, 3>>{});
} }
CK_TILE_DEVICE static constexpr auto MakeGammaBetaBlockTileDistribution() CK_TILE_DEVICE static constexpr auto MakeGammaBetaBlockTileDistribution()
...@@ -95,32 +94,26 @@ struct Layernorm2dFwd ...@@ -95,32 +94,26 @@ struct Layernorm2dFwd
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding< tile_distribution_encoding<
sequence<S::kMWarpPerBlock, S::kMThreadPerWarp>, sequence<S::kMWarpPerBlock, S::kMThreadPerWarp>,
tuple<sequence<S::kNWarpPerBlock, S::kNThreadPerWarp, S::kNPerThread>>, tuple<sequence<S::kNRepeat, S::kNWarpPerBlock, S::kNThreadPerWarp, S::kNPerThread>>,
tuple<sequence<0, 1>, sequence<0, 1>>, tuple<sequence<0, 1>, sequence<0, 1>>,
tuple<sequence<0, 0>, sequence<1, 1>>, tuple<sequence<0, 1>, sequence<1, 2>>,
sequence<1>, sequence<1, 1>,
sequence<2>>{}); sequence<0, 3>>{});
} }
CK_TILE_DEVICE static int GetWelfordMaxCount(int N) template <typename Dstr>
CK_TILE_DEVICE static constexpr auto GetNPerThread(Dstr)
{ {
constexpr ck_tile::index_t kNThreadPerBlock = kNPerBlock / kNPerThread; constexpr auto nDstrSpan = Dstr::get_distributed_spans().template at<1>();
int thread_id_n = get_thread_id() % kNThreadPerBlock; using Lengths = decltype(nDstrSpan.impl_);
int max_count =
__builtin_amdgcn_readfirstlane(N < kNPerBlock ? 0 : kNPerThread * (N / kNPerBlock));
int n_per_block_tail_loop =
__builtin_amdgcn_readfirstlane(N - max_count * kNThreadPerBlock);
if(n_per_block_tail_loop > 0) ck_tile::index_t ret = 1;
{
int thread_max_n = (thread_id_n + 1) * kNPerThread; ck_tile::static_for<0, Lengths::size(), 1>{}(
int delta = thread_max_n - n_per_block_tail_loop; [&](auto idx) { ret *= Lengths::template at(idx); });
delta = clamp(thread_max_n - n_per_block_tail_loop, 0, kNPerThread);
max_count += kNPerThread - delta;
}
return max_count; return ret;
} }
template <typename DistributedTensor> template <typename DistributedTensor>
...@@ -141,127 +134,70 @@ struct Layernorm2dFwd ...@@ -141,127 +134,70 @@ struct Layernorm2dFwd
return out_dstr_tensor; return out_dstr_tensor;
} }
template <typename XBlockWindow, CK_TILE_HOST_DEVICE static constexpr auto
typename GammaBlockWindow, GetLastloopLayerNormIntraLaneReduceCount(index_t NLength)
typename BetaBlockWindow,
typename YBlockWindow,
typename MeanBlockWindow,
typename InvStdBlockWindow,
bool Cond = (kHasGamma && kHasBeta)>
CK_TILE_DEVICE std::enable_if_t<Cond>
TwoPassLayernorm2dFwd(XBlockWindow& x_block_window,
GammaBlockWindow& gamma_block_window,
BetaBlockWindow& beta_block_window,
YBlockWindow& y_block_window,
MeanBlockWindow& mean_block_window,
InvStdBlockWindow& inv_std_block_window,
ComputeDataType epsilon,
ck_tile::index_t N) const
{ {
// TODO - Optimize tail loop to reduce move_tile_window() using S = typename Problem::BlockShape;
index_t num_n_tile_iteration = // S::kNWarpPerBlock, S::kNThreadPerWarp, S::kNPerThread
__builtin_amdgcn_readfirstlane(integer_divide_ceil(N, kNPerBlock)); auto LastloopN = NLength % kNPerBlock == 0 ? kNPerBlock : NLength % kNPerBlock;
constexpr auto NThread = S::kNWarpPerBlock * S::kNThreadPerWarp;
int welford_max_count = GetWelfordMaxCount(N); auto iNLane = get_thread_local_1d_id() % NThread;
ThreadWelford<ComputeDataType, XDataType> thread_welford{welford_max_count}; auto iN0 = LastloopN / (S::kNPerThread * S::kNThreadPerWarp);
auto iN1 = (LastloopN % (S::kNPerThread * S::kNThreadPerWarp)) / S::kNPerThread;
using XTensorType = decltype(load_tile(x_block_window)); auto N2 = (LastloopN % (S::kNPerThread * S::kNThreadPerWarp)) % S::kNPerThread;
auto mean_compute_block_tensor = auto iN3 = iNLane < iN1 ? S::kNPerThread : iNLane == iN1 ? N2 : 0;
thread_welford.template MakeInitialMeanVarDistributedTensor<XTensorType>();
auto var_compute_block_tensor = return iN0 * S::kNPerThread + iN3;
thread_welford.template MakeInitialMeanVarDistributedTensor<XTensorType>(); }
clear_tile(mean_compute_block_tensor);
clear_tile(var_compute_block_tensor);
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
const auto x_block_tensor = load_tile(x_block_window);
thread_welford(x_block_tensor, mean_compute_block_tensor, var_compute_block_tensor);
move_tile_window(x_block_window, {0, kNPerBlock});
}
// TODO: support cross warp Welford
WarpMergeWelford<ComputeDataType, true>{}(
mean_compute_block_tensor, var_compute_block_tensor, thread_welford.cur_count_);
auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon);
if constexpr(kSaveMean)
store_tile(mean_block_window, cast_tile<MeanDataType>(mean_compute_block_tensor));
if constexpr(kSaveInvStd)
store_tile(inv_std_block_window,
cast_tile<InvStdDataType>(inv_std_compute_block_tensor));
// reverse read x to reuse cache
ck_tile::index_t stride_to_right_most_window =
N % kNPerBlock == 0 ? N - kNPerBlock : N - N % kNPerBlock;
move_tile_window(x_block_window, {0, -kNPerBlock}); template <bool Cond = (kHasGamma && kHasBeta)>
move_tile_window(gamma_block_window, {stride_to_right_most_window}); CK_TILE_DEVICE std::enable_if_t<Cond> OnePassLayernorm2dFwd(const XDataType* p_x,
move_tile_window(beta_block_window, {stride_to_right_most_window}); const GammaDataType* p_gamma,
move_tile_window(y_block_window, {0, stride_to_right_most_window}); const BetaDataType* p_beta,
YDataType* p_y,
const ComputeDataType epsilon,
ck_tile::index_t M,
ck_tile::index_t N) const
{
using S = typename Problem::BlockShape;
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
// Normalization const auto x_m_n = [&]() {
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) const auto x_dram_naive = make_naive_tensor_view<address_space_enum::global>(
{ p_x, make_tuple(M, N), make_tuple(N, 1), number<S::kNPerThread>{}, number<1>{});
const auto x_block_tensor = load_tile(x_block_window);
const auto gamma_block_tensor = load_tile(gamma_block_window);
const auto beta_block_tensor = load_tile(beta_block_window);
constexpr auto x_spans = decltype(x_block_tensor)::get_distributed_spans(); return pad_tensor_view(x_dram_naive,
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
sequence<false, kPadN>{});
}();
auto y_block_tensor = const auto gamma_n = [&]() {
make_static_distributed_tensor<YDataType>(x_block_tensor.get_tile_distribution()); const auto gamma_dram_naive = make_naive_tensor_view<address_space_enum::global>(
p_gamma, make_tuple(N), make_tuple(1), number<S::kNPerThread>{}, number<1>{});
sweep_tile_span(x_spans[I1], [&](auto idx1) { return pad_tensor_view(
constexpr auto j_idx = make_tuple(idx1); gamma_dram_naive, make_tuple(number<kNPerBlock>{}), sequence<kPadN>{});
const auto gamma = type_convert<ComputeDataType>(gamma_block_tensor[j_idx]); }();
const auto beta = type_convert<ComputeDataType>(beta_block_tensor[j_idx]);
sweep_tile_span(x_spans[I0], [&](auto idx0) { const auto beta_n = [&]() {
constexpr auto i_idx = make_tuple(idx0); const auto gamma_dram_naive = make_naive_tensor_view<address_space_enum::global>(
constexpr auto i_j_idx = make_tuple(idx0, idx1); p_beta, make_tuple(N), make_tuple(1), number<S::kNPerThread>{}, number<1>{});
const auto mean = mean_compute_block_tensor[i_idx]; return pad_tensor_view(
const auto inv_std = inv_std_compute_block_tensor[i_idx]; gamma_dram_naive, make_tuple(number<kNPerBlock>{}), sequence<kPadN>{});
}();
const auto x = type_convert<ComputeDataType>(x_block_tensor[i_j_idx]); const auto iM = get_block_id() * kMPerBlock;
auto y = (x - mean) * inv_std * gamma + beta;
y_block_tensor(i_j_idx) = type_convert<YDataType>(y); constexpr auto xDstr = MakeXBlockTileDistribution();
});
});
store_tile(y_block_window, y_block_tensor); auto x_block_window = make_tile_window(
x_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0}, xDstr);
move_tile_window(x_block_window, {0, -kNPerBlock}); auto intra_thread_count_last = GetLastloopLayerNormIntraLaneReduceCount(N);
move_tile_window(gamma_block_window, {-kNPerBlock});
move_tile_window(beta_block_window, {-kNPerBlock});
move_tile_window(y_block_window, {0, -kNPerBlock});
}
}
template <typename XBlockWindow, ThreadWelford<ComputeDataType, XDataType> thread_welford{intra_thread_count_last};
typename GammaBlockWindow,
typename BetaBlockWindow,
typename YBlockWindow,
typename MeanBlockWindow,
typename InvStdBlockWindow,
bool Cond = (kHasGamma && kHasBeta)>
CK_TILE_DEVICE std::enable_if_t<Cond>
OnePassLayernorm2dFwd(XBlockWindow& x_block_window,
GammaBlockWindow& gamma_block_window,
BetaBlockWindow& beta_block_window,
YBlockWindow& y_block_window,
MeanBlockWindow& mean_block_window,
InvStdBlockWindow& inv_std_block_window,
ComputeDataType epsilon,
ck_tile::index_t N) const
{
int welford_max_count = GetWelfordMaxCount(N);
ThreadWelford<ComputeDataType, XDataType> thread_welford{welford_max_count};
using XTensorType = decltype(load_tile(x_block_window)); using XTensorType = decltype(load_tile(x_block_window));
auto mean_compute_block_tensor = auto mean_compute_block_tensor =
...@@ -274,21 +210,37 @@ struct Layernorm2dFwd ...@@ -274,21 +210,37 @@ struct Layernorm2dFwd
const auto x_block_tensor = load_tile(x_block_window); const auto x_block_tensor = load_tile(x_block_window);
thread_welford(x_block_tensor, mean_compute_block_tensor, var_compute_block_tensor); thread_welford(x_block_tensor, mean_compute_block_tensor, var_compute_block_tensor);
constexpr auto gammaDstr = MakeGammaBetaBlockTileDistribution();
constexpr auto betaDstr = gammaDstr;
auto gamma_block_window =
make_tile_window(gamma_n, make_tuple(number<kNPerBlock>{}), {0}, gammaDstr);
auto beta_block_window =
make_tile_window(beta_n, make_tuple(number<kNPerBlock>{}), {0}, betaDstr);
const auto gamma_block_tensor = load_tile(gamma_block_window);
const auto beta_block_tensor = load_tile(beta_block_window);
// TODO: support cross warp Welford // TODO: support cross warp Welford
WarpMergeWelford<ComputeDataType, true>{}( WarpMergeWelford<ComputeDataType, true>{}(
mean_compute_block_tensor, var_compute_block_tensor, thread_welford.cur_count_); mean_compute_block_tensor, var_compute_block_tensor, thread_welford.cur_count_);
auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon); auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon);
if constexpr(kSaveMean) // TODO: Extract normalize pipeline
store_tile(mean_block_window, cast_tile<MeanDataType>(mean_compute_block_tensor)); const auto y_m_n = [&]() {
if constexpr(kSaveInvStd) const auto y_dram_naive = make_naive_tensor_view<address_space_enum::global>(
store_tile(inv_std_block_window, p_y, make_tuple(M, N), make_tuple(N, 1), number<S::kNPerThread>{}, number<1>{});
cast_tile<InvStdDataType>(inv_std_compute_block_tensor));
// normalize return pad_tensor_view(y_dram_naive,
const auto gamma_block_tensor = load_tile(gamma_block_window); make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
const auto beta_block_tensor = load_tile(beta_block_window); sequence<false, kPadN>{});
}();
auto y_block_window = make_tile_window(
y_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0});
constexpr auto x_spans = decltype(x_block_tensor)::get_distributed_spans(); constexpr auto x_spans = decltype(x_block_tensor)::get_distributed_spans();
...@@ -317,43 +269,42 @@ struct Layernorm2dFwd ...@@ -317,43 +269,42 @@ struct Layernorm2dFwd
store_tile(y_block_window, y_block_tensor); store_tile(y_block_window, y_block_tensor);
} }
CK_TILE_DEVICE void operator()(Kargs kargs) const template <bool Cond = (kHasGamma && kHasBeta)>
CK_TILE_DEVICE std::enable_if_t<Cond> TwoPassLayernorm2dFwd(const XDataType* p_x,
const GammaDataType* p_gamma,
const BetaDataType* p_beta,
YDataType* p_y,
const ComputeDataType epsilon,
ck_tile::index_t M,
ck_tile::index_t N) const
{ {
using S = typename Problem::BlockShape;
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
const auto x_m_n = [&]() { const auto x_m_n = [&]() {
const auto x_dram_naive = make_naive_tensor_view<address_space_enum::global>( const auto x_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XDataType*>(kargs.p_x), p_x, make_tuple(M, N), make_tuple(N, 1), number<S::kNPerThread>{}, number<1>{});
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.N, 1),
number<kNPerThread>{},
number<1>{});
return pad_tensor_view(x_dram_naive, return pad_tensor_view(x_dram_naive,
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
sequence<kPadM, kPadN>{}); sequence<false, true>{});
}(); }();
const auto gamma_n = [&]() { const auto gamma_n = [&]() {
const auto gamma_dram_naive = make_naive_tensor_view<address_space_enum::global>( const auto gamma_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<const GammaDataType*>(kargs.p_gamma), p_gamma, make_tuple(N), make_tuple(1), number<S::kNPerThread>{}, number<1>{});
make_tuple(kargs.N),
make_tuple(1),
number<kNPerThread>{},
number<1>{});
return pad_tensor_view( return pad_tensor_view(
gamma_dram_naive, make_tuple(number<kNPerBlock>{}), sequence<kPadN>{}); gamma_dram_naive, make_tuple(number<kNPerBlock>{}), sequence<true>{});
}(); }();
const auto beta_n = [&]() { const auto beta_n = [&]() {
const auto gamma_dram_naive = make_naive_tensor_view<address_space_enum::global>( const auto gamma_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<const BetaDataType*>(kargs.p_beta), p_beta, make_tuple(N), make_tuple(1), number<S::kNPerThread>{}, number<1>{});
make_tuple(kargs.N),
make_tuple(1),
number<kNPerThread>{},
number<1>{});
return pad_tensor_view( return pad_tensor_view(
gamma_dram_naive, make_tuple(number<kNPerBlock>{}), sequence<kPadN>{}); gamma_dram_naive, make_tuple(number<kNPerBlock>{}), sequence<true>{});
}(); }();
const auto iM = get_block_id() * kMPerBlock; const auto iM = get_block_id() * kMPerBlock;
...@@ -363,17 +314,52 @@ struct Layernorm2dFwd ...@@ -363,17 +314,52 @@ struct Layernorm2dFwd
auto x_block_window = make_tile_window( auto x_block_window = make_tile_window(
x_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0}, xDstr); x_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0}, xDstr);
index_t num_n_tile_iteration =
__builtin_amdgcn_readfirstlane((N + kNPerBlock - 1) / kNPerBlock);
auto intra_thread_count = S::kNRepeat * S::kNPerThread * (num_n_tile_iteration - 1);
auto intra_thread_count_last = GetLastloopLayerNormIntraLaneReduceCount(N);
ThreadWelford<ComputeDataType, XDataType> thread_welford{intra_thread_count};
ThreadWelford<ComputeDataType, XDataType> thread_welford_last{intra_thread_count_last};
using XTensorType = decltype(load_tile(x_block_window));
auto mean_compute_block_tensor =
thread_welford.template MakeInitialMeanVarDistributedTensor<XTensorType>();
auto var_compute_block_tensor =
thread_welford.template MakeInitialMeanVarDistributedTensor<XTensorType>();
clear_tile(mean_compute_block_tensor);
clear_tile(var_compute_block_tensor);
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration - 1; ++iN)
{
const auto x_block_tensor = load_tile(x_block_window);
thread_welford(x_block_tensor, mean_compute_block_tensor, var_compute_block_tensor);
move_tile_window(x_block_window, {0, kNPerBlock});
}
const auto x_block_tensor_ = load_tile(x_block_window);
thread_welford_last.cur_count_ += intra_thread_count;
thread_welford_last.max_count_ += intra_thread_count;
thread_welford_last(x_block_tensor_, mean_compute_block_tensor, var_compute_block_tensor);
thread_welford.cur_count_ += intra_thread_count_last;
// TODO: support cross warp Welford
WarpMergeWelford<ComputeDataType, true>{}(
mean_compute_block_tensor, var_compute_block_tensor, thread_welford.cur_count_);
auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon);
// TODO: Extract normalize pipeline
const auto y_m_n = [&]() { const auto y_m_n = [&]() {
const auto y_dram_naive = make_naive_tensor_view<address_space_enum::global>( const auto y_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<YDataType*>(kargs.p_y), p_y, make_tuple(M, N), make_tuple(N, 1), number<S::kNPerThread>{}, number<1>{});
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.N, 1),
number<kNPerThread>{},
number<1>{});
return pad_tensor_view(y_dram_naive, return pad_tensor_view(y_dram_naive,
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
sequence<kPadM, kPadN>{}); sequence<false, true>{});
}(); }();
auto y_block_window = make_tile_window( auto y_block_window = make_tile_window(
...@@ -385,67 +371,86 @@ struct Layernorm2dFwd ...@@ -385,67 +371,86 @@ struct Layernorm2dFwd
auto gamma_block_window = auto gamma_block_window =
make_tile_window(gamma_n, make_tuple(number<kNPerBlock>{}), {0}, gammaDstr); make_tile_window(gamma_n, make_tuple(number<kNPerBlock>{}), {0}, gammaDstr);
auto beta_block_window = make_tile_window( auto beta_block_window =
beta_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {0}, betaDstr); make_tile_window(beta_n, make_tuple(number<kNPerBlock>{}), {0}, betaDstr);
auto mean_block_window = [&]() {
if constexpr(kSaveMean)
{
const auto mean_m = [&]() {
const auto mean_dram_naive =
make_naive_tensor_view_packed<address_space_enum::global>(
static_cast<MeanDataType*>(kargs.p_mean),
make_tuple(kargs.M),
number<1>{});
return pad_tensor_view(
mean_dram_naive, make_tuple(number<kMPerBlock>{}), sequence<kPadM>{});
}();
return make_tile_window(mean_m, make_tuple(number<kMPerBlock>{}), {iM});
}
else
return make_null_tile_window(make_tuple(number<kMPerBlock>{}));
}();
auto inv_std_block_window = [&]() { // reverse read x to reuse cache
if constexpr(kSaveInvStd) ck_tile::index_t stride_to_right_most_window =
{ N % kNPerBlock == 0 ? N - kNPerBlock : N - N % kNPerBlock;
const auto inv_std_m = [&]() {
const auto inv_std_dram_naive =
make_naive_tensor_view_packed<address_space_enum::global>(
static_cast<InvStdDataType*>(kargs.p_invStd),
make_tuple(kargs.M),
number<1>{});
return pad_tensor_view(
inv_std_dram_naive, make_tuple(number<kMPerBlock>{}), sequence<kPadM>{});
}();
return make_tile_window(inv_std_m, make_tuple(number<kMPerBlock>{}), {iM});
}
else
return make_null_tile_window(make_tuple(number<kMPerBlock>{}));
}();
if(kargs.N <= kNPerBlock) move_tile_window(gamma_block_window, {stride_to_right_most_window});
OnePassLayernorm2dFwd(x_block_window, move_tile_window(beta_block_window, {stride_to_right_most_window});
gamma_block_window, move_tile_window(y_block_window, {0, stride_to_right_most_window});
beta_block_window,
y_block_window, // Normalization
mean_block_window, for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
inv_std_block_window, {
static_cast<const ComputeDataType>(kargs.epsilon), const auto x_block_tensor = load_tile(x_block_window);
kargs.N); const auto gamma_block_tensor = load_tile(gamma_block_window);
const auto beta_block_tensor = load_tile(beta_block_window);
constexpr auto x_spans = decltype(x_block_tensor)::get_distributed_spans();
auto y_block_tensor =
make_static_distributed_tensor<YDataType>(x_block_tensor.get_tile_distribution());
sweep_tile_span(x_spans[I1], [&](auto idx1) {
constexpr auto j_idx = make_tuple(idx1);
const auto gamma = type_convert<ComputeDataType>(gamma_block_tensor[j_idx]);
const auto beta = type_convert<ComputeDataType>(beta_block_tensor[j_idx]);
sweep_tile_span(x_spans[I0], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto mean = mean_compute_block_tensor[i_idx];
const auto inv_std = inv_std_compute_block_tensor[i_idx];
const auto x = type_convert<ComputeDataType>(x_block_tensor[i_j_idx]);
auto y = (x - mean) * inv_std * gamma + beta;
y_block_tensor(i_j_idx) = type_convert<YDataType>(y);
});
});
store_tile(y_block_window, y_block_tensor);
move_tile_window(x_block_window, {0, -kNPerBlock});
move_tile_window(gamma_block_window, {-kNPerBlock});
move_tile_window(beta_block_window, {-kNPerBlock});
move_tile_window(y_block_window, {0, -kNPerBlock});
}
}
CK_TILE_DEVICE void operator()(const void* p_x,
const void* p_gamma,
const void* p_beta,
void* p_y,
const ComputeDataType epsilon,
ck_tile::index_t M,
ck_tile::index_t N) const
{
if constexpr(kTwoPass)
{
TwoPassLayernorm2dFwd(static_cast<const XDataType*>(p_x),
static_cast<const GammaDataType*>(p_gamma),
static_cast<const BetaDataType*>(p_beta),
static_cast<YDataType*>(p_y),
static_cast<const ComputeDataType>(epsilon),
M,
N);
}
else else
TwoPassLayernorm2dFwd(x_block_window, {
gamma_block_window,
beta_block_window, OnePassLayernorm2dFwd(static_cast<const XDataType*>(p_x),
y_block_window, static_cast<const GammaDataType*>(p_gamma),
mean_block_window, static_cast<const BetaDataType*>(p_beta),
inv_std_block_window, static_cast<YDataType*>(p_y),
static_cast<const ComputeDataType>(kargs.epsilon), static_cast<const ComputeDataType>(epsilon),
kargs.N); M,
N);
}
} }
}; };
......
...@@ -15,20 +15,20 @@ template <typename XDataType_, ...@@ -15,20 +15,20 @@ template <typename XDataType_,
typename MeanDataType_, typename MeanDataType_,
typename InvStdDataType_, typename InvStdDataType_,
typename BlockShape_, typename BlockShape_,
bool kPadM_, bool kPadN_,
bool kPadN_> bool kTwoPass_>
struct BlockLayernorm2dFwdProblem struct BlockLayernorm2dFwdProblem
{ {
using XDataType = remove_cvref_t<XDataType_>; using XDataType = remove_cvref_t<XDataType_>;
using GammaDataType = remove_cvref_t<GammaDataType_>; using GammaDataType = remove_cvref_t<GammaDataType_>;
using BetaDataType = remove_cvref_t<BetaDataType_>; using BetaDataType = remove_cvref_t<BetaDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>; using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using YDataType = remove_cvref_t<YDataType_>; using YDataType = remove_cvref_t<YDataType_>;
using MeanDataType = remove_cvref_t<MeanDataType_>; using MeanDataType = remove_cvref_t<MeanDataType_>;
using InvStdDataType = remove_cvref_t<InvStdDataType_>; using InvStdDataType = remove_cvref_t<InvStdDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>; using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool kPadM = kPadM_; static constexpr bool kPadN = kPadN_;
static constexpr bool kPadN = kPadN_; static constexpr bool kTwoPass = kTwoPass_;
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -12,13 +12,14 @@ template <typename ThreadTile, // Sequence<... ...@@ -12,13 +12,14 @@ template <typename ThreadTile, // Sequence<...
struct TileLayernorm2dShape struct TileLayernorm2dShape
{ {
static constexpr index_t kMPerThread = ThreadTile::at(number<0>{}); static constexpr index_t kMPerThread = ThreadTile::at(number<0>{});
static constexpr index_t kNPerThread = ThreadTile::at(number<1>{}); static constexpr index_t kNRepeat = ThreadTile::at(number<1>{});
static constexpr index_t kNPerThread = ThreadTile::at(number<2>{});
static constexpr index_t kMPerWarp = WarpTile::at(number<0>{}); static constexpr index_t kMPerWarp = WarpTile::at(number<0>{});
static constexpr index_t kNPerWarp = WarpTile::at(number<1>{}); static constexpr index_t kNPerWarp = WarpTile::at(number<1>{});
static constexpr index_t kMThreadPerWarp = kMPerWarp / kMPerThread; static constexpr index_t kMThreadPerWarp = kMPerWarp / kMPerThread;
static constexpr index_t kNThreadPerWarp = kNPerWarp / kNPerThread; static constexpr index_t kNThreadPerWarp = kNPerWarp / kNPerThread / kNRepeat;
static constexpr index_t kMPerBlock = BlockTile::at(number<0>{}); static constexpr index_t kMPerBlock = BlockTile::at(number<0>{});
static constexpr index_t kNPerBlock = BlockTile::at(number<1>{}); static constexpr index_t kNPerBlock = BlockTile::at(number<1>{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment