Unverified Commit cb138394 authored by rocking's avatar rocking Committed by GitHub
Browse files

layernorm2d forward (#1339)



* Add layernorm2d forward

* Refind file path

* clang format

* Exclude ck_tile op from all

* use add_executable instead

* refactor layernorm2d_fwd example

---------
Co-authored-by: default avatarcarlushuang <carlus.huang@amd.com>
parent 05b10e0e
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
add_executable(tile_example_layernorm2d_fwd EXCLUDE_FROM_ALL layernorm2d_fwd.cpp)
target_compile_options(tile_example_layernorm2d_fwd PRIVATE -DSAVE_MEAN_INV_STD)
\ No newline at end of file
# Layernorm2D forward
This folder contains example for Layernorm2D forward using ck_tile tile-programming implementation.
## build
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_example_layernorm2d_fwd -j
```
This will result in an executable `build/bin/tile_example_layernorm2d_fwd`
## example
```
args:
-m m dimension (default:3328)
-n m dimension (default:4096)
-e epsilon (default:1e-5)
-v cpu validation or not (default:1)
-prec precision (default:fp16)
```
\ No newline at end of file
#include "ck_tile/host.hpp"
#include "layernorm2d_fwd.hpp"
#include <cstring>
// Host API implementation
float layernorm2d_fwd(layernorm2d_fwd_traits t,
layernorm2d_fwd_args a,
const ck_tile::stream_config& s)
{
if(t.data_type.compare("fp16") == 0)
{
using XDataType = ck_tile::half_t;
using YDataType = ck_tile::half_t;
using GammaDataType = ck_tile::half_t;
using BetaDataType = ck_tile::half_t;
#ifdef SAVE_MEAN_INV_STD
using MeanDataType = ck_tile::half_t;
using InvStdDataType = ck_tile::half_t;
#else
using MeanDataType = ck_tile::null_type;
using InvStdDataType = ck_tile::null_type;
#endif
using ComputeDataType = float;
using thread_tile = ck_tile::sequence<4, 4>;
using warp_tile = ck_tile::sequence<8, 128>;
using block_tile = ck_tile::sequence<32, 128>;
using Shape = ck_tile::TileLayernorm2dShape<thread_tile, warp_tile, block_tile>;
using PipelineProblem = ck_tile::BlockLayernorm2dFwdProblem<XDataType,
GammaDataType,
BetaDataType,
ComputeDataType,
YDataType,
MeanDataType,
InvStdDataType,
Shape>;
using Kernel = ck_tile::Layernorm2dFwd<PipelineProblem>;
auto kargs = Kernel::MakeKargs(
a.p_x, a.p_gamma, a.p_beta, a.p_y, a.p_mean, a.p_invStd, a.epsilon, a.M, a.N);
const dim3 grids = Kernel::GridSize(a.M);
constexpr dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = Shape::kMWarpPerBlock * Shape::kNWarpPerBlock;
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
return 0;
}
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3328", "m dimension")
.insert("n", "4096", "m dimension")
.insert("e", "1e-5", "epsilon")
.insert("v", "1", "cpu validation or not")
.insert("prec", "fp16", "precision");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
float epsilon = arg_parser.get_float("e");
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
std::string data_type = arg_parser.get_str("prec");
int do_validation = arg_parser.get_int("v");
using XDataType = ck_tile::half_t;
using YDataType = ck_tile::half_t;
using GammaDataType = ck_tile::half_t;
using BetaDataType = ck_tile::half_t;
#ifdef SAVE_MEAN_INV_STD
using MeanDataType = ck_tile::half_t;
using InvStdDataType = ck_tile::half_t;
#else
using MeanDataType = ck_tile::null_type;
using InvStdDataType = ck_tile::null_type;
#endif
using ComputeDataType = float;
// host verify
ck_tile::HostTensor<XDataType> x_host({M, N});
ck_tile::HostTensor<GammaDataType> gamma_host({N});
ck_tile::HostTensor<BetaDataType> beta_host({N});
ck_tile::HostTensor<YDataType> y_host_ref({M, N});
ck_tile::HostTensor<YDataType> y_host_dev({M, N});
ck_tile::HostTensor<MeanDataType> mean_host_ref({M});
ck_tile::HostTensor<InvStdDataType> invStd_host_ref({M});
#ifdef SAVE_MEAN_INV_STD
ck_tile::HostTensor<MeanDataType> mean_host_dev({M});
ck_tile::HostTensor<InvStdDataType> invStd_host_dev({M});
#endif
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(x_host);
ck_tile::FillUniformDistribution<GammaDataType>{-5.f, 5.f}(gamma_host);
ck_tile::FillUniformDistribution<BetaDataType>{-5.f, 5.f}(beta_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes());
#ifdef SAVE_MEAN_INV_STD
ck_tile::DeviceMem mean_buf(mean_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem invStd_buf(invStd_host_dev.get_element_space_size_in_bytes());
#endif
x_buf.ToDevice(x_host.data());
gamma_buf.ToDevice(gamma_host.data());
beta_buf.ToDevice(beta_host.data());
layernorm2d_fwd_traits traits{data_type};
layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(),
gamma_buf.GetDeviceBuffer(),
beta_buf.GetDeviceBuffer(),
y_buf.GetDeviceBuffer(),
#ifdef SAVE_MEAN_INV_STD
mean_buf.GetDeviceBuffer(),
invStd_buf.GetDeviceBuffer(),
#else
nullptr,
nullptr,
#endif
epsilon,
M,
N};
float ave_time = layernorm2d_fwd(traits, args, ck_tile::stream_config{nullptr, true});
std::size_t num_byte = sizeof(XDataType) * M * N + sizeof(GammaDataType) * N +
sizeof(BetaDataType) * N + sizeof(YDataType) * M * N;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "[" << data_type << "]"
<< " m:" << M << ", n:" << N << ", " << ave_time << " ms, " << gb_per_sec << " GB/s"
<< std::flush;
bool pass = true;
if(do_validation)
{
// reference
ck_tile::reference_layernorm2d_fwd<XDataType,
GammaDataType,
BetaDataType,
ComputeDataType,
YDataType,
MeanDataType,
InvStdDataType>(
x_host, gamma_host, beta_host, y_host_ref, mean_host_ref, invStd_host_ref, epsilon);
y_buf.FromDevice(y_host_dev.data());
pass = ck_tile::check_err(y_host_dev, y_host_ref);
#ifdef SAVE_MEAN_INV_STD
mean_buf.FromDevice(mean_host_dev.data());
pass &= ck_tile::check_err(mean_host_dev, mean_host_ref);
invStd_buf.FromDevice(invStd_host_dev.data());
pass &= ck_tile::check_err(invStd_host_dev, invStd_host_ref);
#endif
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush;
}
std::cout << std::endl << std::flush;
return !pass;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/layernorm2d.hpp"
#include <string>
struct layernorm2d_fwd_traits
{
std::string data_type;
};
struct layernorm2d_fwd_args
{
const void* p_x;
const void* p_gamma;
const void* p_beta;
void* p_y;
void* p_mean;
void* p_invStd;
float epsilon;
ck_tile::index_t M;
ck_tile::index_t N;
};
// host API
float layernorm2d_fwd(layernorm2d_fwd_traits, layernorm2d_fwd_args, const ck_tile::stream_config&);
......@@ -3,3 +3,4 @@ include_directories(AFTER
)
add_subdirectory(01_fmha)
add_subdirectory(02_layernorm2d)
......@@ -27,6 +27,7 @@
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/null_type.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/numeric/type_convert.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <stdint.h>
namespace ck_tile {
struct null_type
{
};
} // namespace ck_tile
......@@ -18,6 +18,7 @@
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_layernorm2d.hpp"
#include "ck_tile/host/reference/reference_reduce.hpp"
#include "ck_tile/host/reference/reference_softmax.hpp"
#include "ck_tile/host/stream_config.hpp"
......
......@@ -56,8 +56,9 @@ check_err(const Range& out,
}
const auto is_infinity_error = [=](auto o, auto r) {
const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r);
const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
const bool both_infinite_and_same =
std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
};
......@@ -114,8 +115,9 @@ check_err(const Range& out,
}
const auto is_infinity_error = [=](auto o, auto r) {
const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r);
const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
const bool both_infinite_and_same =
std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
};
......@@ -173,8 +175,9 @@ check_err(const Range& out,
}
const auto is_infinity_error = [=](auto o, auto r) {
const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r);
const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
const bool both_infinite_and_same =
std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
};
......@@ -285,8 +288,9 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
}
const auto is_infinity_error = [=](auto o, auto r) {
const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r);
const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
const bool both_infinite_and_same =
std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
};
......@@ -357,8 +361,9 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
}
const auto is_infinity_error = [=](auto o, auto r) {
const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r);
const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
const bool both_infinite_and_same =
std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace ck_tile {
template <typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename ComputeDataType,
typename YDataType,
typename MeanDataType,
typename InvStdDataType>
void reference_layernorm2d_fwd(const HostTensor<XDataType>& x_m_n,
const HostTensor<GammaDataType>& gamma_n,
const HostTensor<BetaDataType>& beta_n,
HostTensor<YDataType>& y_m_n,
HostTensor<MeanDataType>& mean_m,
HostTensor<InvStdDataType>& invStd_m,
ComputeDataType epsilon)
{
auto layernorm2d_fwd_func = [&](auto m) {
const int N = x_m_n.mDesc.get_lengths()[1];
int count = 0;
ComputeDataType mean = 0;
ComputeDataType variance = 0;
ComputeDataType divisor = 0;
for(int n = 0; n < N; ++n)
{
++count;
ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
ComputeDataType delta = x - mean;
mean += delta / count;
ComputeDataType delta2 = x - mean;
variance += delta * delta2;
}
// actual variance
variance = variance / count;
divisor = ck_tile::type_convert<ComputeDataType>(1) / ck_tile::sqrt(variance + epsilon);
if constexpr(!std::is_same_v<MeanDataType, ck_tile::null_type>)
mean_m(m) = ck_tile::type_convert<MeanDataType>(mean);
if constexpr(!std::is_same_v<InvStdDataType, ck_tile::null_type>)
invStd_m(m) = ck_tile::type_convert<InvStdDataType>(divisor);
for(int n = 0; n < N; ++n)
{
ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
ComputeDataType gamma = ck_tile::type_convert<ComputeDataType>(gamma_n(n));
ComputeDataType beta = ck_tile::type_convert<ComputeDataType>(beta_n(n));
auto y = (x - mean) * divisor;
y = y * gamma + beta;
y_m_n(m, n) = ck_tile::type_convert<YDataType>(y);
}
};
make_ParallelTensorFunctor(layernorm2d_fwd_func,
mean_m.mDesc.get_lengths()[0])(std::thread::hardware_concurrency());
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/welford/thread/thread_welford.hpp"
#include "ck_tile/ops/welford/warp/warp_welford.hpp"
namespace ck_tile {
// TODO: Extract some type to wrapper class
template <typename Problem_>
struct Layernorm2dFwd
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
using BetaDataType = ck_tile::remove_cvref_t<typename Problem::BetaDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
using MeanDataType = ck_tile::remove_cvref_t<typename Problem::MeanDataType>;
using InvStdDataType = ck_tile::remove_cvref_t<typename Problem::InvStdDataType>;
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
static constexpr bool kHasBeta = !std::is_same_v<BetaDataType, ck_tile::null_type>;
static constexpr bool kSaveMean = !std::is_same_v<MeanDataType, ck_tile::null_type>;
static constexpr bool kSaveInvStd = !std::is_same_v<InvStdDataType, ck_tile::null_type>;
static constexpr ck_tile::index_t kMPerBlock = Problem::BlockShape::kMPerBlock;
static constexpr ck_tile::index_t kNPerBlock = Problem::BlockShape::kNPerBlock;
static constexpr ck_tile::index_t kNThreadPerWarp = Problem::BlockShape::kNThreadPerWarp;
struct Kargs
{
const void* p_x;
const void* p_gamma;
const void* p_beta;
void* p_y;
void* p_mean;
void* p_invStd;
float epsilon;
ck_tile::index_t M;
ck_tile::index_t N;
};
CK_TILE_HOST static constexpr Kargs MakeKargs(const void* p_x,
const void* p_gamma,
const void* p_beta,
void* p_y,
void* p_mean,
void* p_invStd,
float epsilon,
ck_tile::index_t M,
ck_tile::index_t N)
{
return Kargs{p_x, p_gamma, p_beta, p_y, p_mean, p_invStd, epsilon, M, N};
}
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t M) { return M / kMPerBlock; }
CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::kBlockSize; }
CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution()
{
using S = typename Problem::BlockShape;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<>,
tuple<sequence<S::kMWarpPerBlock, S::kMThreadPerWarp, S::kMPerThread>,
sequence<S::kNWarpPerBlock, S::kNThreadPerWarp, S::kNPerThread>>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<0, 0>, sequence<1, 1>>,
sequence<1, 2>,
sequence<2, 2>>{});
}
CK_TILE_DEVICE static constexpr auto MakeGammaBetaBlockTileDistribution()
{
using S = typename Problem::BlockShape;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<S::kMWarpPerBlock, S::kMThreadPerWarp>,
tuple<sequence<S::kNWarpPerBlock, S::kNThreadPerWarp, S::kNPerThread>>,
tuple<sequence<0, 1>, sequence<0, 1>>,
tuple<sequence<0, 0>, sequence<1, 1>>,
sequence<1>,
sequence<2>>{});
}
template <typename Dstr>
CK_TILE_DEVICE static constexpr auto GetNPerThread(Dstr)
{
constexpr auto nDstrSpan = Dstr::get_distributed_spans().template at<1>();
using Lengths = decltype(nDstrSpan.impl_);
ck_tile::index_t ret = 1;
ck_tile::static_for<0, Lengths::size(), 1>{}(
[&](auto idx) { ret *= Lengths::template at(idx); });
return ret;
}
template <typename DistributedTensor>
CK_TILE_DEVICE static auto InvSqrt(const DistributedTensor& in_dstr_tensor,
const ComputeDataType epsilon)
{
// TODO: Investigate fast inverse square root algorithm with epsilon
constexpr auto spans = DistributedTensor::get_distributed_spans();
DistributedTensor out_dstr_tensor;
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
out_dstr_tensor(i_idx) = type_convert<ComputeDataType>(1.0f) /
ck_tile::sqrt(in_dstr_tensor[i_idx] + epsilon);
});
return out_dstr_tensor;
}
template <bool Cond = (kHasGamma && kHasBeta)>
CK_TILE_DEVICE std::enable_if_t<Cond> TwoPassLayernorm2dFwd(const XDataType* p_x,
const GammaDataType* p_gamma,
const BetaDataType* p_beta,
YDataType* p_y,
MeanDataType* p_mean,
InvStdDataType* p_invStd,
const ComputeDataType epsilon,
ck_tile::index_t M,
ck_tile::index_t N) const
{
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
const auto x_m_n = make_naive_tensor_view<address_space_enum::global>(
p_x, make_tuple(M, N), make_tuple(N, 1), number<32>{}, number<1>{});
const auto gamma_n = make_naive_tensor_view<address_space_enum::global>(
p_gamma, make_tuple(N), make_tuple(1), number<32>{}, number<1>{});
const auto beta_n = make_naive_tensor_view<address_space_enum::global>(
p_beta, make_tuple(N), make_tuple(1), number<32>{}, number<1>{});
const auto iM = get_block_id() * kMPerBlock;
constexpr auto xDstr = MakeXBlockTileDistribution();
auto x_block_window = make_tile_window(
x_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0}, xDstr);
index_t num_n_tile_iteration = __builtin_amdgcn_readfirstlane(N / kNPerBlock);
// TODO: padding - handle max_count if N % kNPerBlock != 0
constexpr auto NPerThread = GetNPerThread(xDstr);
ThreadWelford<ComputeDataType, XDataType> thread_welford{
type_convert<int>(NPerThread * N / kNPerBlock)};
using XTensorType = decltype(load_tile(x_block_window));
auto mean_compute_block_tensor =
thread_welford.template MakeInitialMeanVarDistributedTensor<XTensorType>();
auto var_compute_block_tensor =
thread_welford.template MakeInitialMeanVarDistributedTensor<XTensorType>();
clear_tile(mean_compute_block_tensor);
clear_tile(var_compute_block_tensor);
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
const auto x_block_tensor = load_tile(x_block_window);
thread_welford(x_block_tensor, mean_compute_block_tensor, var_compute_block_tensor);
move_tile_window(x_block_window, {0, kNPerBlock});
}
// TODO: support cross warp Welford
WarpMergeWelford<ComputeDataType, true>{}(
mean_compute_block_tensor, var_compute_block_tensor, thread_welford.cur_count_);
auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon);
if constexpr(kSaveMean)
{
const auto mean_m = make_naive_tensor_view_packed<address_space_enum::global>(
p_mean, make_tuple(M), number<32>{});
auto mean_block_window =
make_tile_window(mean_m, make_tuple(number<kMPerBlock>{}), {iM});
store_tile(mean_block_window, cast_tile<MeanDataType>(mean_compute_block_tensor));
}
if constexpr(kSaveInvStd)
{
const auto inv_std_m = make_naive_tensor_view_packed<address_space_enum::global>(
p_invStd, make_tuple(M), number<32>{});
auto inv_std_block_window =
make_tile_window(inv_std_m, make_tuple(number<kMPerBlock>{}), {iM});
store_tile(inv_std_block_window, cast_tile<MeanDataType>(inv_std_compute_block_tensor));
}
// TODO: Extract normalize pipeline
const auto y_m_n = make_naive_tensor_view<address_space_enum::global>(
p_y, make_tuple(M, N), make_tuple(N, 1), number<32>{}, number<1>{});
auto y_block_window = make_tile_window(
y_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0});
constexpr auto gammaDstr = MakeGammaBetaBlockTileDistribution();
constexpr auto betaDstr = gammaDstr;
auto gamma_block_window =
make_tile_window(gamma_n, make_tuple(number<kNPerBlock>{}), {0}, gammaDstr);
auto beta_block_window = make_tile_window(
beta_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {0}, betaDstr);
// reverse read x to reuse cache
ck_tile::index_t stride_to_right_most_window = N - kNPerBlock;
move_tile_window(x_block_window, {0, -kNPerBlock});
move_tile_window(gamma_block_window, {stride_to_right_most_window});
move_tile_window(beta_block_window, {stride_to_right_most_window});
move_tile_window(y_block_window, {0, stride_to_right_most_window});
// Normalization
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
const auto x_block_tensor = load_tile(x_block_window);
const auto gamma_block_tensor = load_tile(gamma_block_window);
const auto beta_block_tensor = load_tile(beta_block_window);
constexpr auto x_spans = decltype(x_block_tensor)::get_distributed_spans();
auto y_block_tensor =
make_static_distributed_tensor<YDataType>(x_block_tensor.get_tile_distribution());
sweep_tile_span(x_spans[I1], [&](auto idx1) {
constexpr auto j_idx = make_tuple(idx1);
const auto gamma = type_convert<ComputeDataType>(gamma_block_tensor[j_idx]);
const auto beta = type_convert<ComputeDataType>(beta_block_tensor[j_idx]);
sweep_tile_span(x_spans[I0], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto mean = mean_compute_block_tensor[i_idx];
const auto inv_std = inv_std_compute_block_tensor[i_idx];
const auto x = type_convert<ComputeDataType>(x_block_tensor[i_j_idx]);
auto y = (x - mean) * inv_std * gamma + beta;
y_block_tensor(i_j_idx) = type_convert<YDataType>(y);
});
});
store_tile(y_block_window, y_block_tensor);
move_tile_window(x_block_window, {0, -kNPerBlock});
move_tile_window(gamma_block_window, {-kNPerBlock});
move_tile_window(beta_block_window, {-kNPerBlock});
move_tile_window(y_block_window, {0, -kNPerBlock});
}
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
TwoPassLayernorm2dFwd(static_cast<const XDataType*>(kargs.p_x),
static_cast<const GammaDataType*>(kargs.p_gamma),
static_cast<const BetaDataType*>(kargs.p_beta),
static_cast<YDataType*>(kargs.p_y),
static_cast<MeanDataType*>(kargs.p_mean),
static_cast<InvStdDataType*>(kargs.p_invStd),
static_cast<const ComputeDataType>(kargs.epsilon),
kargs.M,
kargs.N);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <typename XDataType_,
typename GammaDataType_,
typename BetaDataType_,
typename ComputeDataType_,
typename YDataType_,
typename MeanDataType_,
typename InvStdDataType_,
typename BlockShape_>
struct BlockLayernorm2dFwdProblem
{
using XDataType = remove_cvref_t<XDataType_>;
using GammaDataType = remove_cvref_t<GammaDataType_>;
using BetaDataType = remove_cvref_t<BetaDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using YDataType = remove_cvref_t<YDataType_>;
using MeanDataType = remove_cvref_t<MeanDataType_>;
using InvStdDataType = remove_cvref_t<InvStdDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename ThreadTile, // Sequence<...
typename WarpTile, // Sequence<...
typename BlockTile> // Sequence<...
struct TileLayernorm2dShape
{
static constexpr index_t kMPerThread = ThreadTile::at(number<0>{});
static constexpr index_t kNPerThread = ThreadTile::at(number<1>{});
static constexpr index_t kMPerWarp = WarpTile::at(number<0>{});
static constexpr index_t kNPerWarp = WarpTile::at(number<1>{});
static constexpr index_t kMThreadPerWarp = kMPerWarp / kMPerThread;
static constexpr index_t kNThreadPerWarp = kNPerWarp / kNPerThread;
static constexpr index_t kMPerBlock = BlockTile::at(number<0>{});
static constexpr index_t kNPerBlock = BlockTile::at(number<1>{});
static constexpr index_t kMWarpPerBlock = kMPerBlock / kMPerWarp;
static constexpr index_t kNWarpPerBlock = kNPerBlock / kNPerWarp;
// TODO - kNNumWarps can only be 1 if we don't support cross warp welford
static_assert(kNWarpPerBlock == 1);
static constexpr index_t kBlockSize = warpSize * kMWarpPerBlock * kNWarpPerBlock;
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/welford/thread/thread_welford.hpp"
#include "ck_tile/ops/welford/warp/warp_welford.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename ComputeDataType_, typename XDataType_>
struct ThreadWelford
{
using XDataType = remove_cvref_t<XDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
template <typename T>
CK_TILE_DEVICE void Update(T& mean, T& var, T x)
{
if(ck_tile::isnan(x))
{
mean = x;
var = x;
}
else
{
T delta = x - mean;
mean += delta / cur_count_;
T delta2 = x - mean;
var += delta * delta2;
}
}
// [CAUSION] - max_count_ is to deal with the padding problem
// max_count_ is depend on caller, eg: naive and splitN welford will have different
// calculation of max_count_
CK_TILE_DEVICE constexpr ThreadWelford(int max_count) : cur_count_(0), max_count_(max_count) {}
template <typename XDistributedTensor_,
typename MeanDistributedTensor_,
typename VarDistributedTensor_>
CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor,
MeanDistributedTensor_& mean_tensor,
VarDistributedTensor_& var_tensor)
{
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
constexpr auto spans = XDistributedTensor_::get_distributed_spans();
sweep_tile_span(spans[I1], [&](auto dstr_idx_i1) {
if(cur_count_ < max_count_)
{
++cur_count_;
sweep_tile_span(spans[I0], [&](auto dstr_idx_i0) {
constexpr auto in_dstr_idx = make_tuple(dstr_idx_i0, dstr_idx_i1);
constexpr auto out_dstr_idx = make_tuple(dstr_idx_i0);
auto x = ck_tile::type_convert<ComputeDataType>(x_tensor[in_dstr_idx]);
Update(mean_tensor(out_dstr_idx), var_tensor(out_dstr_idx), x);
});
}
});
}
template <typename XDistributedTensor_>
CK_TILE_DEVICE static auto MakeInitialMeanVarDistributedTensor()
{
static_assert(std::is_same_v<XDataType, typename XDistributedTensor_::DataType>, "wrong!");
constexpr auto reduce_dims = sequence<1>{};
constexpr auto dstr =
make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding(
XDistributedTensor_::get_tile_distribution()
.get_static_tile_distribution_encoding(),
reduce_dims));
auto tensor = make_static_distributed_tensor<ComputeDataType>(dstr);
clear_tile(tensor);
return tensor;
}
template <typename XDistributedTensor_>
CK_TILE_DEVICE auto operator()(const XDistributedTensor_& x_tensor)
{
auto mean_tensor = MakeInitialMeanVarDistributedTensor<XDistributedTensor_>();
auto var_tensor = MakeInitialMeanVarDistributedTensor<XDistributedTensor_>();
(*this)(x_tensor, mean_tensor, var_tensor);
return ck_tile::make_tuple(mean_tensor, var_tensor);
}
int cur_count_;
int max_count_;
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename ComputeDataType_, bool BroadcastLane = true, bool GetActualVariance = true>
struct WarpMergeWelford
{
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
template <typename T>
CK_TILE_DEVICE static void
Merge(T& mean_a, T& var_a, int& count_a, T mean_b, T var_b, int count_b)
{
int count = count_a + count_b;
T count_ = type_convert<T>(count);
T count_a_ = type_convert<T>(count_a);
T count_b_ = type_convert<T>(count_b);
T count_b_over_count = count == 0 ? type_convert<T>(0) : count_b_ / count_;
T delta = mean_b - mean_a;
mean_a += delta * count_b_over_count;
var_a += var_b + delta * delta * count_a_ * count_b_over_count;
count_a = count;
}
template <typename MeanDistributedTensor_, typename VarDistributedTensor_>
CK_TILE_DEVICE void
operator()(MeanDistributedTensor_& mean_tensor, VarDistributedTensor_& var_tensor, int& count)
{
using Dstr = typename MeanDistributedTensor_::StaticTileDistribution;
using DstrEncode = typename Dstr::DstrEncode;
using DstrEncodeDetail = typename DstrEncode::detail;
static_assert(std::is_same_v<Dstr, typename VarDistributedTensor_::StaticTileDistribution>,
"wrong!");
constexpr index_t NDimP = Dstr::get_num_of_dimension_p();
constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
constexpr index_t idim_p_lane = NDimP - 1;
const auto ps_idx = make_array<index_t>(get_warp_id(), get_lane_id());
const auto rs_idx =
mean_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size();
static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size());
const int original_count = count;
// loop over thread data
static_for<0, thread_buf_size, 1>{}([&](auto i) {
auto v_local_mean = mean_tensor.get_thread_buffer()[i];
auto v_local_var = var_tensor.get_thread_buffer()[i];
auto v_local_count = original_count;
// cross-lane reduce for replication
// only reduce on R dimension correspond to lane
// (lane id maps to this R dimension)
static_for<0, NDimR, 1>{}([&](auto idim_r) {
// FIXME: nasty to use does_p_own_r_
if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
{
constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
constexpr index_t lid_over_rid_derivative =
DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
static_assert(is_power_of_two_integer(r_length),
"wrong! only support power of 2 reduction");
constexpr index_t nstage = integer_log2_floor(r_length);
// reduction sweep forward
static_for<0, nstage, 1>{}([&](auto istage) {
constexpr index_t lid_delta =
lid_over_rid_derivative * (1 << (nstage - istage - 1));
// pull data from remote lane
const auto v_remote_mean = warp_shuffle_down(v_local_mean, lid_delta);
const auto v_remote_var = warp_shuffle_down(v_local_var, lid_delta);
const auto v_remote_count = warp_shuffle_down(v_local_count, lid_delta);
// welford merge
Merge(v_local_mean,
v_local_var,
v_local_count,
v_remote_mean,
v_remote_var,
v_remote_count);
});
}
});
// cross-lane broadcast for replication
// only broadcast on R dimension correspond to lane
// (lane id maps to this R dimension)
if constexpr(BroadcastLane)
{
static_for<0, NDimR, 1>{}([&](auto idim_r) {
// FIXME: nasty to use does_p_own_r_
if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
{
const index_t r_id = rs_idx[idim_r];
constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
constexpr index_t lid_over_rid_derivative =
DstrEncodeDetail::ps_over_rs_derivative_[NDimP - 1][idim_r];
static_assert(is_power_of_two_integer(r_length),
"wrong! only support power of 2 reduction");
constexpr index_t nstage = integer_log2_floor(r_length);
// broadcast sweep backward
static_for<0, nstage, 1>{}([&](auto istage) {
// do I hold reduced data?
const bool do_i_hold_reduced_data = r_id < (1 << istage);
constexpr index_t lid_delta = lid_over_rid_derivative * (1 << istage);
// pull data from remote lane
const auto v_remote_mean = warp_shuffle_up(v_local_mean, lid_delta);
const auto v_remote_var = warp_shuffle_up(v_local_var, lid_delta);
const auto v_remote_count = warp_shuffle_up(v_local_count, lid_delta);
// decide whether to update local data with remote data
v_local_mean = do_i_hold_reduced_data ? v_local_mean : v_remote_mean;
v_local_var = do_i_hold_reduced_data ? v_local_var : v_remote_var;
v_local_count = do_i_hold_reduced_data ? v_local_count : v_remote_count;
});
}
});
}
mean_tensor.get_thread_buffer()(i) = v_local_mean;
if constexpr(GetActualVariance)
var_tensor.get_thread_buffer()(i) = v_local_var / v_local_count;
else
var_tensor.get_thread_buffer()(i) = v_local_var;
count = v_local_count;
});
}
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment