Commit 7df98b04 authored by rocking's avatar rocking
Browse files

Change save mean var type from f32 to f16 in f16 mode

parent de1cb1ed
......@@ -16,7 +16,7 @@ using XDataType = ck::half_t;
using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t;
using YDataType = ck::half_t;
using SaveMeanInvStdDataType = float;
using SaveMeanInvStdDataType = ck::half_t;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
#define SAVE_MEAN_INV_STD
......
......@@ -16,7 +16,7 @@ using XDataType = ck::half_t;
using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t;
using YDataType = ck::half_t;
using SaveMeanInvStdDataType = float;
using SaveMeanInvStdDataType = ck::half_t;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
#define SAVE_MEAN_INV_STD
......
......@@ -20,15 +20,15 @@ namespace instance {
// FP16
void add_device_normalization_fwd_rank_2_1_f16_instances(
std::vector<
std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, PassThrough, 2, 1>>>&);
std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F16, PassThrough, 2, 1>>>&);
void add_device_normalization_fwd_rank_4_3_f16_instances(
std::vector<
std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, PassThrough, 4, 3>>>&);
std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F16, PassThrough, 4, 3>>>&);
void add_device_normalization_fwd_rank_5_3_f16_instances(
std::vector<
std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, PassThrough, 5, 3>>>&);
std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F16, PassThrough, 5, 3>>>&);
#endif
#ifdef CK_ENABLE_FP32
// FP32
......@@ -76,7 +76,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<XDataType, F16> && is_same_v<GammaDataType, F16> &&
is_same_v<BetaDataType, F16> && is_same_v<YDataType, F16> &&
is_same_v<SaveMeanInvStdDataType, F32>)
is_same_v<SaveMeanInvStdDataType, F16>)
{
if constexpr(Rank == 2 && NumReduceDim == 1)
{
......
......@@ -19,7 +19,7 @@ namespace instance {
// FP16
void add_device_normalization_fwd_rank_5_3_swish_f16_instances(
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, Swish, 5, 3>>>&);
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F16, Swish, 5, 3>>>&);
// FP32
void add_device_normalization_fwd_rank_5_3_swish_f32_instances(
......@@ -61,7 +61,7 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<XDataType, F16> && is_same_v<GammaDataType, F16> &&
is_same_v<BetaDataType, F16> && is_same_v<YDataType, F16> &&
is_same_v<SaveMeanInvStdDataType, F32>)
is_same_v<SaveMeanInvStdDataType, F16>)
{
if constexpr(Rank == 5 && NumReduceDim == 3)
{
......
......@@ -11,7 +11,7 @@ namespace instance {
using Pass = ck::tensor_operation::element_wise::PassThrough;
void add_device_normalization_fwd_rank_5_3_f16_instances(
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, Pass, 5, 3>>>&
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F16, Pass, 5, 3>>>&
instances)
{
add_device_operation_instances(instances,
......
......@@ -11,7 +11,7 @@ namespace instance {
using Swish = ck::tensor_operation::element_wise::Swish;
void add_device_normalization_fwd_rank_5_3_swish_f16_instances(
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, Swish, 5, 3>>>&
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F16, Swish, 5, 3>>>&
instances)
{
add_device_operation_instances(instances,
......
......@@ -11,7 +11,7 @@ namespace instance {
using Pass = ck::tensor_operation::element_wise::PassThrough;
void add_device_normalization_fwd_rank_2_1_f16_instances(
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, Pass, 2, 1>>>&
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F16, Pass, 2, 1>>>&
instances)
{
add_device_operation_instances(instances,
......
......@@ -11,7 +11,7 @@ namespace instance {
using Pass = ck::tensor_operation::element_wise::PassThrough;
void add_device_normalization_fwd_rank_4_3_f16_instances(
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, Pass, 4, 3>>>&
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F16, Pass, 4, 3>>>&
instances)
{
add_device_operation_instances(instances,
......
......@@ -23,24 +23,24 @@ using device_normalization_f16_instances =
// clang-format off
std::tuple <
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector>
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 8, 1, 8, 1, 8, 8, 2>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 8, 1, 8, 1, 8, 8, 2>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>
// clang-format on
>;
......@@ -49,31 +49,31 @@ using device_normalization_splitk_f16_instances =
// clang-format off
std::tuple <
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector>
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 8, 1, 8, 1, 8, 8, 2>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 8, 1, 8, 1, 8, 8, 2>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>
// clang-format on
>;
template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_f16_generic_instance = std::tuple<
// clang-format off
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>
// clang-format on
>;
......
......@@ -98,7 +98,7 @@ int profile_groupnorm(int argc, char* argv[])
}
else if(data_type == ck::DataTypeEnum::Half)
{
ck::profiler::profile_groupnorm_impl<F16, F16, F16, F32, F16, F32, false>(
ck::profiler::profile_groupnorm_impl<F16, F16, F16, F32, F16, F16, false>(
do_verification, init_method, do_log, time_kernel, length);
}
else
......
......@@ -104,7 +104,7 @@ int profile_layernorm(int argc, char* argv[])
if(data_type == ck::DataTypeEnum::Half)
{
ck::profiler::profile_layernorm_impl<F16, F16, F16, F32, F16, F32, false, rank>(
ck::profiler::profile_layernorm_impl<F16, F16, F16, F32, F16, F16, false, rank>(
do_verification, init_method, do_log, time_kernel, length);
}
else if(data_type == ck::DataTypeEnum::Float)
......
......@@ -47,8 +47,8 @@ class TestGroupnorm : public ::testing::Test
};
using KernelTypes = ::testing::Types<
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType>
std::tuple<F16, F16, F16, F32, F16, F32>>;
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType>
std::tuple<F16, F16, F16, F32, F16, F16>>;
TYPED_TEST_SUITE(TestGroupnorm, KernelTypes);
TYPED_TEST(TestGroupnorm, Test_FP16) { this->Run(); }
......@@ -45,7 +45,7 @@ class TestGroupnorm : public ::testing::Test
};
using KernelTypes = ::testing::Types<
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType>
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType>
std::tuple<F32, F32, F32, F32, F32, F32>>;
TYPED_TEST_SUITE(TestGroupnorm, KernelTypes);
......
......@@ -41,8 +41,8 @@ class TestLayernorm2d : public ::testing::Test
};
using KernelTypes = ::testing::Types<
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType>
std::tuple<F16, F16, F16, F32, F16, F32>>;
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType>
std::tuple<F16, F16, F16, F32, F16, F16>>;
TYPED_TEST_SUITE(TestLayernorm2d, KernelTypes);
TYPED_TEST(TestLayernorm2d, Test_FP16) { this->Run(); }
......@@ -41,8 +41,8 @@ class TestLayernorm4d : public ::testing::Test
};
using KernelTypes = ::testing::Types<
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType>
std::tuple<F16, F16, F16, F32, F16, F32>>;
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType>
std::tuple<F16, F16, F16, F32, F16, F16>>;
TYPED_TEST_SUITE(TestLayernorm4d, KernelTypes);
TYPED_TEST(TestLayernorm4d, Test_FP16) { this->Run(); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment