Unverified Commit 0bd6b842 authored by rocking5566's avatar rocking5566 Committed by GitHub
Browse files

Layernorm welford (#346)



* Add threadwise and blockwise welford

* Rename gridwise op, prepare to add welford version

* implement welford and integrate welford into layernorm

* Take care of tail loop

* Fix buf when ThreadSliceK > 1

* Fix bug of merging of two empty set

* Rename clip to clamp

* 1. Fix type of count
2. Remove useless static_assert

* Do not inherit Reduction::Argument

* [What] replace __syncthreads() with block_sync_lds()
[Why] __syncthreads might wait both lgkmcnt(0) and vmcnt(0)

* Add y stride

* Rename.
DeviceLayernorm -> DeviceLayernormImpl
DeviceNormalization2 -> DeviceLayernorm

* Move literal ""_uz & ""_zu into namespace 'literals'

* Move namespace 'literals' as 'ck::literals'
Co-authored-by: default avatarPo-Yen, Chen <PoYen.Chen@amd.com>
Co-authored-by: default avatarChao Liu <chao.liu2@amd.com>
parent c20a75b0
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp" #include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/device_layernorm.hpp" #include "ck/tensor_operation/gpu/device/device_layernorm_impl.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" #include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
...@@ -29,7 +29,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -29,7 +29,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
constexpr int Rank = 2; constexpr int Rank = 2;
constexpr int NumReduceDim = 1; constexpr int NumReduceDim = 1;
using DeviceInstance = ck::tensor_operation::device::DeviceLayernorm<XDataType, using DeviceInstance = ck::tensor_operation::device::DeviceLayernormImpl<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
AccDataType, AccDataType,
...@@ -90,6 +90,7 @@ int main() ...@@ -90,6 +90,7 @@ int main()
std::vector<ck::index_t>{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()}, std::vector<ck::index_t>{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()},
std::vector<ck::index_t>{gamma.mDesc.GetStrides().begin(), gamma.mDesc.GetStrides().end()}, std::vector<ck::index_t>{gamma.mDesc.GetStrides().begin(), gamma.mDesc.GetStrides().end()},
std::vector<ck::index_t>{beta.mDesc.GetStrides().begin(), beta.mDesc.GetStrides().end()}, std::vector<ck::index_t>{beta.mDesc.GetStrides().begin(), beta.mDesc.GetStrides().end()},
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()},
{1}, {1},
1e-4, 1e-4,
x_dev.GetDeviceBuffer(), x_dev.GetDeviceBuffer(),
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/reduction_common.hpp"
namespace ck {
// clang-format off
// Assume:
// 1) work_buffer is buffer (typically LDS) allocated outside as workspace
// 2) work_buffer has T elements, and space size is no less than 3*BlockSize
// 3) mean_value, var_value and count is the input data in vgpr from each thread
// 4) mean_value, var_value and count is the over-written reduced output in vgpr for each thread
// 5) Merge mean and M from ThreadwiseWelford
// clang-format on
template <typename T,
index_t BlockSize,
typename ThreadClusterLengths_M_K,
typename ThreadClusterArrangeOrder,
bool GetActualVariance = true>
struct BlockwiseWelford
{
static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1),
"The product of cluster lengths should be same as BlockSize!");
static constexpr auto BufferLength_M = ThreadClusterLengths_M_K::At(0);
static constexpr auto BufferLength_K = ThreadClusterLengths_M_K::At(1);
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<BufferLength_M>{}, Number<BufferLength_K>{}));
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
__device__ static inline void
Merge(T& mean_a, T& var_a, int& count_a, T mean_b, T var_b, int count_b)
{
int count = count_a + count_b;
T count_b_over_count = count == 0 ? type_convert<T>(0) : type_convert<T>(count_b) / count;
T delta = mean_b - mean_a;
mean_a += delta * count_b_over_count;
var_a += var_b + delta * delta * count_a * count_b_over_count;
count_a = count;
}
__device__ static void Run(T& mean_value, T& var_value, int& count)
{
__shared__ T mean_block_buf[BlockSize];
__shared__ T var_block_buf[BlockSize];
__shared__ int count_block_buf[BlockSize];
constexpr auto cluster_len_shift = get_shift<BufferLength_K>();
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}];
const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}];
index_t offset1 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx);
mean_block_buf[offset1] = mean_value;
var_block_buf[offset1] = var_value;
count_block_buf[offset1] = count;
block_sync_lds();
static_for<0, cluster_len_shift, 1>{}([&](auto I) {
constexpr index_t indOffset = 1 << (cluster_len_shift - 1 - I());
if(thread_k_cluster_id < indOffset)
{
index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx +
make_tuple(0, indOffset));
T mean1 = mean_block_buf[offset1];
T var1 = var_block_buf[offset1];
int count1 = count_block_buf[offset1];
T mean2 = mean_block_buf[offset2];
T var2 = var_block_buf[offset2];
int count2 = count_block_buf[offset2];
Merge(mean1, var1, count1, mean2, var2, count2);
mean_block_buf[offset1] = mean1;
var_block_buf[offset1] = var1;
count_block_buf[offset1] = count1;
}
block_sync_lds();
});
index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0));
count = count_block_buf[offset];
mean_value = mean_block_buf[offset];
if constexpr(GetActualVariance)
var_value = var_block_buf[offset] / count;
else
var_value = var_block_buf[offset];
};
};
} // namespace ck
...@@ -9,13 +9,48 @@ ...@@ -9,13 +9,48 @@
#include "ck/utility/reduction_operator.hpp" #include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization.hpp" #include "ck/tensor_operation/gpu/device/device_normalization.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp" #include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_common.hpp" #include "ck/tensor_operation/gpu/device/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_layernorm.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_layernorm_welford_variance.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
namespace ck {
template <typename GridwiseReduction,
typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename YDataType,
typename AccDataType,
typename AccElementwiseOperation,
typename GridDesc_M_K,
typename GridDesc_K>
__global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k,
const GridDesc_K gamma_grid_desc_k,
const GridDesc_K beta_grid_desc_k,
const GridDesc_M_K y_grid_desc_m_k,
index_t num_k_block_tile_iteration,
AccDataType epsilon,
const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global,
const AccElementwiseOperation acc_elementwise_op)
{
GridwiseReduction::Run(x_grid_desc_m_k,
gamma_grid_desc_k,
beta_grid_desc_k,
y_grid_desc_m_k,
num_k_block_tile_iteration,
epsilon,
p_x_global,
p_gamma_global,
p_beta_global,
p_y_global,
acc_elementwise_op);
};
} // namespace ck
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -39,7 +74,7 @@ template <typename XDataType, ...@@ -39,7 +74,7 @@ template <typename XDataType,
index_t GammaSrcVectorSize, index_t GammaSrcVectorSize,
index_t BetaSrcVectorSize, index_t BetaSrcVectorSize,
index_t YDstVectorSize> index_t YDstVectorSize>
struct DeviceLayernorm : public DeviceNormalization2<XDataType, struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
AccDataType, AccDataType,
...@@ -58,27 +93,74 @@ struct DeviceLayernorm : public DeviceNormalization2<XDataType, ...@@ -58,27 +93,74 @@ struct DeviceLayernorm : public DeviceNormalization2<XDataType,
using PassThrough = tensor_operation::element_wise::PassThrough; using PassThrough = tensor_operation::element_wise::PassThrough;
// Used for freeloading of some handy functions from DeviceReduceMultiBlock static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
using Reduction = DeviceReduceMultiBlock<XDataType, static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
AccDataType,
YDataType, static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
Rank, const std::vector<index_t>& inStrides,
NumReduceDim, int blkGroupSize,
reduce::Add, int numBlockTileIteration)
PassThrough, // InElementwiseOperation {
AccElementwiseOperation, // AccElementwiseOperation constexpr index_t NumInvariantDim = Rank - NumReduceDim;
InMemoryDataOperationEnum::Set, static constexpr index_t numSrcDim = Rank;
false, // PropagateNan static constexpr bool reduceAllDim = (NumInvariantDim == 0);
false, // OutputIndex
false, // HaveIndexInputIfOutputIndex const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
BlockSize, const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
MThreadClusterSize,
KThreadClusterSize, const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
MThreadSliceSize,
KThreadSliceSize, const auto in_grid_desc_m_k = [&]() {
XYSrcVectorDim, if constexpr(reduceAllDim)
XSrcVectorSize, {
1>; // YDstVectorSize const auto one_dim_inDesc = transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}),
make_tuple(Sequence<0>{}));
return transform_tensor_descriptor(one_dim_inDesc,
make_tuple(make_unmerge_transform(make_tuple(
1, one_dim_inDesc.GetLength(Number<0>{})))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{}));
}
else
{
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths =
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
const auto invariantDimLengths =
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
return transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(reduceDimLengths)),
make_tuple(InvariantDims{}, ReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}();
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration;
const auto inPad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength;
auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
in_grid_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, inPad_M),
make_right_pad_transform(reduceLength, inPad_K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (in_grid_desc_m_k_padded);
};
static auto MakeAffine1dDescriptor(const std::vector<index_t>& Lengths, static auto MakeAffine1dDescriptor(const std::vector<index_t>& Lengths,
const std::vector<index_t>& Strides, const std::vector<index_t>& Strides,
...@@ -97,7 +179,7 @@ struct DeviceLayernorm : public DeviceNormalization2<XDataType, ...@@ -97,7 +179,7 @@ struct DeviceLayernorm : public DeviceNormalization2<XDataType,
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
const auto reduceTotalLength = grid_desc_k.GetLength(Number<0>{}); const auto reduceTotalLength = grid_desc_k.GetLength(Number<0>{});
const int reduceSizePerBlock = Reduction::K_BlockTileSize * numBlockTileIteration; const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration;
const auto Pad_K = reduceSizePerBlock * blkGroupSize - reduceTotalLength; const auto Pad_K = reduceSizePerBlock * blkGroupSize - reduceTotalLength;
...@@ -110,10 +192,11 @@ struct DeviceLayernorm : public DeviceNormalization2<XDataType, ...@@ -110,10 +192,11 @@ struct DeviceLayernorm : public DeviceNormalization2<XDataType,
return (grid_desc_k_padded); return (grid_desc_k_padded);
}; };
using GridDesc_M_K = decltype(Reduction::MakeSrc2dDescriptor({1}, {1}, 1, 1)); using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1));
using GridDesc_K = decltype(MakeAffine1dDescriptor({1}, {1}, 1, 1)); using GridDesc_K = decltype(MakeAffine1dDescriptor({1}, {1}, 1, 1));
using GridwiseReduceLayernormGeneric = GridwiseLayernorm_mk_to_mk<XDataType, using GridwiseReduceLayernormGeneric =
GridwiseLayernormWelfordVariance_mk_to_mk<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
...@@ -134,7 +217,8 @@ struct DeviceLayernorm : public DeviceNormalization2<XDataType, ...@@ -134,7 +217,8 @@ struct DeviceLayernorm : public DeviceNormalization2<XDataType,
YDstVectorSize, YDstVectorSize,
false>; false>;
using GridwiseReduceLayernormSweepOnce = GridwiseLayernorm_mk_to_mk<XDataType, using GridwiseReduceLayernormSweepOnce =
GridwiseLayernormWelfordVariance_mk_to_mk<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
...@@ -155,12 +239,13 @@ struct DeviceLayernorm : public DeviceNormalization2<XDataType, ...@@ -155,12 +239,13 @@ struct DeviceLayernorm : public DeviceNormalization2<XDataType,
YDstVectorSize, YDstVectorSize,
true>; true>;
struct Argument : public Reduction::Argument struct Argument : public BaseArgument
{ {
Argument(const std::vector<index_t> lengths, Argument(const std::vector<index_t> lengths,
const std::vector<index_t> xStrides, const std::vector<index_t> xStrides,
const std::vector<index_t> gammaStrides, const std::vector<index_t> gammaStrides,
const std::vector<index_t> betaStrides, const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
AccElementwiseOperation acc_elementwise_op, AccElementwiseOperation acc_elementwise_op,
AccDataType epsilon, AccDataType epsilon,
...@@ -168,53 +253,76 @@ struct DeviceLayernorm : public DeviceNormalization2<XDataType, ...@@ -168,53 +253,76 @@ struct DeviceLayernorm : public DeviceNormalization2<XDataType,
const GammaDataType* p_gamma, const GammaDataType* p_gamma,
const BetaDataType* p_beta, const BetaDataType* p_beta,
YDataType* p_y) YDataType* p_y)
: Reduction::Argument(lengths, : epsilon_(epsilon),
xStrides, p_x_(p_x),
{},
{},
reduceDims,
0.0f, // alpha
0.0f, // beta
p_x,
nullptr,
p_y,
nullptr,
acc_elementwise_op,
PassThrough{}),
epsilon_(epsilon),
p_gamma_(p_gamma), p_gamma_(p_gamma),
p_beta_(p_beta), p_beta_(p_beta),
p_y_(p_y),
gammaStrides_(gammaStrides), gammaStrides_(gammaStrides),
betaStrides_(betaStrides) betaStrides_(betaStrides),
acc_elementwise_op_(acc_elementwise_op)
{ {
reduceLength_.resize(NumReduceDim); Lengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(lengths, reduceDims);
xStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(xStrides, reduceDims);
yStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(yStrides, reduceDims);
long_index_t invariant_total_length;
long_index_t reduce_total_length;
std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, NumReduceDim>(Lengths_);
blkGroupSize_ = 1;
numBlockTileIteration_ = (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize;
gridSize_ = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize * blkGroupSize_;
reduceLengths_.resize(NumReduceDim);
for(int i = 0; i < NumReduceDim; ++i) for(int i = 0; i < NumReduceDim; ++i)
{ {
reduceLength_[i] = lengths[reduceDims[i]]; reduceLengths_[i] = lengths[reduceDims[i]];
} }
} }
AccDataType epsilon_; AccDataType epsilon_;
const XDataType* p_x_;
const GammaDataType* p_gamma_; const GammaDataType* p_gamma_;
const BetaDataType* p_beta_; const BetaDataType* p_beta_;
std::vector<index_t> reduceLength_; YDataType* p_y_;
std::vector<index_t> Lengths_;
std::vector<index_t> xStrides_;
std::vector<index_t> reduceLengths_;
std::vector<index_t> gammaStrides_; std::vector<index_t> gammaStrides_;
std::vector<index_t> betaStrides_; std::vector<index_t> betaStrides_;
std::vector<index_t> yStrides_;
AccElementwiseOperation acc_elementwise_op_;
int blkGroupSize_;
int numBlockTileIteration_;
size_t gridSize_;
}; };
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
const auto x_grid_desc_m_k = Reduction::MakeSrc2dDescriptor( const auto x_grid_desc_m_k = MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration); arg.Lengths_, arg.xStrides_, arg.blkGroupSize_, arg.numBlockTileIteration_);
const auto gamma_grid_desc_k = MakeAffine1dDescriptor( const auto gamma_grid_desc_k = MakeAffine1dDescriptor(arg.reduceLengths_,
arg.reduceLength_, arg.gammaStrides_, arg.blkGroupSize, arg.numBlockTileIteration); arg.gammaStrides_,
const auto beta_grid_desc_k = MakeAffine1dDescriptor( arg.blkGroupSize_,
arg.reduceLength_, arg.betaStrides_, arg.blkGroupSize, arg.numBlockTileIteration); arg.numBlockTileIteration_);
const auto y_grid_desc_m_k = Reduction::MakeSrc2dDescriptor( const auto beta_grid_desc_k = MakeAffine1dDescriptor(arg.reduceLengths_,
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration); arg.betaStrides_,
arg.blkGroupSize_,
arg.numBlockTileIteration_);
const auto y_grid_desc_m_k = MakeSrc2dDescriptor(
arg.Lengths_, arg.yStrides_, arg.blkGroupSize_, arg.numBlockTileIteration_);
bool sweep_once = bool sweep_once =
x_grid_desc_m_k.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize; x_grid_desc_m_k.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize;
...@@ -241,19 +349,19 @@ struct DeviceLayernorm : public DeviceNormalization2<XDataType, ...@@ -241,19 +349,19 @@ struct DeviceLayernorm : public DeviceNormalization2<XDataType,
float avg_time = 0; float avg_time = 0;
avg_time += launch_and_time_kernel(stream_config, avg_time += launch_and_time_kernel(stream_config,
kernel_main, kernel_main,
dim3(arg.gridSize), dim3(arg.gridSize_),
dim3(BlockSize), dim3(BlockSize),
0, 0,
x_grid_desc_m_k, x_grid_desc_m_k,
gamma_grid_desc_k, gamma_grid_desc_k,
beta_grid_desc_k, beta_grid_desc_k,
y_grid_desc_m_k, y_grid_desc_m_k,
arg.numBlockTileIteration, arg.numBlockTileIteration_,
arg.epsilon_, arg.epsilon_,
arg.in_dev_, arg.p_x_,
arg.p_gamma_, arg.p_gamma_,
arg.p_beta_, arg.p_beta_,
arg.out_dev_, arg.p_y_,
arg.acc_elementwise_op_); arg.acc_elementwise_op_);
return (avg_time); return (avg_time);
...@@ -270,12 +378,33 @@ struct DeviceLayernorm : public DeviceNormalization2<XDataType, ...@@ -270,12 +378,33 @@ struct DeviceLayernorm : public DeviceNormalization2<XDataType,
{ {
const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg); const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
if(!Reduction::IsSupportedArgument(p_arg_)) constexpr index_t NumInvariantDim = Rank - NumReduceDim;
if constexpr(XYSrcVectorDim == 0)
{
if constexpr(NumInvariantDim == 0)
{
return false;
}
else
{ {
if(p_arg_->xStrides_[NumInvariantDim - 1] != 1)
return false; return false;
if(p_arg_->invariant_lowest_length % XSrcVectorSize != 0)
return false;
};
} }
else
{
if(p_arg_->xStrides_[Rank - 1] != 1)
return false;
if(p_arg_->Lengths_[Rank - 1] % XSrcVectorSize != 0)
return false;
};
if(p_arg_->inLengths_[Rank - 1] % YDstVectorSize != 0) if(p_arg_->Lengths_[Rank - 1] % YDstVectorSize != 0)
{ {
return false; return false;
} }
...@@ -309,6 +438,7 @@ struct DeviceLayernorm : public DeviceNormalization2<XDataType, ...@@ -309,6 +438,7 @@ struct DeviceLayernorm : public DeviceNormalization2<XDataType,
const std::vector<index_t> xStrides, const std::vector<index_t> xStrides,
const std::vector<index_t> gammaStrides, const std::vector<index_t> gammaStrides,
const std::vector<index_t> betaStrides, const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
AccDataType epsilon, AccDataType epsilon,
const void* p_x, const void* p_x,
...@@ -321,6 +451,7 @@ struct DeviceLayernorm : public DeviceNormalization2<XDataType, ...@@ -321,6 +451,7 @@ struct DeviceLayernorm : public DeviceNormalization2<XDataType,
xStrides, xStrides,
gammaStrides, gammaStrides,
betaStrides, betaStrides,
yStrides,
reduceDims, reduceDims,
acc_elementwise_op, acc_elementwise_op,
epsilon, epsilon,
...@@ -340,7 +471,7 @@ struct DeviceLayernorm : public DeviceNormalization2<XDataType, ...@@ -340,7 +471,7 @@ struct DeviceLayernorm : public DeviceNormalization2<XDataType,
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceLayernorm<" << BlockSize << ","; str << "DeviceLayernormImpl<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str << "XYSrcVectorDim_" << XYSrcVectorDim << ","; str << "XYSrcVectorDim_" << XYSrcVectorDim << ",";
......
...@@ -46,13 +46,14 @@ template <typename XDataType, ...@@ -46,13 +46,14 @@ template <typename XDataType,
typename AccElementwiseOperation, typename AccElementwiseOperation,
index_t Rank, index_t Rank,
index_t NumReduceDim> index_t NumReduceDim>
struct DeviceNormalization2 : public BaseOperator struct DeviceLayernorm : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<index_t> lengths, MakeArgumentPointer(const std::vector<index_t> lengths,
const std::vector<index_t> xStrides, const std::vector<index_t> xStrides,
const std::vector<index_t> gammaStrides, const std::vector<index_t> gammaStrides,
const std::vector<index_t> betaStrides, const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
AccDataType epsilon, AccDataType epsilon,
const void* p_x, const void* p_x,
...@@ -72,7 +73,7 @@ template <typename XDataType, ...@@ -72,7 +73,7 @@ template <typename XDataType,
typename AccElementwiseOperation, typename AccElementwiseOperation,
index_t Rank, index_t Rank,
index_t NumReduceDim> index_t NumReduceDim>
using DeviceNormalization2Ptr = std::unique_ptr<DeviceNormalization2<XDataType, using DeviceLayernormPtr = std::unique_ptr<DeviceLayernorm<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
AccDataType, AccDataType,
......
...@@ -14,40 +14,6 @@ ...@@ -14,40 +14,6 @@
namespace ck { namespace ck {
template <typename GridwiseReduction,
typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename YDataType,
typename AccDataType,
typename AccElementwiseOperation,
typename GridDesc_M_K,
typename GridDesc_K>
__global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k,
const GridDesc_K gamma_grid_desc_k,
const GridDesc_K beta_grid_desc_k,
const GridDesc_M_K y_grid_desc_m_k,
index_t num_k_block_tile_iteration,
AccDataType epsilon,
const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global,
const AccElementwiseOperation acc_elementwise_op)
{
GridwiseReduction::Run(x_grid_desc_m_k,
gamma_grid_desc_k,
beta_grid_desc_k,
y_grid_desc_m_k,
num_k_block_tile_iteration,
epsilon,
p_x_global,
p_gamma_global,
p_beta_global,
p_y_global,
acc_elementwise_op);
};
// Y = LayerNorm(X, Beta, Gamma) // Y = LayerNorm(X, Beta, Gamma)
template <typename XDataType, template <typename XDataType,
typename GammaDataType, typename GammaDataType,
...@@ -69,7 +35,7 @@ template <typename XDataType, ...@@ -69,7 +35,7 @@ template <typename XDataType,
index_t YDstVectorDim, index_t YDstVectorDim,
index_t YDstVectorSize, index_t YDstVectorSize,
bool SweepOnce> bool SweepOnce>
struct GridwiseLayernorm_mk_to_mk struct GridwiseLayernormNaiveVariance_mk_to_mk
{ {
static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
(XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0), (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
// Y = LayerNorm(X, Beta, Gamma)
template <typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename YDataType,
typename AccDataType,
typename AccElementwiseOperation,
typename GridDesc_M_K,
typename GridDesc_K,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t XSrcVectorDim,
index_t XSrcVectorSize,
index_t GammaSrcVectorSize,
index_t BetaSrcVectorSize,
index_t YDstVectorDim,
index_t YDstVectorSize,
bool SweepOnce>
struct GridwiseLayernormWelfordVariance_mk_to_mk
{
static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
(XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert((YDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) ||
(YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using ThreadwiseWelford =
ThreadwiseWelford<AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M>;
using BlockwiseWelford = BlockwiseWelford<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
__device__ static int GetKPerThread(const GridDesc_M_K& x_grid_desc_m_k,
int thread_k_cluster_id)
{
int kPerBlock = x_grid_desc_m_k.GetTransforms()[I0].GetUpperLengths()[I1];
int kPerThread =
kPerBlock < K_BlockTileSize ? 0 : KThreadSliceSize * (kPerBlock / K_BlockTileSize);
int kPerBlockTail = kPerBlock - kPerThread * KThreadClusterSize;
if(kPerBlockTail > 0)
{
int thread_max_len = (thread_k_cluster_id + 1) * KThreadSliceSize;
int delta = thread_max_len - kPerBlockTail;
delta = math::clamp(thread_max_len - kPerBlockTail, 0, KThreadSliceSize);
kPerThread += KThreadSliceSize - delta;
}
return kPerThread;
}
__device__ static void Run(const GridDesc_M_K& x_grid_desc_m_k,
const GridDesc_K& gamma_grid_desc_k,
const GridDesc_K& beta_grid_desc_k,
const GridDesc_M_K& y_grid_desc_m_k,
index_t num_k_block_tile_iteration,
AccDataType epsilon,
const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global,
const AccElementwiseOperation acc_elementwise_op)
{
if constexpr(SweepOnce)
{
num_k_block_tile_iteration = 1;
}
auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
x_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, KThreadSliceSize, true> gamma_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, KThreadSliceSize, true>& beta_thread_buf =
gamma_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
y_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> var_thread_buf;
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
using ThreadBufferLengths_K = Sequence<KThreadSliceSize>;
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
constexpr auto thread_buffer_desc_k =
make_naive_tensor_descriptor_packed(make_tuple(Number<KThreadSliceSize>{}));
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
AccDataType,
GridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XSrcVectorDim,
XSrcVectorSize,
1,
true>(
x_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_gamma_load =
ThreadwiseTensorSliceTransfer_v2<GammaDataType,
AccDataType,
GridDesc_K,
decltype(thread_buffer_desc_k),
ThreadBufferLengths_K,
Sequence<0>,
0,
GammaSrcVectorSize,
1,
true>(
gamma_grid_desc_k, make_multi_index(thread_k_cluster_id * KThreadSliceSize));
auto threadwise_beta_load = ThreadwiseTensorSliceTransfer_v2<BetaDataType,
AccDataType,
GridDesc_K,
decltype(thread_buffer_desc_k),
ThreadBufferLengths_K,
Sequence<0>,
0,
BetaSrcVectorSize,
1,
true>(
beta_grid_desc_k, make_multi_index(thread_k_cluster_id * KThreadSliceSize));
auto threadwise_y_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
YDataType,
decltype(thread_buffer_desc_m_k),
GridDesc_M_K,
AccElementwiseOperation,
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
YDstVectorDim,
YDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
y_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize),
acc_elementwise_op);
// Copy x from Cache
// one pass: fwd, second pass: bwd
constexpr auto thread_copy_fwd_step_k = make_multi_index(SweepOnce ? 0 : K_BlockTileSize);
constexpr auto thread_copy_bwd_step_k = make_multi_index(SweepOnce ? 0 : -K_BlockTileSize);
constexpr auto thread_copy_fwd_step_m_k =
make_multi_index(0, SweepOnce ? 0 : K_BlockTileSize);
constexpr auto thread_copy_bwd_step_m_k =
make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize);
const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x_global, x_grid_desc_m_k.GetElementSpaceSize());
const auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_gamma_global, gamma_grid_desc_k.GetElementSpaceSize());
const auto beta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_beta_global, beta_grid_desc_k.GetElementSpaceSize());
auto threadwise_welford = ThreadwiseWelford();
threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k, thread_k_cluster_id);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
var_thread_buf(I) = type_convert<AccDataType>(0.0f);
});
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
threadwise_welford.Run(x_thread_buf, mean_thread_buf, var_thread_buf);
}
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
int count = threadwise_welford.cur_count_;
BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count);
});
auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k;
auto thread_copy_tail_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_k;
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_k, thread_copy_tail_k);
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_k, thread_copy_tail_k);
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k);
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{
if constexpr(!SweepOnce)
{
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
}
threadwise_gamma_load.Run(gamma_grid_desc_k,
gamma_global_val_buf,
thread_buffer_desc_k,
make_tuple(I0),
gamma_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
constexpr auto offset_k = thread_buffer_desc_k.CalculateOffset(make_tuple(iK));
// normalize
y_thread_buf(Number<offset_m_k>{}) =
(x_thread_buf(Number<offset_m_k>{}) - mean_thread_buf(iM)) /
sqrt(var_thread_buf(iM) + epsilon);
// gamma
y_thread_buf(Number<offset_m_k>{}) =
y_thread_buf(Number<offset_m_k>{}) * gamma_thread_buf(Number<offset_k>{});
});
});
threadwise_beta_load.Run(beta_grid_desc_k,
beta_global_val_buf,
thread_buffer_desc_k,
make_tuple(I0),
beta_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
constexpr auto offset_k = thread_buffer_desc_k.CalculateOffset(make_tuple(iK));
// beta
y_thread_buf(Number<offset_m_k>{}) =
y_thread_buf(Number<offset_m_k>{}) + beta_thread_buf(Number<offset_k>{});
});
});
threadwise_y_store.Run(thread_buffer_desc_m_k,
make_tuple(I0, I0),
y_thread_buf,
y_grid_desc_m_k,
y_global_val_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_k, thread_copy_bwd_step_k);
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_k, thread_copy_bwd_step_k);
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_bwd_step_m_k);
}
}
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/math_v2.hpp"
namespace ck {
// Assume
// 1) XDesc is known at compile-time
// 2) MeanVarDesc is known at compile-time
// 3) XBuffer is static buffer
// 4) MeanBuffer is static buffer
// 5) VarBuffer is static buffer
template <typename T, typename XThreadDesc_M_K, typename MeanVarThreadDesc_M>
struct ThreadwiseWelford
{
static constexpr auto x_thread_desc_m_k = XThreadDesc_M_K{};
static constexpr auto mean_var_thread_desc_m = MeanVarThreadDesc_M{};
static constexpr auto thread_x_length_m = x_thread_desc_m_k.GetLength(Number<0>{});
static constexpr auto thread_x_length_k = x_thread_desc_m_k.GetLength(Number<1>{});
static constexpr auto thread_mean_var_length_m = mean_var_thread_desc_m.GetLength(Number<0>{});
static_assert(thread_x_length_m == thread_mean_var_length_m,
"lengths of source and mean/var buffer must match!");
__device__ constexpr ThreadwiseWelford() : cur_count_(0), max_count_(0) {}
__device__ inline void Update(T& mean, T& var, T x)
{
using ck::math::isnan;
if(isnan(x))
{
mean = x;
var = x;
}
else
{
T delta = x - mean;
mean += delta / cur_count_;
T delta2 = x - mean;
var += delta * delta2;
}
}
template <typename XBufferType, typename MeanBufferType, typename VarBufferType>
__device__ void
Run(const XBufferType& x_buf_m_k, MeanBufferType& mean_buf_m, VarBufferType& var_buf_m)
{
// FIXME - Better naming for var_buf_m
static_for<0, thread_x_length_k, 1>{}([&](auto iK) {
if(cur_count_ < max_count_)
{
++cur_count_;
static_for<0, thread_x_length_m, 1>{}([&](auto iM) {
constexpr index_t out_offset =
mean_var_thread_desc_m.CalculateOffset(make_tuple(iM));
constexpr auto in_offset =
x_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK));
Update(mean_buf_m(Number<out_offset>{}),
var_buf_m(Number<out_offset>{}),
x_buf_m_k[Number<in_offset>{}]);
});
}
});
};
int cur_count_;
int max_count_;
};
} // namespace ck
...@@ -144,6 +144,12 @@ __host__ __device__ constexpr auto min(X x, Ys... ys) ...@@ -144,6 +144,12 @@ __host__ __device__ constexpr auto min(X x, Ys... ys)
return min(x, min(ys...)); return min(x, min(ys...));
} }
template <typename T>
__host__ __device__ constexpr T clamp(const T& x, const T& lowerbound, const T& upperbound)
{
return min(max(x, lowerbound), upperbound);
}
// disallow implicit type casting // disallow implicit type casting
template <typename T> template <typename T>
__device__ T exp(T x); __device__ T exp(T x);
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_layernorm.hpp" #include "ck/tensor_operation/gpu/device/device_layernorm_impl.hpp"
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
...@@ -21,28 +21,28 @@ template <index_t Rank, index_t Reduce> ...@@ -21,28 +21,28 @@ template <index_t Rank, index_t Reduce>
using device_layernorm_f16_instances = std::tuple< using device_layernorm_f16_instances = std::tuple<
// clang-format off // clang-format off
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize> // XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
DeviceLayernorm<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 8, 32, 1, 8, 1, 1, 1, 1, 1>, // fallback kernel DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 8, 32, 1, 8, 1, 1, 1, 1, 1>, // fallback kernel
DeviceLayernorm<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 8, 32, 1, 8, 1, 2, 2, 2, 2>, // fallback kernel DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 8, 32, 1, 8, 1, 2, 2, 2, 2>, // fallback kernel
DeviceLayernorm<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 8, 32, 1, 8, 1, 4, 4, 4, 4>, // fallback kernel DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 8, 32, 1, 8, 1, 4, 4, 4, 4>, // fallback kernel
DeviceLayernorm<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 8, 32, 1, 8, 1, 8, 8, 8, 8>, DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 8, 32, 1, 8, 1, 8, 8, 8, 8>,
DeviceLayernorm<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 4, 64, 1, 8, 1, 8, 8, 8, 8>, DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 4, 64, 1, 8, 1, 8, 8, 8, 8>,
DeviceLayernorm<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 2, 128, 1, 8, 1, 8, 8, 8, 8>, DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 2, 128, 1, 8, 1, 8, 8, 8, 8>,
DeviceLayernorm<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 2, 128, 1, 16, 1, 8, 8, 8, 8>, DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 2, 128, 1, 16, 1, 8, 8, 8, 8>,
DeviceLayernorm<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 2, 128, 1, 32, 1, 8, 8, 8, 8>, DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 2, 128, 1, 32, 1, 8, 8, 8, 8>,
DeviceLayernorm<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 8, 8, 8>, DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 8, 8, 8>,
DeviceLayernorm<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 1, 256, 1, 16, 1, 8, 8, 8, 8>, DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 1, 256, 1, 16, 1, 8, 8, 8, 8>,
DeviceLayernorm<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 1, 256, 1, 32, 1, 8, 8, 8, 8> DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 1, 256, 1, 32, 1, 8, 8, 8, 8>
// clang-format on // clang-format on
>; >;
void add_device_layernorm_f16_rank2_instances( void add_device_layernorm_f16_rank2_instances(
std::vector<DeviceNormalization2Ptr<F16, F16, F16, F32, F16, Pass, 2, 1>>& instances) std::vector<DeviceLayernormPtr<F16, F16, F16, F32, F16, Pass, 2, 1>>& instances)
{ {
add_device_operation_instances(instances, device_layernorm_f16_instances<2, 1>{}); add_device_operation_instances(instances, device_layernorm_f16_instances<2, 1>{});
} }
void add_device_layernorm_f16_rank4_instances( void add_device_layernorm_f16_rank4_instances(
std::vector<DeviceNormalization2Ptr<F16, F16, F16, F32, F16, Pass, 4, 3>>& instances) std::vector<DeviceLayernormPtr<F16, F16, F16, F32, F16, Pass, 4, 3>>& instances)
{ {
add_device_operation_instances(instances, device_layernorm_f16_instances<4, 3>{}); add_device_operation_instances(instances, device_layernorm_f16_instances<4, 3>{});
} }
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_layernorm.hpp" #include "ck/tensor_operation/gpu/device/device_layernorm_impl.hpp"
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
...@@ -20,27 +20,27 @@ template <index_t Rank, index_t Reduce> ...@@ -20,27 +20,27 @@ template <index_t Rank, index_t Reduce>
using device_layernorm_f32_instances = std::tuple< using device_layernorm_f32_instances = std::tuple<
// clang-format off // clang-format off
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize> // XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
DeviceLayernorm<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 8, 32, 1, 8, 1, 1, 1, 1, 1>, // fallback kernel DeviceLayernormImpl<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 8, 32, 1, 8, 1, 1, 1, 1, 1>, // fallback kernel
DeviceLayernorm<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 8, 32, 1, 8, 1, 2, 2, 2, 2>, // fallback kernel DeviceLayernormImpl<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 8, 32, 1, 8, 1, 2, 2, 2, 2>, // fallback kernel
DeviceLayernorm<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 8, 32, 1, 8, 1, 4, 4, 4, 4>, DeviceLayernormImpl<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 8, 32, 1, 8, 1, 4, 4, 4, 4>,
DeviceLayernorm<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 4, 64, 1, 8, 1, 4, 4, 4, 4>, DeviceLayernormImpl<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 4, 64, 1, 8, 1, 4, 4, 4, 4>,
DeviceLayernorm<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 2, 128, 1, 8, 1, 4, 4, 4, 4>, DeviceLayernormImpl<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 2, 128, 1, 8, 1, 4, 4, 4, 4>,
DeviceLayernorm<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 2, 128, 1, 16, 1, 4, 4, 4, 4>, DeviceLayernormImpl<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 2, 128, 1, 16, 1, 4, 4, 4, 4>,
DeviceLayernorm<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 2, 128, 1, 32, 1, 4, 4, 4, 4>, DeviceLayernormImpl<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 2, 128, 1, 32, 1, 4, 4, 4, 4>,
DeviceLayernorm<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 4, 4, 4>, DeviceLayernormImpl<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 4, 4, 4>,
DeviceLayernorm<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 4, 4, 4>, DeviceLayernormImpl<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 4, 4, 4>,
DeviceLayernorm<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 4, 4, 4> DeviceLayernormImpl<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 4, 4, 4>
// clang-format on // clang-format on
>; >;
void add_device_layernorm_f32_rank2_instances( void add_device_layernorm_f32_rank2_instances(
std::vector<DeviceNormalization2Ptr<F32, F32, F32, F32, F32, Pass, 2, 1>>& instances) std::vector<DeviceLayernormPtr<F32, F32, F32, F32, F32, Pass, 2, 1>>& instances)
{ {
add_device_operation_instances(instances, device_layernorm_f32_instances<2, 1>{}); add_device_operation_instances(instances, device_layernorm_f32_instances<2, 1>{});
} }
void add_device_layernorm_f32_rank4_instances( void add_device_layernorm_f32_rank4_instances(
std::vector<DeviceNormalization2Ptr<F32, F32, F32, F32, F32, Pass, 4, 3>>& instances) std::vector<DeviceLayernormPtr<F32, F32, F32, F32, F32, Pass, 4, 3>>& instances)
{ {
add_device_operation_instances(instances, device_layernorm_f32_instances<4, 3>{}); add_device_operation_instances(instances, device_layernorm_f32_instances<4, 3>{});
} }
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "profiler/include/data_type_enum.hpp" #include "profiler/include/data_type_enum.hpp"
#include "ck/tensor_operation/gpu/device/device_layernorm.hpp" #include "ck/tensor_operation/gpu/device/device_layernorm_impl.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
...@@ -25,10 +25,10 @@ using F32 = float; ...@@ -25,10 +25,10 @@ using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
void add_device_layernorm_f16_rank2_instances( void add_device_layernorm_f16_rank2_instances(
std::vector<DeviceNormalization2Ptr<F16, F16, F16, F32, F16, PassThrough, 2, 1>>&); std::vector<DeviceLayernormPtr<F16, F16, F16, F32, F16, PassThrough, 2, 1>>&);
void add_device_layernorm_f32_rank2_instances( void add_device_layernorm_f32_rank2_instances(
std::vector<DeviceNormalization2Ptr<F32, F32, F32, F32, F32, PassThrough, 2, 1>>&); std::vector<DeviceLayernormPtr<F32, F32, F32, F32, F32, PassThrough, 2, 1>>&);
} // namespace instance } // namespace instance
} // namespace device } // namespace device
...@@ -105,7 +105,7 @@ void profile_layernorm_impl(int do_verification, ...@@ -105,7 +105,7 @@ void profile_layernorm_impl(int do_verification,
// add device normalization instances // add device normalization instances
constexpr int NumReduceDim = Rank - 1; constexpr int NumReduceDim = Rank - 1;
std::vector<tensor_operation::device::DeviceNormalization2Ptr<XDataType, std::vector<tensor_operation::device::DeviceLayernormPtr<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
AccDataType, AccDataType,
...@@ -163,6 +163,7 @@ void profile_layernorm_impl(int do_verification, ...@@ -163,6 +163,7 @@ void profile_layernorm_impl(int do_verification,
strideXY, strideXY,
strideGamma, strideGamma,
strideBeta, strideBeta,
strideXY,
reduce_dim, reduce_dim,
1e-4, 1e-4,
x_dev.GetDeviceBuffer(), x_dev.GetDeviceBuffer(),
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/utility/number.hpp" #include "ck/utility/number.hpp"
#include "ck/tensor_operation/gpu/device/device_layernorm.hpp" #include "ck/tensor_operation/gpu/device/device_layernorm_impl.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
...@@ -63,7 +63,7 @@ class TestLayernorm : public ::testing::Test ...@@ -63,7 +63,7 @@ class TestLayernorm : public ::testing::Test
Rank, Rank,
NumReduceDim>; NumReduceDim>;
using DeviceInstance = tensor_operation::device::DeviceLayernorm<XDataType, using DeviceInstance = tensor_operation::device::DeviceLayernormImpl<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
AccDataType, AccDataType,
...@@ -119,6 +119,7 @@ class TestLayernorm : public ::testing::Test ...@@ -119,6 +119,7 @@ class TestLayernorm : public ::testing::Test
gamma.mDesc.GetStrides().end()}, gamma.mDesc.GetStrides().end()},
std::vector<ck::index_t>{beta.mDesc.GetStrides().begin(), std::vector<ck::index_t>{beta.mDesc.GetStrides().begin(),
beta.mDesc.GetStrides().end()}, beta.mDesc.GetStrides().end()},
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()},
reduceDims, reduceDims,
1e-4, 1e-4,
x_dev.GetDeviceBuffer(), x_dev.GetDeviceBuffer(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment