"model/models/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "4ff8a691bcef296aa976e19d0ba9c7b74ae9f27c"
Commit 8b7aeb35 authored by rocking's avatar rocking
Browse files

Implement layernorm kernel and deviceOp

parent 12235112
...@@ -209,6 +209,8 @@ int main(int argc, char* argv[]) ...@@ -209,6 +209,8 @@ int main(int argc, char* argv[])
auto device_instance = DeviceInstance{}; auto device_instance = DeviceInstance{};
std::cout << i_inLengths.size() << ", " << i_inStrides.size() <<std::endl;
auto argument_ptr = device_instance.MakeArgumentPointer(i_inLengths, auto argument_ptr = device_instance.MakeArgumentPointer(i_inLengths,
i_inStrides, i_inStrides,
reduceDims, reduceDims,
......
add_example_executable(example_layernorm_blockwise layernorm_blockwise.cpp)
\ No newline at end of file
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <getopt.h>
#include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/device_layernorm.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_common_util.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
using XDataType = ck::half_t;
using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t;
using YDataType = ck::half_t;
using AccDataType = float;
constexpr int Rank = 2;
constexpr int NumReduceDim = 1;
using DeviceInstance = ck::tensor_operation::device::DeviceLayernorm<XDataType,
GammaDataType,
BetaDataType,
AccDataType,
YDataType,
Rank,
NumReduceDim,
256, // BlockSize
8, // ClusterM
32, // ClusterK
1, // SliceM
8, // SliceK
1, // SrcVecDim (0=M, 1=K)
8, // SrcScalarPerVector
1, // AffineVecDim (0=M, 1=K)
1, // AffineScalarPerVector
8>; // OutScalarPerVector
int main()
{
bool time_kernel = false;
ck::index_t M = 1024;
ck::index_t N = 1024;
ck::index_t Stride = 1024;
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
return HostTensorDescriptor(std::vector<std::size_t>({len}),
std::vector<std::size_t>({stride}));
};
auto f_host_tensor_descriptor2d = [](std::size_t row, std::size_t col, std::size_t stride) {
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
};
Tensor<XDataType> x(f_host_tensor_descriptor2d(M, N, Stride));
Tensor<GammaDataType> gamma(f_host_tensor_descriptor1d(N, 1));
Tensor<BetaDataType> beta(f_host_tensor_descriptor1d(N, 1));
Tensor<YDataType> y(f_host_tensor_descriptor2d(M, N, Stride));
x.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
gamma.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{0.0, 1.0});
beta.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{0.0, 1.0});
DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpace());
DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpace());
DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpace());
DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpace());
x_dev.ToDevice(x.mData.data());
auto device_instance = DeviceInstance{};
auto argument_ptr = device_instance.MakeArgumentPointer({M, N},
{Stride, 1},
{0, 1},
{1},
1e-4,
x_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(),
beta_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer());
if(!device_instance.IsSupportedArgument(argument_ptr.get()))
{
std::cout << "The runtime parameters are not supported" << std::endl;
return 1;
};
auto invoker_ptr = device_instance.MakeInvokerPointer();
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
bool pass = true;
return (pass ? 0 : 1);
}
...@@ -42,3 +42,4 @@ add_subdirectory(20_convnd_bwd_weight_xdl) ...@@ -42,3 +42,4 @@ add_subdirectory(20_convnd_bwd_weight_xdl)
add_subdirectory(21_gemm_layernorm) add_subdirectory(21_gemm_layernorm)
add_subdirectory(22_cgemm) add_subdirectory(22_cgemm)
add_subdirectory(23_softmax) add_subdirectory(23_softmax)
add_subdirectory(24_layernorm)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_layernorm.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename AccDataType,
typename YDataType,
index_t Rank,
index_t NumReduceDim,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t AffineSrcVectorDim,
index_t AffineSrcVectorSize,
index_t OutDstVectorSize>
struct DeviceLayernorm : public BaseOperator
{
static_assert(
((AffineSrcVectorDim == 0 && MThreadSliceSize % AffineSrcVectorSize == 0) ||
(AffineSrcVectorDim == 1 && KThreadSliceSize % AffineSrcVectorSize == 0)),
"Invalid thread slice sizes and/or affine vector sizes configuration, please check!");
using PassThrough = tensor_operation::element_wise::PassThrough;
// Used for freeloading of some handy functions from DeviceReduceMultiBlock
using Reduction = DeviceReduceMultiBlock<XDataType,
AccDataType,
YDataType,
Rank,
NumReduceDim,
reduce::Add,
PassThrough, // InElementwiseOperation
PassThrough, // AccElementwiseOperation
InMemoryDataOperationEnum::Set,
false, // PropagateNan
false, // OutputIndex
false, // HaveIndexInputIfOutputIndex
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
1>; // OutDstVectorSize
using GridDesc_M_K = decltype(Reduction::MakeSrc2dDescriptor({1}, {1}, 1, 1));
using GridwiseReduce = GridwiseLayernorm_mk_to_mk<XDataType,
GammaDataType,
BetaDataType,
YDataType,
AccDataType,
GridDesc_M_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
AffineSrcVectorDim,
AffineSrcVectorSize,
OutDstVectorSize,
false>;
struct Argument : public Reduction::Argument
{
Argument(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
const std::vector<index_t> affineStrides,
const std::vector<index_t> reduceDims,
AccDataType epsilon,
const XDataType* p_x,
const GammaDataType* p_gamma,
const BetaDataType* p_beta,
YDataType* p_y)
: Reduction::Argument(inLengths,
inStrides,
{},
{},
reduceDims,
0.0f, // alpha
0.0f, // beta
p_x,
nullptr,
p_y,
nullptr,
PassThrough{},
PassThrough{}),
epsilon_(epsilon),
p_gamma_(p_gamma),
p_beta_(p_beta)
{
affineStrides_ =
shuffle_tensor_dimensions<Rank, NumReduceDim>(affineStrides, reduceDims);
}
AccDataType epsilon_;
const GammaDataType* p_gamma_;
const BetaDataType* p_beta_;
std::vector<index_t> affineStrides_;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto in_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
const auto gamma_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
arg.inLengths_, arg.affineStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
const auto beta_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
arg.inLengths_, arg.affineStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
const auto out_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
const auto kernel_main = kernel_layernorm<GridwiseReduce,
XDataType,
GammaDataType,
BetaDataType,
YDataType,
AccDataType,
GridDesc_M_K>;
float avg_time = 0;
avg_time += launch_and_time_kernel(stream_config,
kernel_main,
dim3(arg.gridSize),
dim3(BlockSize),
0,
in_grid_desc_m_k,
gamma_grid_desc_m_k,
beta_grid_desc_m_k,
out_grid_desc_m_k,
arg.blkGroupSize,
arg.numBlockTileIteration,
arg.epsilon_,
arg.in_dev_,
arg.p_gamma_,
arg.p_beta_,
arg.out_dev_);
return (avg_time);
};
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
};
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
if(!Reduction::IsSupportedArgument(p_arg_))
{
return false;
}
if(p_arg_->inLengths_[Rank - 1] % OutDstVectorSize != 0)
{
return false;
}
// TODO - Check AffineSrcVectorDim and AffineSrcVectorSize
return true;
};
std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
const std::vector<index_t> affineStrides,
const std::vector<int> reduceDims,
AccDataType epsilon,
const void* p_x,
const void* p_gamma,
const void* p_beta,
void* p_y)
{
return std::make_unique<Argument>(inLengths,
inStrides,
affineStrides,
reduceDims,
epsilon,
static_cast<const XDataType*>(p_x),
static_cast<const GammaDataType*>(p_gamma),
static_cast<const BetaDataType*>(p_beta),
static_cast<YDataType*>(p_y));
};
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); };
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceLayernorm<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/reduction_common.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseReduction,
typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename YDataType,
typename AccDataType,
typename GridDesc_M_K>
__global__ void kernel_layernorm(const GridDesc_M_K in_grid_desc_m_k,
const GridDesc_M_K gamma_grid_desc_m_k,
const GridDesc_M_K beta_grid_desc_m_k,
const GridDesc_M_K out_grid_desc_m_k,
index_t block_group_size,
index_t num_k_block_tile_iteration,
AccDataType epsilon,
const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global)
{
GridwiseReduction::Run(in_grid_desc_m_k,
gamma_grid_desc_m_k,
beta_grid_desc_m_k,
out_grid_desc_m_k,
block_group_size,
num_k_block_tile_iteration,
epsilon,
p_x_global,
p_gamma_global,
p_beta_global,
p_y_global);
};
template <typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename YDataType,
typename AccDataType,
typename GridDesc_M_K,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t AffineSrcVectorDim,
index_t AffineSrcVectorSize,
index_t OutDstVectorSize,
bool SweepOnce>
struct GridwiseLayernorm_mk_to_mk
{
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
(KThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using BlockwiseSumReduce =
PartitionedBlockwiseReduction<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
reduce::Add,
false, // ignored
detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>;
using ThreadwiseSumReduce =
ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
reduce::Add,
false, // ignored
detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
__device__ static void Run(const GridDesc_M_K& in_grid_desc_m_k,
const GridDesc_M_K& gamma_grid_desc_m_k,
const GridDesc_M_K& beta_grid_desc_m_k,
const GridDesc_M_K& out_grid_desc_m_k,
index_t block_group_size,
index_t num_k_block_tile_iteration,
AccDataType epsilon,
const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global)
{
if constexpr(SweepOnce)
{
num_k_block_tile_iteration = 1;
}
// LDS
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_global, out_grid_desc_m_k.GetElementSpaceSize());
auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
gamma_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
beta_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
out_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
MThreadSliceSize * KThreadSliceSize,
true>& in_square_thread_buf = out_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
mean_square_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>& var_value_buf =
mean_square_thread_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
mean_thread_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>();
mean_square_thread_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>();
});
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / block_group_size;
const index_t block_local_id = block_global_id % block_group_size;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
AccDataType,
GridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
true>(
in_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock +
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_gamma_load = ThreadwiseTensorSliceTransfer_v2<GammaDataType,
AccDataType,
GridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
AffineSrcVectorDim,
AffineSrcVectorSize,
1,
true>(
gamma_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock +
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_beta_load = ThreadwiseTensorSliceTransfer_v2<BetaDataType,
AccDataType,
GridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
AffineSrcVectorDim,
AffineSrcVectorSize,
1,
true>(
beta_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock +
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_y_store = ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
YDataType,
decltype(thread_buffer_desc),
GridDesc_M_K,
PassThroughOp,
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
OutDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
out_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock +
thread_k_cluster_id * KThreadSliceSize),
PassThroughOp{});
constexpr auto in_thread_copy_fwd_step =
make_multi_index(0, SweepOnce ? 0 : K_BlockTileSize);
constexpr auto in_thread_copy_bwd_step =
make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize);
const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x_global, in_grid_desc_m_k.GetElementSpaceSize());
const auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize());
const auto beta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize());
// E(x), E[x^2], var(x)
int reduce_length = in_grid_desc_m_k.GetLength(I1);
index_t reducedTiles = 0;
do
{
threadwise_x_load.Run(in_grid_desc_m_k,
in_global_val_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_square_thread_buf(Number<offset>{}) =
in_thread_buf(Number<offset>{}) * in_thread_buf(Number<offset>{});
});
});
ThreadwiseSumReduce::Reduce(in_thread_buf, mean_thread_buf);
ThreadwiseSumReduce::Reduce(in_square_thread_buf, mean_square_thread_buf);
threadwise_x_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step);
++reducedTiles;
} while(reducedTiles < num_k_block_tile_iteration);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
BlockwiseSumReduce::Reduce(reduce_work_buf, mean_thread_buf(I));
mean_thread_buf(I) = mean_thread_buf(I) / reduce_length;
BlockwiseSumReduce::Reduce(reduce_work_buf, mean_square_thread_buf(I));
mean_square_thread_buf(I) = mean_square_thread_buf(I) / reduce_length;
// var(x) = E[x^2] - E[x]^2
var_value_buf(I) =
mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I));
});
// y = (x - E[x]) / sqrt(var[x] + epsilon)
auto thread_copy_tail = (num_k_block_tile_iteration - 1) * in_thread_copy_fwd_step;
threadwise_x_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_bwd_step);
threadwise_gamma_load.MoveSrcSliceWindow(in_grid_desc_m_k, thread_copy_tail);
threadwise_beta_load.MoveSrcSliceWindow(in_grid_desc_m_k, thread_copy_tail);
threadwise_y_store.MoveDstSliceWindow(out_grid_desc_m_k, thread_copy_tail);
reducedTiles = 0;
do
{
if constexpr(!SweepOnce)
{
threadwise_x_load.Run(in_grid_desc_m_k,
in_global_val_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_buf);
}
threadwise_gamma_load.Run(gamma_grid_desc_m_k,
gamma_global_val_buf,
thread_buffer_desc,
make_tuple(I0, I0),
gamma_thread_buf);
threadwise_beta_load.Run(beta_grid_desc_m_k,
beta_global_val_buf,
thread_buffer_desc,
make_tuple(I0, I0),
beta_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
// normalize
out_thread_buf(Number<offset>{}) =
(in_thread_buf(Number<offset>{}) - mean_thread_buf(iM)) /
sqrt(var_value_buf(iM) + epsilon);
// affine
out_thread_buf(Number<offset>{}) =
out_thread_buf(Number<offset>{}) * gamma_thread_buf(Number<offset>{}) +
beta_thread_buf(Number<offset>{});
});
});
threadwise_y_store.Run(thread_buffer_desc,
make_tuple(I0, I0),
out_thread_buf,
out_grid_desc_m_k,
out_global_val_buf);
threadwise_x_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_bwd_step);
threadwise_gamma_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_bwd_step);
threadwise_beta_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_bwd_step);
threadwise_y_store.MoveDstSliceWindow(out_grid_desc_m_k, in_thread_copy_bwd_step);
++reducedTiles;
} while(reducedTiles < num_k_block_tile_iteration);
}
};
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment