Commit 479ec5d9 authored by Astha Rai's avatar Astha Rai
Browse files

Merge branch 'gridwise_2d' of github.com:ROCmSoftwarePlatform/composable_kernel into gridwise_2d

parents 73536f4f 2c4305b2
......@@ -48,7 +48,7 @@ class BatchnormFwdArgParser
std::cout << "--inOutLengths or -D, comma separated list of input tensor dimension lengths, must have 4 integers for nhwc" << std::endl;
std::cout << "--reduceDims or -R, comma separated list of dimensions to reduce on" << std::endl;
std::cout << "--verify or -v, 1/0 to indicate whether to verify the result by comparing with the host-based batch-normalization" << std::endl;
std::cout << "Arg1: data type (0: fp16, 1: fp32, 3: int8, 5: bp16, 6: fp64)" << std::endl;
std::cout << "Arg1: data type (0: fp16, 1: fp32, 5: bp16, 6: fp64)" << std::endl;
std::cout << "Arg2: 1/0 to indicate whether to update the moving average and variance (0=no, 1=yes)" << std::endl;
std::cout << "Arg3: 1/0 to indicate whether to save the calculated mean and invVariance (0=no, 1=yes)" << std::endl;
std::cout << "Arg4: init method used for bnScale and bnBias (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value)" << std::endl;
......@@ -141,7 +141,6 @@ int profile_batchnorm_forward(int argc, char* argv[])
using F16 = ck::half_t;
using F32 = float;
using BF16 = ck::bhalf_t;
using I8 = int8_t;
using F64 = double;
if(arg_parser.data_type == 0)
......@@ -178,23 +177,6 @@ int profile_batchnorm_forward(int argc, char* argv[])
averageFactor);
};
}
else if(arg_parser.data_type == 3)
{
if(arg_parser.inLengths.size() == 4 && arg_parser.reduceDims.size() == 3)
{
profile_batchnorm_forward_impl<I8, I8, F32, I8, I8, F32, 4, 3>(
arg_parser.do_verification,
arg_parser.init_method,
arg_parser.do_dumpout,
arg_parser.time_kernel,
arg_parser.inLengths,
arg_parser.reduceDims,
arg_parser.updateMovingAverage,
arg_parser.saveMeanAndInvVariance,
epsilon,
averageFactor);
};
}
else if(arg_parser.data_type == 5)
{
if(arg_parser.inLengths.size() == 4 && arg_parser.reduceDims.size() == 3)
......
......@@ -90,7 +90,6 @@ class TestBatchNormFwdRank4 : public ::testing::Test
using KernelTypes = ::testing::Types<std::tuple<F16, F16, F32, F16, F16, F32>,
std::tuple<F32, F32, F32, F32, F32, F32>,
std::tuple<BF16, BF16, F32, BF16, BF16, F32>,
std::tuple<I8, I8, F32, I8, I8, F32>,
std::tuple<F64, F64, F64, F64, F64, F64>>;
TYPED_TEST_SUITE(TestBatchNormFwdRank4, KernelTypes);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment