Commit 9d2280d6 authored by rocking's avatar rocking
Browse files

1. Rename AccDatatype in normalization to computeData

2. Rename AccElementwiseOperation to YElementwiseOperation in normalization
parent 1a38e362
...@@ -16,7 +16,7 @@ using XDataType = ck::half_t; ...@@ -16,7 +16,7 @@ using XDataType = ck::half_t;
using GammaDataType = ck::half_t; using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t; using BetaDataType = ck::half_t;
using YDataType = ck::half_t; using YDataType = ck::half_t;
using AccDataType = float; using ComputeDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
constexpr int Rank = 2; constexpr int Rank = 2;
...@@ -54,7 +54,7 @@ int main(int argc, char* argv[]) ...@@ -54,7 +54,7 @@ int main(int argc, char* argv[])
using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType, using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
AccDataType, ComputeDataType,
YDataType, YDataType,
PassThrough, PassThrough,
Rank, Rank,
......
...@@ -24,7 +24,7 @@ using XDataType = ck::half_t; ...@@ -24,7 +24,7 @@ using XDataType = ck::half_t;
using GammaDataType = ck::half_t; using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t; using BetaDataType = ck::half_t;
using YDataType = ck::half_t; using YDataType = ck::half_t;
using AccDataType = float; using ConputeDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
constexpr int Rank = 2; constexpr int Rank = 2;
...@@ -34,7 +34,7 @@ using DeviceInstance = ...@@ -34,7 +34,7 @@ using DeviceInstance =
ck::tensor_operation::device::DeviceNormalizationImpl<XDataType, ck::tensor_operation::device::DeviceNormalizationImpl<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
AccDataType, ConputeDataType,
YDataType, YDataType,
PassThrough, PassThrough,
Rank, Rank,
...@@ -121,7 +121,7 @@ int main() ...@@ -121,7 +121,7 @@ int main()
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
AccDataType, ConputeDataType,
PassThrough, PassThrough,
Rank, Rank,
NumReduceDim>; NumReduceDim>;
......
...@@ -27,7 +27,7 @@ using XDataType = ck::half_t; ...@@ -27,7 +27,7 @@ using XDataType = ck::half_t;
using GammaDataType = ck::half_t; using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t; using BetaDataType = ck::half_t;
using YDataType = ck::half_t; using YDataType = ck::half_t;
using AccDataType = float; using ConputeDataType = float;
struct YElementOp struct YElementOp
{ {
...@@ -50,7 +50,7 @@ using DeviceInstance = ...@@ -50,7 +50,7 @@ using DeviceInstance =
ck::tensor_operation::device::DeviceNormalizationImpl<XDataType, ck::tensor_operation::device::DeviceNormalizationImpl<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
AccDataType, ConputeDataType,
YDataType, YDataType,
YElementOp, YElementOp,
Rank, Rank,
...@@ -157,7 +157,7 @@ int main(int argc, char* argv[]) ...@@ -157,7 +157,7 @@ int main(int argc, char* argv[])
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
AccDataType, ConputeDataType,
YElementOp>; YElementOp>;
ReferenceInstance ref; ReferenceInstance ref;
......
...@@ -14,9 +14,9 @@ namespace device { ...@@ -14,9 +14,9 @@ namespace device {
template <typename XDataType, template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename AccDataType, typename ConputeDataType,
typename YDataType, typename YDataType,
typename AccElementwiseOperation, typename YElementwiseOperation,
index_t Rank, index_t Rank,
index_t NumReduceDim> index_t NumReduceDim>
struct DeviceNormalization : public BaseOperator struct DeviceNormalization : public BaseOperator
...@@ -35,7 +35,7 @@ struct DeviceNormalization : public BaseOperator ...@@ -35,7 +35,7 @@ struct DeviceNormalization : public BaseOperator
void* p_y, void* p_y,
void* p_savedMean, void* p_savedMean,
void* p_savedInvVar, void* p_savedInvVar,
AccElementwiseOperation acc_elementwise_op) = 0; YElementwiseOperation y_elementwise_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
...@@ -43,17 +43,17 @@ struct DeviceNormalization : public BaseOperator ...@@ -43,17 +43,17 @@ struct DeviceNormalization : public BaseOperator
template <typename XDataType, template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename AccDataType, typename ConputeDataType,
typename YDataType, typename YDataType,
typename AccElementwiseOperation, typename YElementwiseOperation,
index_t Rank, index_t Rank,
index_t NumReduceDim> index_t NumReduceDim>
using DeviceNormalizationPtr = std::unique_ptr<DeviceNormalization<XDataType, using DeviceNormalizationPtr = std::unique_ptr<DeviceNormalization<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
AccDataType, ConputeDataType,
YDataType, YDataType,
AccElementwiseOperation, YElementwiseOperation,
Rank, Rank,
NumReduceDim>>; NumReduceDim>>;
......
...@@ -21,20 +21,20 @@ template <typename GridwiseReduction, ...@@ -21,20 +21,20 @@ template <typename GridwiseReduction,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename YDataType, typename YDataType,
typename AccDataType, typename ConputeDataType,
typename AccElementwiseOperation, typename YElementwiseOperation,
typename GridDesc_M_K> typename GridDesc_M_K>
__global__ void kernel_normalization(const GridDesc_M_K x_grid_desc_m_k, __global__ void kernel_normalization(const GridDesc_M_K x_grid_desc_m_k,
const GridDesc_M_K gamma_grid_desc_m_k, const GridDesc_M_K gamma_grid_desc_m_k,
const GridDesc_M_K beta_grid_desc_m_k, const GridDesc_M_K beta_grid_desc_m_k,
const GridDesc_M_K y_grid_desc_m_k, const GridDesc_M_K y_grid_desc_m_k,
index_t num_k_block_tile_iteration, index_t num_k_block_tile_iteration,
AccDataType epsilon, ConputeDataType epsilon,
const XDataType* const __restrict__ p_x_global, const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global, const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global, const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global, YDataType* const __restrict__ p_y_global,
const AccElementwiseOperation acc_elementwise_op) const YElementwiseOperation y_elementwise_op)
{ {
GridwiseReduction::Run(x_grid_desc_m_k, GridwiseReduction::Run(x_grid_desc_m_k,
gamma_grid_desc_m_k, gamma_grid_desc_m_k,
...@@ -46,7 +46,7 @@ __global__ void kernel_normalization(const GridDesc_M_K x_grid_desc_m_k, ...@@ -46,7 +46,7 @@ __global__ void kernel_normalization(const GridDesc_M_K x_grid_desc_m_k,
p_gamma_global, p_gamma_global,
p_beta_global, p_beta_global,
p_y_global, p_y_global,
acc_elementwise_op); y_elementwise_op);
}; };
} // namespace ck } // namespace ck
...@@ -58,9 +58,9 @@ namespace device { ...@@ -58,9 +58,9 @@ namespace device {
template <typename XDataType, template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename AccDataType, typename ConputeDataType,
typename YDataType, typename YDataType,
typename AccElementwiseOperation, typename YElementwiseOperation,
index_t Rank, index_t Rank,
index_t NumReduceDim, index_t NumReduceDim,
index_t BlockSize, index_t BlockSize,
...@@ -78,9 +78,9 @@ template <typename XDataType, ...@@ -78,9 +78,9 @@ template <typename XDataType,
struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
AccDataType, ConputeDataType,
YDataType, YDataType,
AccElementwiseOperation, YElementwiseOperation,
Rank, Rank,
NumReduceDim> NumReduceDim>
{ {
...@@ -172,8 +172,8 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -172,8 +172,8 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
AccDataType, ConputeDataType,
AccElementwiseOperation, YElementwiseOperation,
GridDesc_M_K, GridDesc_M_K,
BlockSize, BlockSize,
MThreadClusterSize, MThreadClusterSize,
...@@ -194,8 +194,8 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -194,8 +194,8 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
AccDataType, ConputeDataType,
AccElementwiseOperation, YElementwiseOperation,
GridDesc_M_K, GridDesc_M_K,
BlockSize, BlockSize,
MThreadClusterSize, MThreadClusterSize,
...@@ -220,7 +220,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -220,7 +220,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const std::vector<index_t> betaStrides, const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides, const std::vector<index_t> yStrides,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
AccElementwiseOperation acc_elementwise_op, YElementwiseOperation y_elementwise_op,
double epsilon, double epsilon,
const XDataType* p_x, const XDataType* p_x,
const GammaDataType* p_gamma, const GammaDataType* p_gamma,
...@@ -230,9 +230,9 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -230,9 +230,9 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
p_gamma_(p_gamma), p_gamma_(p_gamma),
p_beta_(p_beta), p_beta_(p_beta),
p_y_(p_y), p_y_(p_y),
acc_elementwise_op_(acc_elementwise_op) y_elementwise_op_(y_elementwise_op)
{ {
epsilon_ = static_cast<AccDataType>(epsilon); epsilon_ = static_cast<ConputeDataType>(epsilon);
Lengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(lengths, reduceDims); Lengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(lengths, reduceDims);
xStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(xStrides, reduceDims); xStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(xStrides, reduceDims);
...@@ -265,7 +265,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -265,7 +265,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
x_grid_desc_m_k_.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize; x_grid_desc_m_k_.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize;
} }
AccDataType epsilon_; ConputeDataType epsilon_;
const XDataType* p_x_; const XDataType* p_x_;
const GammaDataType* p_gamma_; const GammaDataType* p_gamma_;
...@@ -278,7 +278,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -278,7 +278,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
std::vector<index_t> betaStrides_; std::vector<index_t> betaStrides_;
std::vector<index_t> yStrides_; std::vector<index_t> yStrides_;
AccElementwiseOperation acc_elementwise_op_; YElementwiseOperation y_elementwise_op_;
int blkGroupSize_; int blkGroupSize_;
int numBlockTileIteration_; int numBlockTileIteration_;
...@@ -301,16 +301,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -301,16 +301,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
AccDataType, ConputeDataType,
AccElementwiseOperation, YElementwiseOperation,
GridDesc_M_K> GridDesc_M_K>
: kernel_normalization<GridwiseReduceLayernormGeneric, : kernel_normalization<GridwiseReduceLayernormGeneric,
XDataType, XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
AccDataType, ConputeDataType,
AccElementwiseOperation, YElementwiseOperation,
GridDesc_M_K>; GridDesc_M_K>;
float avg_time = 0; float avg_time = 0;
...@@ -329,7 +329,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -329,7 +329,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
arg.p_gamma_, arg.p_gamma_,
arg.p_beta_, arg.p_beta_,
arg.p_y_, arg.p_y_,
arg.acc_elementwise_op_); arg.y_elementwise_op_);
return (avg_time); return (avg_time);
}; };
...@@ -429,7 +429,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -429,7 +429,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
void* p_y, void* p_y,
void* p_saveMean, void* p_saveMean,
void* p_saveInvVar, void* p_saveInvVar,
AccElementwiseOperation acc_elementwise_op) override YElementwiseOperation y_elementwise_op) override
{ {
// TODO // TODO
// Optional cache of the intermediate results (mean and InvVariance) during the // Optional cache of the intermediate results (mean and InvVariance) during the
...@@ -443,7 +443,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -443,7 +443,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
betaStrides, betaStrides,
yStrides, yStrides,
reduceDims, reduceDims,
acc_elementwise_op, y_elementwise_op,
epsilon, epsilon,
static_cast<const XDataType*>(p_x), static_cast<const XDataType*>(p_x),
static_cast<const GammaDataType*>(p_gamma), static_cast<const GammaDataType*>(p_gamma),
......
...@@ -16,8 +16,8 @@ template <typename XDataType, ...@@ -16,8 +16,8 @@ template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename YDataType, typename YDataType,
typename AccDataType, typename ComputeDataType,
typename AccElementwiseOperation, typename YElementwiseOperation,
typename GridDesc_M_K, typename GridDesc_M_K,
index_t BlockSize, index_t BlockSize,
index_t MThreadClusterSize, index_t MThreadClusterSize,
...@@ -70,9 +70,9 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -70,9 +70,9 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}))); decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using ThreadwiseWelford = using ThreadwiseWelford =
ThreadwiseWelford<AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M>; ThreadwiseWelford<ComputeDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M>;
using BlockwiseWelford = BlockwiseWelford<AccDataType, using BlockwiseWelford = BlockwiseWelford<ComputeDataType,
BlockSize, BlockSize,
ThreadClusterLengths_M_K, ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder>; ThreadClusterArrangeOrder>;
...@@ -115,12 +115,12 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -115,12 +115,12 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
const GridDesc_M_K& beta_grid_desc_m_k, const GridDesc_M_K& beta_grid_desc_m_k,
const GridDesc_M_K& y_grid_desc_m_k, const GridDesc_M_K& y_grid_desc_m_k,
index_t num_k_block_tile_iteration, index_t num_k_block_tile_iteration,
AccDataType epsilon, ComputeDataType epsilon,
const XDataType* const __restrict__ p_x_global, const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global, const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global, const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global, YDataType* const __restrict__ p_y_global,
const AccElementwiseOperation acc_elementwise_op) const YElementwiseOperation y_elementwise_op)
{ {
if constexpr(SweepOnce) if constexpr(SweepOnce)
{ {
...@@ -133,7 +133,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -133,7 +133,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
auto x_thread_buf = generate_tuple( auto x_thread_buf = generate_tuple(
[&](auto) { [&](auto) {
return StaticBuffer<AddressSpaceEnum::Vgpr, return StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType, ComputeDataType,
MThreadSliceSize * XSrcVectorSize, MThreadSliceSize * XSrcVectorSize,
true>{}; true>{};
}, },
...@@ -142,7 +142,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -142,7 +142,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
auto gamma_thread_buf = generate_tuple( auto gamma_thread_buf = generate_tuple(
[&](auto) { [&](auto) {
return StaticBuffer<AddressSpaceEnum::Vgpr, return StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType, ComputeDataType,
MThreadSliceSize * GammaSrcVectorSize, MThreadSliceSize * GammaSrcVectorSize,
true>{}; true>{};
}, },
...@@ -151,7 +151,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -151,7 +151,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
auto beta_thread_buf = generate_tuple( auto beta_thread_buf = generate_tuple(
[&](auto) { [&](auto) {
return StaticBuffer<AddressSpaceEnum::Vgpr, return StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType, ComputeDataType,
MThreadSliceSize * BetaSrcVectorSize, MThreadSliceSize * BetaSrcVectorSize,
true>{}; true>{};
}, },
...@@ -160,14 +160,16 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -160,14 +160,16 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
auto y_thread_buf = generate_tuple( auto y_thread_buf = generate_tuple(
[&](auto) { [&](auto) {
return StaticBuffer<AddressSpaceEnum::Vgpr, return StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType, ComputeDataType,
MThreadSliceSize * YDstVectorSize, MThreadSliceSize * YDstVectorSize,
true>{}; true>{};
}, },
Number<ThreadBufferNumber>{}); Number<ThreadBufferNumber>{});
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> mean_thread_buf; StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> var_thread_buf; mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
var_thread_buf;
const index_t thread_local_id = get_thread_local_1d_id(); const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id(); const index_t block_global_id = get_block_1d_id();
...@@ -179,7 +181,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -179,7 +181,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
const auto thread_k_cluster_id = thread_cluster_idx[I1]; const auto thread_k_cluster_id = thread_cluster_idx[I1];
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType, auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
AccDataType, ComputeDataType,
GridDesc_M_K, GridDesc_M_K,
decltype(thread_buffer_desc_m_k), decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K, ThreadBufferLengths_M_K,
...@@ -195,7 +197,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -195,7 +197,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
auto threadwise_gamma_load = auto threadwise_gamma_load =
ThreadwiseTensorSliceTransfer_v2<GammaDataType, ThreadwiseTensorSliceTransfer_v2<GammaDataType,
AccDataType, ComputeDataType,
GridDesc_M_K, GridDesc_M_K,
decltype(thread_buffer_desc_m_k), decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K, ThreadBufferLengths_M_K,
...@@ -211,7 +213,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -211,7 +213,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
auto threadwise_beta_load = auto threadwise_beta_load =
ThreadwiseTensorSliceTransfer_v2<BetaDataType, ThreadwiseTensorSliceTransfer_v2<BetaDataType,
AccDataType, ComputeDataType,
GridDesc_M_K, GridDesc_M_K,
decltype(thread_buffer_desc_m_k), decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K, ThreadBufferLengths_M_K,
...@@ -226,11 +228,11 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -226,11 +228,11 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
thread_k_cluster_id * BetaSrcVectorSize)); thread_k_cluster_id * BetaSrcVectorSize));
auto threadwise_y_store = auto threadwise_y_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType, ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
YDataType, YDataType,
decltype(thread_buffer_desc_m_k), decltype(thread_buffer_desc_m_k),
GridDesc_M_K, GridDesc_M_K,
AccElementwiseOperation, YElementwiseOperation,
ThreadBufferLengths_M_K, ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder, ThreadBufferDimAccessOrder,
YDstVectorDim, YDstVectorDim,
...@@ -242,7 +244,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -242,7 +244,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
make_multi_index(block_global_id * M_BlockTileSize + make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * YDstVectorSize), thread_k_cluster_id * YDstVectorSize),
acc_elementwise_op); y_elementwise_op);
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize); constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize);
constexpr auto thread_copy_bwd_step_m_k = constexpr auto thread_copy_bwd_step_m_k =
...@@ -261,8 +263,8 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -261,8 +263,8 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k, thread_k_cluster_id); threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k, thread_k_cluster_id);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
mean_thread_buf(I) = type_convert<AccDataType>(0.0f); mean_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
var_thread_buf(I) = type_convert<AccDataType>(0.0f); var_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
}); });
// Separate sweep once and sweep twice pipeline // Separate sweep once and sweep twice pipeline
......
...@@ -21,7 +21,7 @@ template <typename OutElementwise, index_t Rank, index_t Reduce> ...@@ -21,7 +21,7 @@ template <typename OutElementwise, index_t Rank, index_t Reduce>
// clang-format off // clang-format off
using device_normalization_f16_instances = using device_normalization_f16_instances =
std::tuple < std::tuple <
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize> // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize>
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
......
...@@ -19,7 +19,7 @@ using Pass = ck::tensor_operation::element_wise::PassThrough; ...@@ -19,7 +19,7 @@ using Pass = ck::tensor_operation::element_wise::PassThrough;
template <typename OutElementwise, index_t Rank, index_t Reduce> template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_layernorm_f32_instances = std::tuple< using device_layernorm_f32_instances = std::tuple<
// clang-format off // clang-format off
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize> // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
......
...@@ -19,7 +19,7 @@ namespace profiler { ...@@ -19,7 +19,7 @@ namespace profiler {
template <typename XDataType, template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename AccDataType, typename ComputeDataType,
typename YDataType, typename YDataType,
index_t Rank> index_t Rank>
bool profile_layernorm_impl(int do_verification, bool profile_layernorm_impl(int do_verification,
...@@ -86,7 +86,7 @@ bool profile_layernorm_impl(int do_verification, ...@@ -86,7 +86,7 @@ bool profile_layernorm_impl(int do_verification,
using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType, using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
AccDataType, ComputeDataType,
YDataType, YDataType,
PassThrough, PassThrough,
Rank, Rank,
...@@ -109,7 +109,7 @@ bool profile_layernorm_impl(int do_verification, ...@@ -109,7 +109,7 @@ bool profile_layernorm_impl(int do_verification,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
AccDataType, ComputeDataType,
PassThrough, PassThrough,
Rank, Rank,
NumReduceDim>; NumReduceDim>;
......
...@@ -15,7 +15,7 @@ class TestGroupnorm : public ::testing::Test ...@@ -15,7 +15,7 @@ class TestGroupnorm : public ::testing::Test
using XDataType = std::tuple_element_t<0, Tuple>; using XDataType = std::tuple_element_t<0, Tuple>;
using GammaDataType = std::tuple_element_t<1, Tuple>; using GammaDataType = std::tuple_element_t<1, Tuple>;
using BetaDataType = std::tuple_element_t<2, Tuple>; using BetaDataType = std::tuple_element_t<2, Tuple>;
using AccDataType = std::tuple_element_t<3, Tuple>; using ComputeDataType = std::tuple_element_t<3, Tuple>;
using YDataType = std::tuple_element_t<4, Tuple>; using YDataType = std::tuple_element_t<4, Tuple>;
void Run() void Run()
...@@ -36,7 +36,7 @@ class TestGroupnorm : public ::testing::Test ...@@ -36,7 +36,7 @@ class TestGroupnorm : public ::testing::Test
ck::profiler::profile_groupnorm_impl<XDataType, ck::profiler::profile_groupnorm_impl<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
AccDataType, ComputeDataType,
YDataType>(true, 2, false, false, length); YDataType>(true, 2, false, false, length);
EXPECT_TRUE(success); EXPECT_TRUE(success);
} }
...@@ -44,7 +44,7 @@ class TestGroupnorm : public ::testing::Test ...@@ -44,7 +44,7 @@ class TestGroupnorm : public ::testing::Test
}; };
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType> // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType>
std::tuple<F16, F16, F16, F32, F16>>; std::tuple<F16, F16, F16, F32, F16>>;
TYPED_TEST_SUITE(TestGroupnorm, KernelTypes); TYPED_TEST_SUITE(TestGroupnorm, KernelTypes);
......
...@@ -15,7 +15,7 @@ class TestGroupnorm : public ::testing::Test ...@@ -15,7 +15,7 @@ class TestGroupnorm : public ::testing::Test
using XDataType = std::tuple_element_t<0, Tuple>; using XDataType = std::tuple_element_t<0, Tuple>;
using GammaDataType = std::tuple_element_t<1, Tuple>; using GammaDataType = std::tuple_element_t<1, Tuple>;
using BetaDataType = std::tuple_element_t<2, Tuple>; using BetaDataType = std::tuple_element_t<2, Tuple>;
using AccDataType = std::tuple_element_t<3, Tuple>; using ComputeDataType = std::tuple_element_t<3, Tuple>;
using YDataType = std::tuple_element_t<4, Tuple>; using YDataType = std::tuple_element_t<4, Tuple>;
void Run() void Run()
...@@ -34,7 +34,7 @@ class TestGroupnorm : public ::testing::Test ...@@ -34,7 +34,7 @@ class TestGroupnorm : public ::testing::Test
ck::profiler::profile_groupnorm_impl<XDataType, ck::profiler::profile_groupnorm_impl<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
AccDataType, ComputeDataType,
YDataType>(true, 2, false, false, length); YDataType>(true, 2, false, false, length);
EXPECT_TRUE(success); EXPECT_TRUE(success);
} }
...@@ -42,7 +42,7 @@ class TestGroupnorm : public ::testing::Test ...@@ -42,7 +42,7 @@ class TestGroupnorm : public ::testing::Test
}; };
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType> // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType>
std::tuple<F32, F32, F32, F32, F32>>; std::tuple<F32, F32, F32, F32, F32>>;
TYPED_TEST_SUITE(TestGroupnorm, KernelTypes); TYPED_TEST_SUITE(TestGroupnorm, KernelTypes);
......
...@@ -15,7 +15,7 @@ class TestLayernorm2d : public ::testing::Test ...@@ -15,7 +15,7 @@ class TestLayernorm2d : public ::testing::Test
using XDataType = std::tuple_element_t<0, Tuple>; using XDataType = std::tuple_element_t<0, Tuple>;
using GammaDataType = std::tuple_element_t<1, Tuple>; using GammaDataType = std::tuple_element_t<1, Tuple>;
using BetaDataType = std::tuple_element_t<2, Tuple>; using BetaDataType = std::tuple_element_t<2, Tuple>;
using AccDataType = std::tuple_element_t<3, Tuple>; using ComputeDataType = std::tuple_element_t<3, Tuple>;
using YDataType = std::tuple_element_t<4, Tuple>; using YDataType = std::tuple_element_t<4, Tuple>;
void Run() void Run()
...@@ -29,7 +29,7 @@ class TestLayernorm2d : public ::testing::Test ...@@ -29,7 +29,7 @@ class TestLayernorm2d : public ::testing::Test
bool success = ck::profiler::profile_layernorm_impl<XDataType, bool success = ck::profiler::profile_layernorm_impl<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
AccDataType, ComputeDataType,
YDataType, YDataType,
2>(true, 2, false, false, length); 2>(true, 2, false, false, length);
EXPECT_TRUE(success); EXPECT_TRUE(success);
...@@ -38,7 +38,7 @@ class TestLayernorm2d : public ::testing::Test ...@@ -38,7 +38,7 @@ class TestLayernorm2d : public ::testing::Test
}; };
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType> // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType>
std::tuple<F16, F16, F16, F32, F16>>; std::tuple<F16, F16, F16, F32, F16>>;
TYPED_TEST_SUITE(TestLayernorm2d, KernelTypes); TYPED_TEST_SUITE(TestLayernorm2d, KernelTypes);
......
...@@ -15,7 +15,7 @@ class TestLayernorm2d : public ::testing::Test ...@@ -15,7 +15,7 @@ class TestLayernorm2d : public ::testing::Test
using XDataType = std::tuple_element_t<0, Tuple>; using XDataType = std::tuple_element_t<0, Tuple>;
using GammaDataType = std::tuple_element_t<1, Tuple>; using GammaDataType = std::tuple_element_t<1, Tuple>;
using BetaDataType = std::tuple_element_t<2, Tuple>; using BetaDataType = std::tuple_element_t<2, Tuple>;
using AccDataType = std::tuple_element_t<3, Tuple>; using ComputeDataType = std::tuple_element_t<3, Tuple>;
using YDataType = std::tuple_element_t<4, Tuple>; using YDataType = std::tuple_element_t<4, Tuple>;
void Run() void Run()
...@@ -29,7 +29,7 @@ class TestLayernorm2d : public ::testing::Test ...@@ -29,7 +29,7 @@ class TestLayernorm2d : public ::testing::Test
bool success = ck::profiler::profile_layernorm_impl<XDataType, bool success = ck::profiler::profile_layernorm_impl<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
AccDataType, ComputeDataType,
YDataType, YDataType,
2>(true, 2, false, false, length); 2>(true, 2, false, false, length);
EXPECT_TRUE(success); EXPECT_TRUE(success);
...@@ -38,7 +38,7 @@ class TestLayernorm2d : public ::testing::Test ...@@ -38,7 +38,7 @@ class TestLayernorm2d : public ::testing::Test
}; };
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType> // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType>
std::tuple<F32, F32, F32, F32, F32>>; std::tuple<F32, F32, F32, F32, F32>>;
TYPED_TEST_SUITE(TestLayernorm2d, KernelTypes); TYPED_TEST_SUITE(TestLayernorm2d, KernelTypes);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment