Unverified Commit 6a6163a3 authored by rocking5566's avatar rocking5566 Committed by GitHub
Browse files

Improve normalization (#580)

* Sync the order of type string with template parameter

* Add more instances

* Check the vector size and remove redundant var

* Extract var to static, prepare to separate sweep once kernel

* Separate sweeponce flow and optimize the flow

* 1. Rename AccDatatype in normalization to computeData
2. Rename AccElementwiseOperation to YElementwiseOperation in normalization

* Remove useless code

* Update naive variance kernel

* Refine string

* Fix typo

* Support naive variance for device_normalization

* Check the blocksize

* Share the VGPR of x and y

* Share the VGPR of gamma and beta

* Add more instances

* Support fp16 sqrt for experiment

* Add CHANGELOG

* Fix typo

* clang-format
parent 0cfda84d
...@@ -9,7 +9,7 @@ Full documentation for Composable Kernel is not yet available. ...@@ -9,7 +9,7 @@ Full documentation for Composable Kernel is not yet available.
- Fixed grouped ConvBwdWeight test case failure (#524). - Fixed grouped ConvBwdWeight test case failure (#524).
### Optimizations ### Optimizations
- Optimized ... - Improve proformance of normalization kernel
### Added ### Added
- Added user tutorial (#563). - Added user tutorial (#563).
......
...@@ -12,12 +12,12 @@ ...@@ -12,12 +12,12 @@
#include "ck/library/tensor_operation_instance/gpu/normalization.hpp" #include "ck/library/tensor_operation_instance/gpu/normalization.hpp"
using XDataType = ck::half_t; using XDataType = ck::half_t;
using GammaDataType = ck::half_t; using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t; using BetaDataType = ck::half_t;
using YDataType = ck::half_t; using YDataType = ck::half_t;
using AccDataType = float; using ComputeDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
constexpr int Rank = 2; constexpr int Rank = 2;
constexpr int NumReduceDim = 1; constexpr int NumReduceDim = 1;
...@@ -54,7 +54,7 @@ int main(int argc, char* argv[]) ...@@ -54,7 +54,7 @@ int main(int argc, char* argv[])
using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType, using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
AccDataType, ComputeDataType,
YDataType, YDataType,
PassThrough, PassThrough,
Rank, Rank,
......
...@@ -20,12 +20,12 @@ ...@@ -20,12 +20,12 @@
#include "ck/library/utility/literals.hpp" #include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp"
using XDataType = ck::half_t; using XDataType = ck::half_t;
using GammaDataType = ck::half_t; using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t; using BetaDataType = ck::half_t;
using YDataType = ck::half_t; using YDataType = ck::half_t;
using AccDataType = float; using ComputeDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
constexpr int Rank = 2; constexpr int Rank = 2;
constexpr int NumReduceDim = 1; constexpr int NumReduceDim = 1;
...@@ -34,7 +34,7 @@ using DeviceInstance = ...@@ -34,7 +34,7 @@ using DeviceInstance =
ck::tensor_operation::device::DeviceNormalizationImpl<XDataType, ck::tensor_operation::device::DeviceNormalizationImpl<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
AccDataType, ComputeDataType,
YDataType, YDataType,
PassThrough, PassThrough,
Rank, Rank,
...@@ -121,7 +121,7 @@ int main() ...@@ -121,7 +121,7 @@ int main()
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
AccDataType, ComputeDataType,
PassThrough, PassThrough,
Rank, Rank,
NumReduceDim>; NumReduceDim>;
......
...@@ -23,11 +23,11 @@ ...@@ -23,11 +23,11 @@
constexpr int Rank = 5; constexpr int Rank = 5;
constexpr int NumReduceDim = 3; constexpr int NumReduceDim = 3;
using XDataType = ck::half_t; using XDataType = ck::half_t;
using GammaDataType = ck::half_t; using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t; using BetaDataType = ck::half_t;
using YDataType = ck::half_t; using YDataType = ck::half_t;
using AccDataType = float; using ComputeDataType = float;
struct YElementOp struct YElementOp
{ {
...@@ -50,7 +50,7 @@ using DeviceInstance = ...@@ -50,7 +50,7 @@ using DeviceInstance =
ck::tensor_operation::device::DeviceNormalizationImpl<XDataType, ck::tensor_operation::device::DeviceNormalizationImpl<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
AccDataType, ComputeDataType,
YDataType, YDataType,
YElementOp, YElementOp,
Rank, Rank,
...@@ -157,7 +157,7 @@ int main(int argc, char* argv[]) ...@@ -157,7 +157,7 @@ int main(int argc, char* argv[])
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
AccDataType, ComputeDataType,
YElementOp>; YElementOp>;
ReferenceInstance ref; ReferenceInstance ref;
......
...@@ -14,9 +14,9 @@ namespace device { ...@@ -14,9 +14,9 @@ namespace device {
template <typename XDataType, template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename AccDataType, typename ComputeDataType,
typename YDataType, typename YDataType,
typename AccElementwiseOperation, typename YElementwiseOperation,
index_t Rank, index_t Rank,
index_t NumReduceDim> index_t NumReduceDim>
struct DeviceNormalization : public BaseOperator struct DeviceNormalization : public BaseOperator
...@@ -35,7 +35,7 @@ struct DeviceNormalization : public BaseOperator ...@@ -35,7 +35,7 @@ struct DeviceNormalization : public BaseOperator
void* p_y, void* p_y,
void* p_savedMean, void* p_savedMean,
void* p_savedInvVar, void* p_savedInvVar,
AccElementwiseOperation acc_elementwise_op) = 0; YElementwiseOperation y_elementwise_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
...@@ -43,17 +43,17 @@ struct DeviceNormalization : public BaseOperator ...@@ -43,17 +43,17 @@ struct DeviceNormalization : public BaseOperator
template <typename XDataType, template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename AccDataType, typename ComputeDataType,
typename YDataType, typename YDataType,
typename AccElementwiseOperation, typename YElementwiseOperation,
index_t Rank, index_t Rank,
index_t NumReduceDim> index_t NumReduceDim>
using DeviceNormalizationPtr = std::unique_ptr<DeviceNormalization<XDataType, using DeviceNormalizationPtr = std::unique_ptr<DeviceNormalization<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
AccDataType, ComputeDataType,
YDataType, YDataType,
AccElementwiseOperation, YElementwiseOperation,
Rank, Rank,
NumReduceDim>>; NumReduceDim>>;
......
...@@ -10,46 +10,11 @@ ...@@ -10,46 +10,11 @@
#include "ck/tensor_operation/gpu/device/device_normalization.hpp" #include "ck/tensor_operation/gpu/device/device_normalization.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp" #include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_normalization_selector.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
namespace ck {
template <typename GridwiseReduction,
typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename YDataType,
typename AccDataType,
typename AccElementwiseOperation,
typename GridDesc_M_K>
__global__ void kernel_normalization(const GridDesc_M_K x_grid_desc_m_k,
const GridDesc_M_K gamma_grid_desc_m_k,
const GridDesc_M_K beta_grid_desc_m_k,
const GridDesc_M_K y_grid_desc_m_k,
index_t num_k_block_tile_iteration,
AccDataType epsilon,
const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global,
const AccElementwiseOperation acc_elementwise_op)
{
GridwiseReduction::Run(x_grid_desc_m_k,
gamma_grid_desc_m_k,
beta_grid_desc_m_k,
y_grid_desc_m_k,
num_k_block_tile_iteration,
epsilon,
p_x_global,
p_gamma_global,
p_beta_global,
p_y_global,
acc_elementwise_op);
};
} // namespace ck
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -58,9 +23,9 @@ namespace device { ...@@ -58,9 +23,9 @@ namespace device {
template <typename XDataType, template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename AccDataType, typename ComputeDataType,
typename YDataType, typename YDataType,
typename AccElementwiseOperation, typename YElementwiseOperation,
index_t Rank, index_t Rank,
index_t NumReduceDim, index_t NumReduceDim,
index_t BlockSize, index_t BlockSize,
...@@ -74,16 +39,18 @@ template <typename XDataType, ...@@ -74,16 +39,18 @@ template <typename XDataType,
index_t GammaSrcVectorSize, index_t GammaSrcVectorSize,
index_t BetaSrcVectorDim, index_t BetaSrcVectorDim,
index_t BetaSrcVectorSize, index_t BetaSrcVectorSize,
index_t YDstVectorSize> index_t YDstVectorSize,
bool UseWelford = true>
struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
AccDataType, ComputeDataType,
YDataType, YDataType,
AccElementwiseOperation, YElementwiseOperation,
Rank, Rank,
NumReduceDim> NumReduceDim>
{ {
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize);
static_assert( static_assert(
((GammaSrcVectorDim == 0 && MThreadSliceSize % GammaSrcVectorSize == 0) || ((GammaSrcVectorDim == 0 && MThreadSliceSize % GammaSrcVectorSize == 0) ||
(GammaSrcVectorDim == 1 && KThreadSliceSize % GammaSrcVectorSize == 0)), (GammaSrcVectorDim == 1 && KThreadSliceSize % GammaSrcVectorSize == 0)),
...@@ -167,51 +134,6 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -167,51 +134,6 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1)); using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1));
using GridwiseReduceLayernormGeneric =
GridwiseNormalizationWelfordVariance_mk_to_mk<XDataType,
GammaDataType,
BetaDataType,
YDataType,
AccDataType,
AccElementwiseOperation,
GridDesc_M_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XYSrcVectorDim,
XSrcVectorSize,
GammaSrcVectorDim,
GammaSrcVectorSize,
BetaSrcVectorDim,
BetaSrcVectorSize,
XYSrcVectorDim,
YDstVectorSize,
false>;
using GridwiseNormalizationSweepOnce =
GridwiseNormalizationWelfordVariance_mk_to_mk<XDataType,
GammaDataType,
BetaDataType,
YDataType,
AccDataType,
AccElementwiseOperation,
GridDesc_M_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XYSrcVectorDim,
XSrcVectorSize,
GammaSrcVectorDim,
GammaSrcVectorSize,
BetaSrcVectorDim,
BetaSrcVectorSize,
XYSrcVectorDim,
YDstVectorSize,
true>;
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const std::vector<index_t> lengths, Argument(const std::vector<index_t> lengths,
...@@ -220,7 +142,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -220,7 +142,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const std::vector<index_t> betaStrides, const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides, const std::vector<index_t> yStrides,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
AccElementwiseOperation acc_elementwise_op, YElementwiseOperation y_elementwise_op,
double epsilon, double epsilon,
const XDataType* p_x, const XDataType* p_x,
const GammaDataType* p_gamma, const GammaDataType* p_gamma,
...@@ -230,9 +152,9 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -230,9 +152,9 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
p_gamma_(p_gamma), p_gamma_(p_gamma),
p_beta_(p_beta), p_beta_(p_beta),
p_y_(p_y), p_y_(p_y),
acc_elementwise_op_(acc_elementwise_op) y_elementwise_op_(y_elementwise_op)
{ {
epsilon_ = static_cast<AccDataType>(epsilon); epsilon_ = static_cast<ComputeDataType>(epsilon);
Lengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(lengths, reduceDims); Lengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(lengths, reduceDims);
xStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(xStrides, reduceDims); xStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(xStrides, reduceDims);
...@@ -265,7 +187,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -265,7 +187,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
x_grid_desc_m_k_.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize; x_grid_desc_m_k_.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize;
} }
AccDataType epsilon_; ComputeDataType epsilon_;
const XDataType* p_x_; const XDataType* p_x_;
const GammaDataType* p_gamma_; const GammaDataType* p_gamma_;
...@@ -278,7 +200,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -278,7 +200,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
std::vector<index_t> betaStrides_; std::vector<index_t> betaStrides_;
std::vector<index_t> yStrides_; std::vector<index_t> yStrides_;
AccElementwiseOperation acc_elementwise_op_; YElementwiseOperation y_elementwise_op_;
int blkGroupSize_; int blkGroupSize_;
int numBlockTileIteration_; int numBlockTileIteration_;
...@@ -295,23 +217,27 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -295,23 +217,27 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
{ {
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
const auto kernel_main = arg.isSweeponce_ auto kernel_main = NormalizationKernelSelector<XDataType,
? kernel_normalization<GridwiseNormalizationSweepOnce, GammaDataType,
XDataType, BetaDataType,
GammaDataType, YDataType,
BetaDataType, ComputeDataType,
YDataType, YElementwiseOperation,
AccDataType, GridDesc_M_K,
AccElementwiseOperation, BlockSize,
GridDesc_M_K> MThreadClusterSize,
: kernel_normalization<GridwiseReduceLayernormGeneric, KThreadClusterSize,
XDataType, MThreadSliceSize,
GammaDataType, KThreadSliceSize,
BetaDataType, XYSrcVectorDim,
YDataType, XSrcVectorSize,
AccDataType, GammaSrcVectorDim,
AccElementwiseOperation, GammaSrcVectorSize,
GridDesc_M_K>; BetaSrcVectorDim,
BetaSrcVectorSize,
XYSrcVectorDim,
YDstVectorSize,
UseWelford>(arg.isSweeponce_);
float avg_time = 0; float avg_time = 0;
avg_time += launch_and_time_kernel(stream_config, avg_time += launch_and_time_kernel(stream_config,
...@@ -329,7 +255,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -329,7 +255,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
arg.p_gamma_, arg.p_gamma_,
arg.p_beta_, arg.p_beta_,
arg.p_y_, arg.p_y_,
arg.acc_elementwise_op_); arg.y_elementwise_op_);
return (avg_time); return (avg_time);
}; };
...@@ -429,7 +355,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -429,7 +355,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
void* p_y, void* p_y,
void* p_saveMean, void* p_saveMean,
void* p_saveInvVar, void* p_saveInvVar,
AccElementwiseOperation acc_elementwise_op) override YElementwiseOperation y_elementwise_op) override
{ {
// TODO // TODO
// Optional cache of the intermediate results (mean and InvVariance) during the // Optional cache of the intermediate results (mean and InvVariance) during the
...@@ -443,7 +369,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -443,7 +369,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
betaStrides, betaStrides,
yStrides, yStrides,
reduceDims, reduceDims,
acc_elementwise_op, y_elementwise_op,
epsilon, epsilon,
static_cast<const XDataType*>(p_x), static_cast<const XDataType*>(p_x),
static_cast<const GammaDataType*>(p_gamma), static_cast<const GammaDataType*>(p_gamma),
...@@ -462,8 +388,8 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -462,8 +388,8 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
// clang-format off // clang-format off
str << "DeviceNormalizationImpl<" << BlockSize << ","; str << "DeviceNormalizationImpl<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; str << "Cluster_MK_" << MThreadClusterSize << "_" << KThreadClusterSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; str << "Slice_MK_" << MThreadSliceSize << "_" << KThreadSliceSize << ",";
str << "XYSrcVectorDim_" << XYSrcVectorDim << ","; str << "XYSrcVectorDim_" << XYSrcVectorDim << ",";
str << "VectorSize_X" << XSrcVectorSize << "_Gamma" << GammaSrcVectorSize << "_Beta" << BetaSrcVectorSize << "_Y" << YDstVectorSize << ">"; str << "VectorSize_X" << XSrcVectorSize << "_Gamma" << GammaSrcVectorSize << "_Beta" << BetaSrcVectorSize << "_Y" << YDstVectorSize << ">";
// clang-format on // clang-format on
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp"
namespace ck {
template <typename GridwiseReduction,
typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename YDataType,
typename ComputeDataType,
typename YElementwiseOperation,
typename GridDesc_M_K>
__global__ void kernel_normalization(const GridDesc_M_K x_grid_desc_m_k,
const GridDesc_M_K gamma_grid_desc_m_k,
const GridDesc_M_K beta_grid_desc_m_k,
const GridDesc_M_K y_grid_desc_m_k,
index_t num_k_block_tile_iteration,
ComputeDataType epsilon,
const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global,
const YElementwiseOperation y_elementwise_op)
{
GridwiseReduction::Run(x_grid_desc_m_k,
gamma_grid_desc_m_k,
beta_grid_desc_m_k,
y_grid_desc_m_k,
num_k_block_tile_iteration,
epsilon,
p_x_global,
p_gamma_global,
p_beta_global,
p_y_global,
y_elementwise_op);
};
template <typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename YDataType,
typename ComputeDataType,
typename YElementwiseOperation,
typename GridDesc_M_K,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t XSrcVectorDim,
index_t XSrcVectorSize,
index_t GammaSrcVectorDim,
index_t GammaSrcVectorSize,
index_t BetaSrcVectorDim,
index_t BetaSrcVectorSize,
index_t YDstVectorDim,
index_t YDstVectorSize,
bool UseWelford>
auto NormalizationKernelSelector(bool isSweepOnce)
{
using GridwiseNormalizationGenericNaive =
GridwiseNormalizationNaiveVariance_mk_to_mk<XDataType,
GammaDataType,
BetaDataType,
YDataType,
ComputeDataType,
YElementwiseOperation,
GridDesc_M_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XSrcVectorDim,
XSrcVectorSize,
GammaSrcVectorDim,
GammaSrcVectorSize,
BetaSrcVectorDim,
BetaSrcVectorSize,
YDstVectorDim,
YDstVectorSize,
false>;
using GridwiseNormalizationSweepOnceNaive =
GridwiseNormalizationNaiveVariance_mk_to_mk<XDataType,
GammaDataType,
BetaDataType,
YDataType,
ComputeDataType,
YElementwiseOperation,
GridDesc_M_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XSrcVectorDim,
XSrcVectorSize,
GammaSrcVectorDim,
GammaSrcVectorSize,
BetaSrcVectorDim,
BetaSrcVectorSize,
YDstVectorDim,
YDstVectorSize,
true>;
using GridwiseNormalizationGenericWelford =
GridwiseNormalizationWelfordVariance_mk_to_mk<XDataType,
GammaDataType,
BetaDataType,
YDataType,
ComputeDataType,
YElementwiseOperation,
GridDesc_M_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XSrcVectorDim,
XSrcVectorSize,
GammaSrcVectorDim,
GammaSrcVectorSize,
BetaSrcVectorDim,
BetaSrcVectorSize,
YDstVectorDim,
YDstVectorSize,
false>;
using GridwiseNormalizationSweepOnceWelford =
GridwiseNormalizationWelfordVariance_mk_to_mk<XDataType,
GammaDataType,
BetaDataType,
YDataType,
ComputeDataType,
YElementwiseOperation,
GridDesc_M_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XSrcVectorDim,
XSrcVectorSize,
GammaSrcVectorDim,
GammaSrcVectorSize,
BetaSrcVectorDim,
BetaSrcVectorSize,
YDstVectorDim,
YDstVectorSize,
true>;
if constexpr(UseWelford)
{
return isSweepOnce ? kernel_normalization<GridwiseNormalizationSweepOnceWelford,
XDataType,
GammaDataType,
BetaDataType,
YDataType,
ComputeDataType,
YElementwiseOperation,
GridDesc_M_K>
: kernel_normalization<GridwiseNormalizationGenericWelford,
XDataType,
GammaDataType,
BetaDataType,
YDataType,
ComputeDataType,
YElementwiseOperation,
GridDesc_M_K>;
}
else
{
return isSweepOnce ? kernel_normalization<GridwiseNormalizationSweepOnceNaive,
XDataType,
GammaDataType,
BetaDataType,
YDataType,
ComputeDataType,
YElementwiseOperation,
GridDesc_M_K>
: kernel_normalization<GridwiseNormalizationGenericNaive,
XDataType,
GammaDataType,
BetaDataType,
YDataType,
ComputeDataType,
YElementwiseOperation,
GridDesc_M_K>;
}
}
} // namespace ck
...@@ -83,6 +83,11 @@ static inline __host__ bool isnan(int4_t x) ...@@ -83,6 +83,11 @@ static inline __host__ bool isnan(int4_t x)
}; };
#endif #endif
static inline __host__ half_t sqrt(half_t x)
{
return static_cast<half_t>(std::sqrt(static_cast<float>(x)));
};
static inline __host__ float sqrt(float x) { return std::sqrt(x); }; static inline __host__ float sqrt(float x) { return std::sqrt(x); };
static inline __host__ double sqrt(double x) { return std::sqrt(x); }; static inline __host__ double sqrt(double x) { return std::sqrt(x); };
...@@ -158,6 +163,11 @@ static inline __device__ bool isnan(half_t x) ...@@ -158,6 +163,11 @@ static inline __device__ bool isnan(half_t x)
return (xx & 0x7FFF) > 0x7C00; return (xx & 0x7FFF) > 0x7C00;
}; };
static inline __device__ half_t sqrt(half_t x)
{
return static_cast<half_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
};
static inline __device__ float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); }; static inline __device__ float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); };
static inline __device__ double sqrt(double x) { return __builtin_amdgcn_sqrt(x); }; static inline __device__ double sqrt(double x) { return __builtin_amdgcn_sqrt(x); };
......
...@@ -21,20 +21,25 @@ template <typename OutElementwise, index_t Rank, index_t Reduce> ...@@ -21,20 +21,25 @@ template <typename OutElementwise, index_t Rank, index_t Reduce>
// clang-format off // clang-format off
using device_normalization_f16_instances = using device_normalization_f16_instances =
std::tuple < std::tuple <
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize> // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize>
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 8, 32, 1, 8, 1, 1, 1, 1, 1, 1, 1>, // fallback kernel DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 8, 32, 1, 8, 1, 2, 1, 2, 1, 2, 2>, // fallback kernel DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 8, 32, 1, 8, 1, 4, 1, 4, 1, 4, 4>, // fallback kernel DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 8, 32, 1, 8, 1, 8, 1, 8, 1, 8, 8>, DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 4, 64, 1, 8, 1, 8, 1, 8, 1, 8, 8>, DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2>, // irregular size
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 2, 128, 1, 8, 1, 8, 1, 8, 1, 8, 8>, DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4>, // irregular size
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 2, 128, 1, 16, 1, 8, 1, 8, 1, 8, 8>, DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 2, 128, 1, 32, 1, 8, 1, 8, 1, 8, 8>, DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 1, 8, 1, 8, 8>, DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 8, 1, 8, 1, 8, 8>, DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 8, 1, 8, 1, 8, 8>, DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 32, 1, 8, 1, 8, 1, 8, 8>, DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 2, 1, 2, 1, 2, 2> DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 16, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 16, 1, 8, 1, 8, 1, 8, 8>
>; >;
// clang-format on // clang-format on
......
...@@ -19,17 +19,26 @@ using Pass = ck::tensor_operation::element_wise::PassThrough; ...@@ -19,17 +19,26 @@ using Pass = ck::tensor_operation::element_wise::PassThrough;
template <typename OutElementwise, index_t Rank, index_t Reduce> template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_layernorm_f32_instances = std::tuple< using device_layernorm_f32_instances = std::tuple<
// clang-format off // clang-format off
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize> // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 8, 32, 1, 8, 1, 1, 1, 1, 1, 1, 1>, // fallback kernel DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 8, 32, 1, 8, 1, 2, 1, 2, 1, 2, 2>, // fallback kernel DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 8, 32, 1, 8, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 4, 64, 1, 8, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 2, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2>, // irregular size
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 2, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 2, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4>, DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4> DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 2, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 4, 1, 4, 1, 4, 4>
// clang-format on // clang-format on
>; >;
......
...@@ -19,7 +19,7 @@ namespace profiler { ...@@ -19,7 +19,7 @@ namespace profiler {
template <typename XDataType, template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename AccDataType, typename ComputeDataType,
typename YDataType, typename YDataType,
index_t Rank> index_t Rank>
bool profile_layernorm_impl(int do_verification, bool profile_layernorm_impl(int do_verification,
...@@ -86,7 +86,7 @@ bool profile_layernorm_impl(int do_verification, ...@@ -86,7 +86,7 @@ bool profile_layernorm_impl(int do_verification,
using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType, using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
AccDataType, ComputeDataType,
YDataType, YDataType,
PassThrough, PassThrough,
Rank, Rank,
...@@ -109,7 +109,7 @@ bool profile_layernorm_impl(int do_verification, ...@@ -109,7 +109,7 @@ bool profile_layernorm_impl(int do_verification,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
AccDataType, ComputeDataType,
PassThrough, PassThrough,
Rank, Rank,
NumReduceDim>; NumReduceDim>;
...@@ -181,8 +181,8 @@ bool profile_layernorm_impl(int do_verification, ...@@ -181,8 +181,8 @@ bool profile_layernorm_impl(int do_verification,
{ {
y_dev.FromDevice(y.mData.data()); y_dev.FromDevice(y.mData.data());
bool pass = ck::utils::check_err( bool pass =
y.mData, host_y.mData, "Error: Incorrect results d1", 1e-3, 1e-3); ck::utils::check_err(y.mData, host_y.mData, "Error: Incorrect results", 1e-3, 1e-3);
if(do_log) if(do_log)
{ {
......
...@@ -12,11 +12,11 @@ template <typename Tuple> ...@@ -12,11 +12,11 @@ template <typename Tuple>
class TestGroupnorm : public ::testing::Test class TestGroupnorm : public ::testing::Test
{ {
protected: protected:
using XDataType = std::tuple_element_t<0, Tuple>; using XDataType = std::tuple_element_t<0, Tuple>;
using GammaDataType = std::tuple_element_t<1, Tuple>; using GammaDataType = std::tuple_element_t<1, Tuple>;
using BetaDataType = std::tuple_element_t<2, Tuple>; using BetaDataType = std::tuple_element_t<2, Tuple>;
using AccDataType = std::tuple_element_t<3, Tuple>; using ComputeDataType = std::tuple_element_t<3, Tuple>;
using YDataType = std::tuple_element_t<4, Tuple>; using YDataType = std::tuple_element_t<4, Tuple>;
void Run() void Run()
{ {
...@@ -36,7 +36,7 @@ class TestGroupnorm : public ::testing::Test ...@@ -36,7 +36,7 @@ class TestGroupnorm : public ::testing::Test
ck::profiler::profile_groupnorm_impl<XDataType, ck::profiler::profile_groupnorm_impl<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
AccDataType, ComputeDataType,
YDataType>(true, 2, false, false, length); YDataType>(true, 2, false, false, length);
EXPECT_TRUE(success); EXPECT_TRUE(success);
} }
...@@ -44,7 +44,7 @@ class TestGroupnorm : public ::testing::Test ...@@ -44,7 +44,7 @@ class TestGroupnorm : public ::testing::Test
}; };
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType> // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType>
std::tuple<F16, F16, F16, F32, F16>>; std::tuple<F16, F16, F16, F32, F16>>;
TYPED_TEST_SUITE(TestGroupnorm, KernelTypes); TYPED_TEST_SUITE(TestGroupnorm, KernelTypes);
......
...@@ -12,11 +12,11 @@ template <typename Tuple> ...@@ -12,11 +12,11 @@ template <typename Tuple>
class TestGroupnorm : public ::testing::Test class TestGroupnorm : public ::testing::Test
{ {
protected: protected:
using XDataType = std::tuple_element_t<0, Tuple>; using XDataType = std::tuple_element_t<0, Tuple>;
using GammaDataType = std::tuple_element_t<1, Tuple>; using GammaDataType = std::tuple_element_t<1, Tuple>;
using BetaDataType = std::tuple_element_t<2, Tuple>; using BetaDataType = std::tuple_element_t<2, Tuple>;
using AccDataType = std::tuple_element_t<3, Tuple>; using ComputeDataType = std::tuple_element_t<3, Tuple>;
using YDataType = std::tuple_element_t<4, Tuple>; using YDataType = std::tuple_element_t<4, Tuple>;
void Run() void Run()
{ {
...@@ -34,7 +34,7 @@ class TestGroupnorm : public ::testing::Test ...@@ -34,7 +34,7 @@ class TestGroupnorm : public ::testing::Test
ck::profiler::profile_groupnorm_impl<XDataType, ck::profiler::profile_groupnorm_impl<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
AccDataType, ComputeDataType,
YDataType>(true, 2, false, false, length); YDataType>(true, 2, false, false, length);
EXPECT_TRUE(success); EXPECT_TRUE(success);
} }
...@@ -42,7 +42,7 @@ class TestGroupnorm : public ::testing::Test ...@@ -42,7 +42,7 @@ class TestGroupnorm : public ::testing::Test
}; };
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType> // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType>
std::tuple<F32, F32, F32, F32, F32>>; std::tuple<F32, F32, F32, F32, F32>>;
TYPED_TEST_SUITE(TestGroupnorm, KernelTypes); TYPED_TEST_SUITE(TestGroupnorm, KernelTypes);
......
...@@ -12,11 +12,11 @@ template <typename Tuple> ...@@ -12,11 +12,11 @@ template <typename Tuple>
class TestLayernorm2d : public ::testing::Test class TestLayernorm2d : public ::testing::Test
{ {
protected: protected:
using XDataType = std::tuple_element_t<0, Tuple>; using XDataType = std::tuple_element_t<0, Tuple>;
using GammaDataType = std::tuple_element_t<1, Tuple>; using GammaDataType = std::tuple_element_t<1, Tuple>;
using BetaDataType = std::tuple_element_t<2, Tuple>; using BetaDataType = std::tuple_element_t<2, Tuple>;
using AccDataType = std::tuple_element_t<3, Tuple>; using ComputeDataType = std::tuple_element_t<3, Tuple>;
using YDataType = std::tuple_element_t<4, Tuple>; using YDataType = std::tuple_element_t<4, Tuple>;
void Run() void Run()
{ {
...@@ -29,7 +29,7 @@ class TestLayernorm2d : public ::testing::Test ...@@ -29,7 +29,7 @@ class TestLayernorm2d : public ::testing::Test
bool success = ck::profiler::profile_layernorm_impl<XDataType, bool success = ck::profiler::profile_layernorm_impl<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
AccDataType, ComputeDataType,
YDataType, YDataType,
2>(true, 2, false, false, length); 2>(true, 2, false, false, length);
EXPECT_TRUE(success); EXPECT_TRUE(success);
...@@ -38,7 +38,7 @@ class TestLayernorm2d : public ::testing::Test ...@@ -38,7 +38,7 @@ class TestLayernorm2d : public ::testing::Test
}; };
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType> // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType>
std::tuple<F16, F16, F16, F32, F16>>; std::tuple<F16, F16, F16, F32, F16>>;
TYPED_TEST_SUITE(TestLayernorm2d, KernelTypes); TYPED_TEST_SUITE(TestLayernorm2d, KernelTypes);
......
...@@ -12,11 +12,11 @@ template <typename Tuple> ...@@ -12,11 +12,11 @@ template <typename Tuple>
class TestLayernorm2d : public ::testing::Test class TestLayernorm2d : public ::testing::Test
{ {
protected: protected:
using XDataType = std::tuple_element_t<0, Tuple>; using XDataType = std::tuple_element_t<0, Tuple>;
using GammaDataType = std::tuple_element_t<1, Tuple>; using GammaDataType = std::tuple_element_t<1, Tuple>;
using BetaDataType = std::tuple_element_t<2, Tuple>; using BetaDataType = std::tuple_element_t<2, Tuple>;
using AccDataType = std::tuple_element_t<3, Tuple>; using ComputeDataType = std::tuple_element_t<3, Tuple>;
using YDataType = std::tuple_element_t<4, Tuple>; using YDataType = std::tuple_element_t<4, Tuple>;
void Run() void Run()
{ {
...@@ -29,7 +29,7 @@ class TestLayernorm2d : public ::testing::Test ...@@ -29,7 +29,7 @@ class TestLayernorm2d : public ::testing::Test
bool success = ck::profiler::profile_layernorm_impl<XDataType, bool success = ck::profiler::profile_layernorm_impl<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
AccDataType, ComputeDataType,
YDataType, YDataType,
2>(true, 2, false, false, length); 2>(true, 2, false, false, length);
EXPECT_TRUE(success); EXPECT_TRUE(success);
...@@ -38,7 +38,7 @@ class TestLayernorm2d : public ::testing::Test ...@@ -38,7 +38,7 @@ class TestLayernorm2d : public ::testing::Test
}; };
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType> // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType>
std::tuple<F32, F32, F32, F32, F32>>; std::tuple<F32, F32, F32, F32, F32>>;
TYPED_TEST_SUITE(TestLayernorm2d, KernelTypes); TYPED_TEST_SUITE(TestLayernorm2d, KernelTypes);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment