Commit bd34d666 authored by rocking's avatar rocking
Browse files

Add 5ary elementwise for normalization

parent 980ed33a
......@@ -4,6 +4,7 @@
#include "device.hpp"
#include "device_base.hpp"
#include "common_header.hpp"
#include "gridwise_5ary_Elementwise_1d.hpp"
#include "tensor_layout.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
......@@ -12,24 +13,23 @@ namespace ck {
namespace tensor_operation {
namespace device {
// Input: x, E[x], E[x^2], Gamma, Beta
// Output: {[(x - E[x]) / sqrt(E[x^2] - E[x]^2 + epsilon)] * gamma} + beta
template <typename XDataType,
typename MeanDataType,
typename MeanSquareDataType,
typename GammaDataType,
typename BetaDataType,
typename OutDataType,
template <typename ADataType,
typename BDataType,
typename CDataType,
typename DDataType,
typename EDataType,
typename FDataType,
typename ComputeDataType,
typename OutElementwiseFunctor,
typename ElementwiseFunctor,
index_t NDim,
index_t MPerThread,
index_t XScalarPerVector,
index_t MeanScalarPerVector,
index_t MeanSquareScalarPerVector,
index_t GammaScalarPerVector,
index_t BetaScalarPerVector>
struct DeviceNormalize_Xdl_CShuffle : public BaseOperator
index_t AScalarPerVector,
index_t BScalarPerVector,
index_t CScalarPerVector,
index_t DScalarPerVector,
index_t EScalarPerVector,
index_t FScalarPerVector>
struct Device5AryElementwise_Xdl_CShuffle : public BaseOperator
{
static constexpr auto I0 = Number<0>{};
......@@ -73,68 +73,96 @@ struct DeviceNormalize_Xdl_CShuffle : public BaseOperator
return PadDescriptor_M_1d(desc, gridSize, blockSize);
}
using GridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
using AGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
using BGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
using CGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
using DGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
using EGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
using FGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
using Gridwise5AryEltwise = Gridwise5AryElementwise_1D<ADataType,
BDataType,
CDataType,
DDataType,
EDataType,
FDataType,
ComputeDataType,
AGridDesc_M,
BGridDesc_M,
CGridDesc_M,
DGridDesc_M,
EGridDesc_M,
FGridDesc_M,
ElementwiseFunctor,
MPerThread,
AScalarPerVector,
BScalarPerVector,
CScalarPerVector,
DScalarPerVector,
EScalarPerVector,
FScalarPerVector>;
struct Argument : public BaseArgument
{
Argument(const XDataType* p_x,
const MeanDataType* p_mean,
const MeanSquareDataType* p_mean_square,
const GammaDataType* p_gamma,
const BetaDataType* p_beta,
OutDataType* p_output,
Argument(const ADataType* p_a,
const BDataType* p_b,
const CDataType* p_c,
const DDataType* p_d,
const EDataType* p_e,
FDataType* p_f,
const std::vector<index_t>& lengths,
const std::vector<index_t>& stride_x,
const std::vector<index_t>& stride_mean,
const std::vector<index_t>& stride_mean_square,
const std::vector<index_t>& stride_gamma,
const std::vector<index_t>& stride_beta,
const std::vector<index_t>& stride_output,
OutElementwiseFunctor functor)
: p_x_(p_x),
p_mean_(p_mean),
p_mean_square_(p_mean_square),
p_gamma_(p_gamma),
p_beta_(p_beta),
p_output_(p_output),
const std::vector<index_t>& a_strides,
const std::vector<index_t>& b_strides,
const std::vector<index_t>& c_strides,
const std::vector<index_t>& d_strides,
const std::vector<index_t>& e_strides,
const std::vector<index_t>& f_strides,
ElementwiseFunctor functor)
: p_a_(p_a),
p_b_(p_b),
p_c_(p_c),
p_d_(p_d),
p_e_(p_e),
p_f_(p_f),
lengths_(lengths),
stride_x_(stride_x),
stride_mean_(stride_mean),
stride_mean_square_(stride_mean_square),
stride_gamma_(stride_gamma),
stride_beta_(stride_beta),
a_strides_(a_strides),
b_strides_(b_strides),
c_strides_(c_strides),
d_strides_(d_strides),
e_strides_(e_strides),
f_strides_(f_strides),
functor_(functor),
blockSize_(256),
gridSize_(120) // FIXME - Calculate the grid size by number of CU in the future
{
x_grid_desc_m_ = MakeDescriptor_M(lengths, stride_x, gridSize_, blockSize_);
mean_grid_desc_m_ = MakeDescriptor_M(lengths, stride_mean, gridSize_, blockSize_);
mean_square_grid_desc_m_ =
MakeDescriptor_M(lengths, stride_mean_square, gridSize_, blockSize_);
gamma_grid_desc_m_ = MakeDescriptor_M(lengths, stride_gamma, gridSize_, blockSize_);
beta_grid_desc_m_ = MakeDescriptor_M(lengths, stride_beta, gridSize_, blockSize_);
output_grid_desc_m_ = MakeDescriptor_M(lengths, stride_output, gridSize_, blockSize_);
a_grid_desc_m_ = MakeDescriptor_M(lengths, a_strides, gridSize_, blockSize_);
b_grid_desc_m_ = MakeDescriptor_M(lengths, b_strides, gridSize_, blockSize_);
c_grid_desc_m_ = MakeDescriptor_M(lengths, c_strides, gridSize_, blockSize_);
d_grid_desc_m_ = MakeDescriptor_M(lengths, d_strides, gridSize_, blockSize_);
e_grid_desc_m_ = MakeDescriptor_M(lengths, e_strides, gridSize_, blockSize_);
f_grid_desc_m_ = MakeDescriptor_M(lengths, f_strides, gridSize_, blockSize_);
}
const XDataType* p_x_;
const MeanDataType* p_mean_;
const MeanSquareDataType* p_mean_square_;
const GammaDataType* p_gamma_;
const BetaDataType* p_beta_;
OutDataType* p_output_;
const ADataType* p_a_;
const BDataType* p_b_;
const CDataType* p_c_;
const DDataType* p_d_;
const EDataType* p_e_;
FDataType* p_f_;
std::vector<index_t> lengths_;
GridDesc_M x_grid_desc_m_;
GridDesc_M mean_grid_desc_m_;
GridDesc_M mean_square_grid_desc_m_;
GridDesc_M gamma_grid_desc_m_;
GridDesc_M beta_grid_desc_m_;
GridDesc_M output_grid_desc_m_;
std::vector<index_t> stride_x_;
std::vector<index_t> stride_mean_;
std::vector<index_t> stride_mean_square_;
std::vector<index_t> stride_gamma_;
std::vector<index_t> stride_beta_;
OutElementwiseFunctor functor_;
AGridDesc_M a_grid_desc_m_;
BGridDesc_M b_grid_desc_m_;
CGridDesc_M c_grid_desc_m_;
DGridDesc_M d_grid_desc_m_;
EGridDesc_M e_grid_desc_m_;
FGridDesc_M f_grid_desc_m_;
std::vector<index_t> a_strides_;
std::vector<index_t> b_strides_;
std::vector<index_t> c_strides_;
std::vector<index_t> d_strides_;
std::vector<index_t> e_strides_;
std::vector<index_t> f_strides_;
ElementwiseFunctor functor_;
index_t blockSize_;
index_t gridSize_;
};
......@@ -143,10 +171,37 @@ struct DeviceNormalize_Xdl_CShuffle : public BaseOperator
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
// TODO
(void)arg;
(void)stream_config;
return 0;
const auto kernel = kernel_5ary_elementwise_1d<Gridwise5AryEltwise,
ADataType,
BDataType,
CDataType,
DDataType,
EDataType,
FDataType,
AGridDesc_M,
BGridDesc_M,
CGridDesc_M,
ElementwiseFunctor>;
float elapsed_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.gridSize_),
dim3(arg.blockSize_),
0,
arg.p_a_,
arg.p_b_,
arg.p_c_,
arg.p_d_,
arg.p_e_,
arg.p_f_,
arg.a_grid_desc_m_,
arg.b_grid_desc_m_,
arg.c_grid_desc_m_,
arg.d_grid_desc_m_,
arg.e_grid_desc_m_,
arg.f_grid_desc_m_,
arg.functor_);
return elapsed_time;
}
// polymorphic
......@@ -181,20 +236,22 @@ struct DeviceNormalize_Xdl_CShuffle : public BaseOperator
return ret;
};
if(!IsScalarPerVectorValid(pArg->stride_x_.back() == 1, XScalarPerVector))
if(!IsScalarPerVectorValid(pArg->a_strides_.back() == 1, AScalarPerVector))
return false;
if(!IsScalarPerVectorValid(pArg->b_strides_.back() == 1, BScalarPerVector))
return false;
if(!IsScalarPerVectorValid(pArg->stride_mean_.back() == 1, MeanScalarPerVector))
if(!IsScalarPerVectorValid(pArg->c_strides_.back() == 1, CScalarPerVector))
return false;
if(!IsScalarPerVectorValid(pArg->stride_mean_square_.back() == 1,
MeanSquareScalarPerVector))
if(!IsScalarPerVectorValid(pArg->d_strides_.back() == 1, DScalarPerVector))
return false;
if(!IsScalarPerVectorValid(pArg->stride_gamma_.back() == 1, GammaScalarPerVector))
if(!IsScalarPerVectorValid(pArg->e_strides_.back() == 1, EScalarPerVector))
return false;
if(!IsScalarPerVectorValid(pArg->stride_beta_.back() == 1, BetaScalarPerVector))
if(!IsScalarPerVectorValid(pArg->f_strides_.back() == 1, FScalarPerVector))
return false;
};
};
......
#pragma once
#include "cluster_descriptor.hpp"
#include "data_type.hpp"
#include "element_wise_operation.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
namespace ck {
template <typename Gridwise5AryEltwise,
typename ADataType,
typename BDataType,
typename CDataType,
typename DDataType,
typename EDataType,
typename FDataType,
typename AGridDesc_M,
typename BGridDesc_M,
typename CGridDesc_M,
typename DGridDesc_M,
typename EGridDesc_M,
typename FGridDesc_M,
typename ElementwiseFunctor>
__global__ void kernel_5ary_elementwise_1d(const ADataType* __restrict__ p_a_global,
const BDataType* __restrict__ p_b_global,
const CDataType* __restrict__ p_c_global,
const DDataType* __restrict__ p_d_global,
const EDataType* __restrict__ p_e_global,
FDataType* __restrict__ p_f_global,
const AGridDesc_M a_grid_desc_m,
const BGridDesc_M b_grid_desc_m,
const CGridDesc_M c_grid_desc_m,
const DGridDesc_M d_grid_desc_m,
const EGridDesc_M e_grid_desc_m,
const FGridDesc_M f_grid_desc_m,
const ElementwiseFunctor functor)
{
Gridwise5AryEltwise::Run(p_a_global,
p_b_global,
p_c_global,
p_d_global,
p_e_global,
p_f_global,
a_grid_desc_m,
b_grid_desc_m,
c_grid_desc_m,
d_grid_desc_m,
e_grid_desc_m,
f_grid_desc_m,
functor);
}
// TODO - implement n-ary Elemenetwise_1D, tuple of inputs and tuple of outputs
template <typename ADataType,
typename BDataType,
typename CDataType,
typename DDataType,
typename EDataType,
typename FDataType,
typename ComputeDataType,
typename AGridDesc_M,
typename BGridDesc_M,
typename CGridDesc_M,
typename DGridDesc_M,
typename EGridDesc_M,
typename FGridDesc_M,
typename ElementwiseFunctor,
index_t MPerThread,
index_t AScalarPerVector,
index_t BScalarPerVector,
index_t CScalarPerVector,
index_t DScalarPerVector,
index_t EScalarPerVector,
index_t FScalarPerVector>
struct Gridwise5AryElementwise_1D
{
static constexpr auto I0 = Number<0>{};
static constexpr auto thread_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MPerThread>{}));
using PassThrough = tensor_operation::element_wise::PassThrough;
static __device__ auto CalculateElementwiseIndex()
{
const index_t global_thread_id = get_thread_global_1d_id();
return make_multi_index(global_thread_id * MPerThread);
}
__device__ static void Run(const ADataType* __restrict__ p_a_global,
const BDataType* __restrict__ p_b_global,
const CDataType* __restrict__ p_c_global,
const DDataType* __restrict__ p_d_global,
const EDataType* __restrict__ p_e_global,
FDataType* __restrict__ p_f_global,
const AGridDesc_M a_grid_desc_m,
const BGridDesc_M b_grid_desc_m,
const CGridDesc_M c_grid_desc_m,
const DGridDesc_M d_grid_desc_m,
const EGridDesc_M e_grid_desc_m,
const FGridDesc_M f_grid_desc_m,
const ElementwiseFunctor functor)
{
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_grid_desc_m.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_grid_desc_m.GetElementSpaceSize());
const auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_global, c_grid_desc_m.GetElementSpaceSize());
const auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d_global, d_grid_desc_m.GetElementSpaceSize());
const auto e_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_global, e_grid_desc_m.GetElementSpaceSize());
auto f_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_f_global, f_grid_desc_m.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> a_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> b_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> c_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> d_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> e_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> f_thread_buf;
const auto thread_store_global_offset = CalculateElementwiseIndex();
auto a_global_load =
ThreadwiseTensorSliceTransfer_v2<ADataType,
ComputeDataType,
AGridDesc_M,
decltype(thread_desc_m),
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
AScalarPerVector, // ScalarPerVector
1, // SrcScalarStrideInVector
false>{a_grid_desc_m, thread_store_global_offset};
auto b_global_load =
ThreadwiseTensorSliceTransfer_v2<BDataType,
ComputeDataType,
BGridDesc_M,
decltype(thread_desc_m),
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
BScalarPerVector, // ScalarPerVector
1, // SrcScalarStrideInVector
false>{b_grid_desc_m, thread_store_global_offset};
auto c_global_load =
ThreadwiseTensorSliceTransfer_v2<CDataType,
ComputeDataType,
CGridDesc_M,
decltype(thread_desc_m),
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
CScalarPerVector, // ScalarPerVector
1, // SrcScalarStrideInVector
false>{c_grid_desc_m, thread_store_global_offset};
auto d_global_load =
ThreadwiseTensorSliceTransfer_v2<DDataType,
ComputeDataType,
DGridDesc_M,
decltype(thread_desc_m),
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
DScalarPerVector, // ScalarPerVector
1, // SrcScalarStrideInVector
false>{d_grid_desc_m, thread_store_global_offset};
auto e_global_load =
ThreadwiseTensorSliceTransfer_v2<EDataType,
ComputeDataType,
EGridDesc_M,
decltype(thread_desc_m),
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
EScalarPerVector, // ScalarPerVector
1, // SrcScalarStrideInVector
false>{e_grid_desc_m, thread_store_global_offset};
auto f_global_write =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
FDataType,
decltype(thread_desc_m),
FGridDesc_M,
PassThrough,
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // DstVectorDim
FScalarPerVector, // ScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
false>{
f_grid_desc_m, thread_store_global_offset, PassThrough{}};
const index_t blockSize = get_block_size();
const index_t blockPerGrid = get_grid_size();
const auto M = c_grid_desc_m.GetLength(I0);
const index_t loop_step = blockPerGrid * blockSize * MPerThread;
const auto loop_step_index = make_multi_index(loop_step);
index_t num_iter = M / (loop_step);
do
{
// read and process MPerThread elements
a_global_load.Run(
a_grid_desc_m, a_global_buf, thread_desc_m, make_tuple(I0), a_thread_buf);
b_global_load.Run(
b_grid_desc_m, b_global_buf, thread_desc_m, make_tuple(I0), b_thread_buf);
c_global_load.Run(
c_grid_desc_m, c_global_buf, thread_desc_m, make_tuple(I0), c_thread_buf);
d_global_load.Run(
d_grid_desc_m, d_global_buf, thread_desc_m, make_tuple(I0), d_thread_buf);
e_global_load.Run(
e_grid_desc_m, e_global_buf, thread_desc_m, make_tuple(I0), e_thread_buf);
static_for<0, MPerThread, 1>{}([&](auto m) {
constexpr auto offset = thread_desc_m.CalculateOffset(make_tuple(m));
functor(f_thread_buf(Number<offset>{}),
a_thread_buf(Number<offset>{}),
b_thread_buf(Number<offset>{}),
c_thread_buf(Number<offset>{}),
d_thread_buf(Number<offset>{}),
e_thread_buf(Number<offset>{}));
});
f_global_write.Run(thread_desc_m,
make_tuple(I0), // SrcSliceOriginIdx
f_thread_buf,
f_grid_desc_m,
f_global_buf);
a_global_load.MoveSrcSliceWindow(a_grid_desc_m, loop_step_index);
b_global_load.MoveSrcSliceWindow(b_grid_desc_m, loop_step_index);
c_global_load.MoveSrcSliceWindow(c_grid_desc_m, loop_step_index);
d_global_load.MoveSrcSliceWindow(d_grid_desc_m, loop_step_index);
e_global_load.MoveSrcSliceWindow(e_grid_desc_m, loop_step_index);
f_global_write.MoveDstSliceWindow(f_grid_desc_m, loop_step_index);
} while(--num_iter);
}
};
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment