Unverified Commit 5bf0475a authored by Qianfeng's avatar Qianfeng Committed by GitHub
Browse files

Remove int8 from batchnorm-forward instances since it is not needed for...

Remove int8 from batchnorm-forward instances since it is not needed for forward training and could fail test (#516)
parent 4e6a5575
......@@ -31,10 +31,6 @@ void add_device_batchnorm_forward_rank_4_3_bf16_instances(
std::vector<
std::unique_ptr<DeviceBatchNormFwd<BF16, BF16, F32, BF16, BF16, F32, PassThrough, 4, 3>>>&);
// Int8
void add_device_batchnorm_forward_rank_4_3_i8_instances(
std::vector<std::unique_ptr<DeviceBatchNormFwd<I8, I8, F32, I8, I8, F32, PassThrough, 4, 3>>>&);
// FP64
void add_device_batchnorm_forward_rank_4_3_f64_instances(
std::vector<
......@@ -101,15 +97,6 @@ struct DeviceOperationInstanceFactory<
add_device_batchnorm_forward_rank_4_3_bf16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, I8> && is_same_v<YDataType, I8> &&
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, I8> &&
is_same_v<BiasDataType, I8> && is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>)
{
add_device_batchnorm_forward_rank_4_3_i8_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F64> && is_same_v<YDataType, F64> &&
is_same_v<AccDataType, F64> && is_same_v<ScaleDataType, F64> &&
is_same_v<BiasDataType, F64> && is_same_v<MeanVarDataType, F64>)
......
......@@ -2,6 +2,5 @@ add_instance_library(device_batchnorm_instance
device_batchnorm_forward_f16_instance.cpp
device_batchnorm_forward_f32_instance.cpp
device_batchnorm_forward_bf16_instance.cpp
device_batchnorm_forward_i8_instance.cpp
device_batchnorm_forward_f64_instance.cpp
)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using I8 = int8_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// clang-format off
template <index_t Rank, index_t NumReduceDim, typename YElementwiseOp>
using device_batchnorm_forward_i8_blockwise_instances = std::tuple<
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1>
>;
// clang-format on
// clang-format off
template <index_t Rank, index_t NumReduceDim, typename YElementwiseOp>
using device_batchnorm_forward_i8_multiblock_instances =
std::tuple <
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1>
>;
// clang-format on
void add_device_batchnorm_forward_rank_4_3_i8_instances(
std::vector<std::unique_ptr<DeviceBatchNormFwd<I8, I8, F32, I8, I8, F32, PassThrough, 4, 3>>>&
instances)
{
add_device_operation_instances(
instances, device_batchnorm_forward_i8_blockwise_instances<4, 3, PassThrough>{});
add_device_operation_instances(
instances, device_batchnorm_forward_i8_multiblock_instances<4, 3, PassThrough>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -85,39 +85,22 @@ bool profile_batchnorm_forward_impl(int do_verification,
if(updateMovingAverage)
{
if constexpr(ck::is_same_v<XDataType, int8_t>)
{
x.GenerateTensorValue(GeneratorTensor_2<XDataType>{-5, 5}, num_thread);
const float x_mean = 0.0f;
const float x_stddev = 2.5f;
const float noise_stddev = 0.04f;
resultRunningMean_ref.GenerateTensorValue(
GeneratorTensor_4<MeanVarDataType>{x_mean, noise_stddev}, num_thread);
resultRunningVariance_ref.GenerateTensorValue(
GeneratorTensor_4<MeanVarDataType>{x_stddev * x_stddev, noise_stddev}, num_thread);
}
else
{
const float x_mean = 0.0f;
const float x_stddev = 1.0f;
const float noise_stddev = 0.04f;
// input data in normal distribution
x.GenerateTensorValue(GeneratorTensor_4<XDataType>{x_mean, x_stddev}, num_thread);
// initialize the runningMean to be values with tiny variation to the mean of the x
// values
resultRunningMean_ref.GenerateTensorValue(
GeneratorTensor_4<MeanVarDataType>{x_mean, noise_stddev}, num_thread);
// initialize the runningVariance to be values with tiny variation to the variance of
// the x values
resultRunningVariance_ref.GenerateTensorValue(
GeneratorTensor_4<MeanVarDataType>{x_stddev * x_stddev, noise_stddev}, num_thread);
};
const float x_mean = 0.0f;
const float x_stddev = 1.0f;
const float noise_stddev = 0.04f;
// input data in normal distribution
x.GenerateTensorValue(GeneratorTensor_4<XDataType>{x_mean, x_stddev}, num_thread);
// initialize the runningMean to be values with tiny variation to the mean of the x
// values
resultRunningMean_ref.GenerateTensorValue(
GeneratorTensor_4<MeanVarDataType>{x_mean, noise_stddev}, num_thread);
// initialize the runningVariance to be values with tiny variation to the variance of
// the x values
resultRunningVariance_ref.GenerateTensorValue(
GeneratorTensor_4<MeanVarDataType>{x_stddev * x_stddev, noise_stddev}, num_thread);
}
else
{
......@@ -129,35 +112,24 @@ bool profile_batchnorm_forward_impl(int do_verification,
if(do_verification)
{
if constexpr(ck::is_same_v<ScaleDataType, int8_t> && ck::is_same_v<BiasDataType, int8_t>)
switch(init_method)
{
case 0:
bnScale.GenerateTensorValue(GeneratorTensor_0<ScaleDataType>{}, num_thread);
bnBias.GenerateTensorValue(GeneratorTensor_0<BiasDataType>{}, num_thread);
break;
case 1:
bnScale.GenerateTensorValue(GeneratorTensor_1<ScaleDataType>{1}, num_thread);
bnBias.GenerateTensorValue(GeneratorTensor_1<BiasDataType>{0}, num_thread);
break;
case 2:
bnScale.GenerateTensorValue(GeneratorTensor_2<ScaleDataType>{-5, 5}, num_thread);
bnBias.GenerateTensorValue(GeneratorTensor_2<BiasDataType>{-5, 5}, num_thread);
break;
default:
bnScale.GenerateTensorValue(GeneratorTensor_3<ScaleDataType>{-1.0f, 1.0f}, num_thread);
bnBias.GenerateTensorValue(GeneratorTensor_3<BiasDataType>{-1.0f, 1.0f}, num_thread);
}
else
{
switch(init_method)
{
case 0:
bnScale.GenerateTensorValue(GeneratorTensor_0<ScaleDataType>{}, num_thread);
bnBias.GenerateTensorValue(GeneratorTensor_0<BiasDataType>{}, num_thread);
break;
case 1:
bnScale.GenerateTensorValue(GeneratorTensor_1<ScaleDataType>{1}, num_thread);
bnBias.GenerateTensorValue(GeneratorTensor_1<BiasDataType>{0}, num_thread);
break;
case 2:
bnScale.GenerateTensorValue(GeneratorTensor_2<ScaleDataType>{-5, 5}, num_thread);
bnBias.GenerateTensorValue(GeneratorTensor_2<BiasDataType>{-5, 5}, num_thread);
break;
default:
bnScale.GenerateTensorValue(GeneratorTensor_3<ScaleDataType>{-1.0f, 1.0f},
num_thread);
bnBias.GenerateTensorValue(GeneratorTensor_3<BiasDataType>{-1.0f, 1.0f},
num_thread);
}
};
};
// these buffers are usually provided by the user application
......
......@@ -48,7 +48,7 @@ class BatchnormFwdArgParser
std::cout << "--inOutLengths or -D, comma separated list of input tensor dimension lengths, must have 4 integers for nhwc" << std::endl;
std::cout << "--reduceDims or -R, comma separated list of dimensions to reduce on" << std::endl;
std::cout << "--verify or -v, 1/0 to indicate whether to verify the result by comparing with the host-based batch-normalization" << std::endl;
std::cout << "Arg1: data type (0: fp16, 1: fp32, 3: int8, 5: bp16, 6: fp64)" << std::endl;
std::cout << "Arg1: data type (0: fp16, 1: fp32, 5: bp16, 6: fp64)" << std::endl;
std::cout << "Arg2: 1/0 to indicate whether to update the moving average and variance (0=no, 1=yes)" << std::endl;
std::cout << "Arg3: 1/0 to indicate whether to save the calculated mean and invVariance (0=no, 1=yes)" << std::endl;
std::cout << "Arg4: init method used for bnScale and bnBias (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value)" << std::endl;
......@@ -141,7 +141,6 @@ int profile_batchnorm_forward(int argc, char* argv[])
using F16 = ck::half_t;
using F32 = float;
using BF16 = ck::bhalf_t;
using I8 = int8_t;
using F64 = double;
if(arg_parser.data_type == 0)
......@@ -178,23 +177,6 @@ int profile_batchnorm_forward(int argc, char* argv[])
averageFactor);
};
}
else if(arg_parser.data_type == 3)
{
if(arg_parser.inLengths.size() == 4 && arg_parser.reduceDims.size() == 3)
{
profile_batchnorm_forward_impl<I8, I8, F32, I8, I8, F32, 4, 3>(
arg_parser.do_verification,
arg_parser.init_method,
arg_parser.do_dumpout,
arg_parser.time_kernel,
arg_parser.inLengths,
arg_parser.reduceDims,
arg_parser.updateMovingAverage,
arg_parser.saveMeanAndInvVariance,
epsilon,
averageFactor);
};
}
else if(arg_parser.data_type == 5)
{
if(arg_parser.inLengths.size() == 4 && arg_parser.reduceDims.size() == 3)
......
......@@ -90,7 +90,6 @@ class TestBatchNormFwdRank4 : public ::testing::Test
using KernelTypes = ::testing::Types<std::tuple<F16, F16, F32, F16, F16, F32>,
std::tuple<F32, F32, F32, F32, F32, F32>,
std::tuple<BF16, BF16, F32, BF16, BF16, F32>,
std::tuple<I8, I8, F32, I8, I8, F32>,
std::tuple<F64, F64, F64, F64, F64, F64>>;
TYPED_TEST_SUITE(TestBatchNormFwdRank4, KernelTypes);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment