Unverified Commit 8f8a2ce3 authored by jakpiase's avatar jakpiase Committed by GitHub
Browse files

Add pool2d int8 and fp8 instances (#1508)

* add pool2d fp8 and int8

* minor fixes

* add formatting

* add reviewer suggestions

* add reviewer suggestions
parent a4982c3b
...@@ -67,6 +67,36 @@ void add_device_pool2d_fwd_nhwc_index_f32_instances( ...@@ -67,6 +67,36 @@ void add_device_pool2d_fwd_nhwc_index_f32_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, NHWC, NHWC, MaxOp, true>>>&); DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, NHWC, NHWC, MaxOp, true>>>&);
#endif #endif
#ifdef CK_ENABLE_INT8
// I8
void add_device_pool2d_fwd_nhwc_i8_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, I8, I8, I32, NHWC, NHWC, MaxOp, false>>>&);
void add_device_pool2d_fwd_nhwc_i8_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, I8, I8, I32, NHWC, NHWC, AvgOp, false>>>&);
// I8 - return index
void add_device_pool2d_fwd_nhwc_index_i8_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, I8, I8, I32, NHWC, NHWC, MaxOp, true>>>&);
#endif
#ifdef CK_ENABLE_FP8
// F8
void add_device_pool2d_fwd_nhwc_f8_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F8, F8, I32, NHWC, NHWC, MaxOp, false>>>&);
void add_device_pool2d_fwd_nhwc_f8_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F8, F8, I32, NHWC, NHWC, AvgOp, false>>>&);
// F8 - return index
void add_device_pool2d_fwd_nhwc_index_f8_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F8, F8, I32, NHWC, NHWC, MaxOp, true>>>&);
#endif
template <typename InDataType, template <typename InDataType,
typename OutDataType, typename OutDataType,
typename IndexDataType, typename IndexDataType,
...@@ -140,6 +170,34 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw ...@@ -140,6 +170,34 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
add_device_pool2d_fwd_nhwc_f32_instances(op_ptrs); add_device_pool2d_fwd_nhwc_f32_instances(op_ptrs);
} }
} }
#endif
#ifdef CK_ENABLE_INT8
else if constexpr(is_same_v<InDataType, I8> && is_same_v<OutDataType, I8> &&
is_same_v<IndexDataType, I32>)
{
if constexpr(OutputIndex && ReduceOpId == MaxOp)
{
add_device_pool2d_fwd_nhwc_index_i8_instances(op_ptrs);
}
else
{
add_device_pool2d_fwd_nhwc_i8_instances(op_ptrs);
}
}
#endif
#ifdef CK_ENABLE_FP8
else if constexpr(is_same_v<InDataType, F8> && is_same_v<OutDataType, F8> &&
is_same_v<IndexDataType, I32>)
{
if constexpr(OutputIndex && ReduceOpId == MaxOp)
{
add_device_pool2d_fwd_nhwc_index_f8_instances(op_ptrs);
}
else
{
add_device_pool2d_fwd_nhwc_f8_instances(op_ptrs);
}
}
#endif #endif
} }
......
...@@ -4,5 +4,9 @@ list(APPEND DEVICE_POOL2D_FWD_INSTANCES device_avg_pool2d_fwd_nhwc_f16_instance. ...@@ -4,5 +4,9 @@ list(APPEND DEVICE_POOL2D_FWD_INSTANCES device_avg_pool2d_fwd_nhwc_f16_instance.
device_avg_pool2d_fwd_nhwc_f32_instance.cpp device_avg_pool2d_fwd_nhwc_f32_instance.cpp
device_max_pool2d_fwd_nhwc_f32_instance.cpp device_max_pool2d_fwd_nhwc_f32_instance.cpp
device_avg_pool2d_fwd_nhwc_bf16_instance.cpp device_avg_pool2d_fwd_nhwc_bf16_instance.cpp
device_max_pool2d_fwd_nhwc_bf16_instance.cpp) device_max_pool2d_fwd_nhwc_bf16_instance.cpp
device_avg_pool2d_fwd_nhwc_i8_instance.cpp
device_max_pool2d_fwd_nhwc_i8_instance.cpp
device_avg_pool2d_fwd_nhwc_f8_instance.cpp
device_max_pool2d_fwd_nhwc_f8_instance.cpp)
add_instance_library(device_pool2d_fwd_instance ${DEVICE_POOL2D_FWD_INSTANCES}) add_instance_library(device_pool2d_fwd_instance ${DEVICE_POOL2D_FWD_INSTANCES})
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "pool2d_fwd_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
void add_device_pool2d_fwd_nhwc_f8_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F8, F8, I32, NHWC, NHWC, ReduceOpId, false>>>&
instances)
{
add_device_operation_instances(
instances, device_pool2d_fwd_nhwc_instances<F8, F8, I32, F32, ReduceOpId, false>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "pool2d_fwd_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
void add_device_pool2d_fwd_nhwc_i8_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, I8, I8, I32, NHWC, NHWC, ReduceOpId, false>>>&
instances)
{
add_device_operation_instances(
instances, device_pool2d_fwd_nhwc_instances<I8, I8, I32, F32, ReduceOpId, false>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "pool2d_fwd_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
void add_device_pool2d_fwd_nhwc_f8_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F8, F8, I32, NHWC, NHWC, ReduceOpId, false>>>&
instances)
{
add_device_operation_instances(
instances, device_pool2d_fwd_nhwc_instances<F8, F8, I32, F32, ReduceOpId, false>{});
}
void add_device_pool2d_fwd_nhwc_index_f8_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F8, F8, I32, NHWC, NHWC, ReduceOpId, true>>>&
instances)
{
add_device_operation_instances(
instances, device_pool2d_fwd_nhwc_instances<F8, F8, I32, F32, ReduceOpId, true>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "pool2d_fwd_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
void add_device_pool2d_fwd_nhwc_i8_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, I8, I8, I32, NHWC, NHWC, ReduceOpId, false>>>&
instances)
{
add_device_operation_instances(
instances, device_pool2d_fwd_nhwc_instances<I8, I8, I32, F32, ReduceOpId, false>{});
}
void add_device_pool2d_fwd_nhwc_index_i8_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, I8, I8, I32, NHWC, NHWC, ReduceOpId, true>>>&
instances)
{
add_device_operation_instances(
instances, device_pool2d_fwd_nhwc_instances<I8, I8, I32, F32, ReduceOpId, true>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -15,9 +15,11 @@ namespace device { ...@@ -15,9 +15,11 @@ namespace device {
namespace instance { namespace instance {
using I32 = int32_t; using I32 = int32_t;
using F32 = float;
using F16 = ck::half_t; using F16 = ck::half_t;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using F32 = float; using I8 = int8_t;
using F8 = ck::f8_t;
using NHWC = ck::tensor_layout::convolution::NHWC; using NHWC = ck::tensor_layout::convolution::NHWC;
template <typename InDataType, template <typename InDataType,
......
...@@ -49,9 +49,18 @@ struct maxPoolFwdArgParser ...@@ -49,9 +49,18 @@ struct maxPoolFwdArgParser
} }
}; };
enum struct PoolDataType
{
F32 = 0,
BF16,
F16,
INT8,
F8,
};
void print_help_max_pool2d_fwd() void print_help_max_pool2d_fwd()
{ {
std::cout << "arg1: data type (0: fp16; 1: fp32; 5: bf16)\n" std::cout << "arg1: data type (0: fp16; 1: fp32; 2: bf16; 3: int8; 4: fp8)\n"
<< "arg2: verification (0: no; 1: yes)\n" << "arg2: verification (0: no; 1: yes)\n"
<< "arg3: initialization (0: no init; 1: integer value; 2: decimal value)\n" << "arg3: initialization (0: no init; 1: integer value; 2: decimal value)\n"
<< "arg4: print tensor value (0: no; 1: yes)\n" << "arg4: print tensor value (0: no; 1: yes)\n"
...@@ -70,12 +79,12 @@ void print_help_max_pool2d_fwd() ...@@ -70,12 +79,12 @@ void print_help_max_pool2d_fwd()
int profile_max_pool2d_fwd(int argc, char* argv[]) int profile_max_pool2d_fwd(int argc, char* argv[])
{ {
ck::DataTypeEnum data_type = ck::DataTypeEnum::Half; PoolDataType data_type = PoolDataType::F32;
bool do_verification = true; bool do_verification = true;
int init_method = 0; int init_method = 0;
bool do_log = false; bool do_log = false;
bool time_kernel = true; bool time_kernel = true;
bool return_index = false; bool return_index = false;
std::vector<index_t> in_length = {2, 32, 30, 30}; std::vector<index_t> in_length = {2, 32, 30, 30};
std::vector<index_t> wsize = {2, 2}; std::vector<index_t> wsize = {2, 2};
...@@ -91,7 +100,7 @@ int profile_max_pool2d_fwd(int argc, char* argv[]) ...@@ -91,7 +100,7 @@ int profile_max_pool2d_fwd(int argc, char* argv[])
} }
else if(argc == 28) else if(argc == 28)
{ {
data_type = static_cast<ck::DataTypeEnum>(std::stoi(argv[2])); data_type = static_cast<PoolDataType>(std::stoi(argv[2]));
do_verification = std::stoi(argv[3]); do_verification = std::stoi(argv[3]);
init_method = std::stoi(argv[4]); init_method = std::stoi(argv[4]);
do_log = std::stoi(argv[5]); do_log = std::stoi(argv[5]);
...@@ -113,11 +122,13 @@ int profile_max_pool2d_fwd(int argc, char* argv[]) ...@@ -113,11 +122,13 @@ int profile_max_pool2d_fwd(int argc, char* argv[])
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using F32 = float; using F32 = float;
using I32 = int32_t; using I32 = int32_t;
using F8 = ck::f8_t;
using I8 = int8_t;
using NHWC = ck::tensor_layout::convolution::NHWC; using NHWC = ck::tensor_layout::convolution::NHWC;
constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX; constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
if(data_type == ck::DataTypeEnum::Half) if(data_type == PoolDataType::F16)
{ {
if(return_index) if(return_index)
{ {
...@@ -150,7 +161,7 @@ int profile_max_pool2d_fwd(int argc, char* argv[]) ...@@ -150,7 +161,7 @@ int profile_max_pool2d_fwd(int argc, char* argv[])
pad2); pad2);
} }
} }
else if(data_type == ck::DataTypeEnum::BFloat16) else if(data_type == PoolDataType::BF16)
{ {
if(return_index) if(return_index)
{ {
...@@ -189,7 +200,7 @@ int profile_max_pool2d_fwd(int argc, char* argv[]) ...@@ -189,7 +200,7 @@ int profile_max_pool2d_fwd(int argc, char* argv[])
pad2); pad2);
} }
} }
else if(data_type == ck::DataTypeEnum::Float) else if(data_type == PoolDataType::F32)
{ {
if(return_index) if(return_index)
{ {
...@@ -222,6 +233,72 @@ int profile_max_pool2d_fwd(int argc, char* argv[]) ...@@ -222,6 +233,72 @@ int profile_max_pool2d_fwd(int argc, char* argv[])
pad2); pad2);
} }
} }
else if(data_type == PoolDataType::INT8)
{
if(return_index)
{
ck::profiler::
profile_pool2d_fwd_impl<I8, I8, F32, I32, NHWC, NHWC, ReduceOpId, false, true>(
do_verification,
init_method,
do_log,
time_kernel,
in_length,
wsize,
wstride,
wdilation,
pad1,
pad2);
}
else
{
ck::profiler::
profile_pool2d_fwd_impl<I8, I8, F32, I32, NHWC, NHWC, ReduceOpId, false, false>(
do_verification,
init_method,
do_log,
time_kernel,
in_length,
wsize,
wstride,
wdilation,
pad1,
pad2);
}
}
else if(data_type == PoolDataType::F8)
{
if(return_index)
{
ck::profiler::
profile_pool2d_fwd_impl<F8, F8, F32, I32, NHWC, NHWC, ReduceOpId, false, true>(
do_verification,
init_method,
do_log,
time_kernel,
in_length,
wsize,
wstride,
wdilation,
pad1,
pad2);
}
else
{
ck::profiler::
profile_pool2d_fwd_impl<F8, F8, F32, I32, NHWC, NHWC, ReduceOpId, false, false>(
do_verification,
init_method,
do_log,
time_kernel,
in_length,
wsize,
wstride,
wdilation,
pad1,
pad2);
}
}
else else
{ {
throw std::runtime_error("not implemented yet"); throw std::runtime_error("not implemented yet");
......
...@@ -14,13 +14,12 @@ class TestAvgPool2dFwd : public ::testing::Test ...@@ -14,13 +14,12 @@ class TestAvgPool2dFwd : public ::testing::Test
using ComputeDataType = std::tuple_element_t<2, Tuple>; using ComputeDataType = std::tuple_element_t<2, Tuple>;
using IndexDataType = std::tuple_element_t<3, Tuple>; using IndexDataType = std::tuple_element_t<3, Tuple>;
std::vector<PoolingParam> params; static std::vector<PoolingParam> params;
void Run() void Run()
{ {
for(auto param : params) for(auto param : params)
{ {
// avg pool
bool success = bool success =
ck::profiler::profile_pool2d_fwd_impl<InDataType, ck::profiler::profile_pool2d_fwd_impl<InDataType,
OutDataType, OutDataType,
...@@ -45,24 +44,102 @@ class TestAvgPool2dFwd : public ::testing::Test ...@@ -45,24 +44,102 @@ class TestAvgPool2dFwd : public ::testing::Test
} }
}; };
using KernelTypes = std::conditional_t< template <typename T>
CK_ENABLE_FP16 && CK_ENABLE_BF16, std::vector<PoolingParam> TestAvgPool2dFwd<T>::params = {
::testing::Types<std::tuple<F16, F16, F32, I32>, {{{1, 1, 1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0}, {0, 0}},
std::tuple<F16, F16, F32, I32>, {{2, 16, 64, 64}, {64, 64}, {1, 1}, {1, 1}, {0, 0}, {0, 0}},
std::tuple<BF16, BF16, F32, I32>, {{2, 16, 64, 64}, {4, 4}, {4, 4}, {2, 2}, {0, 0}, {0, 0}},
std::tuple<BF16, BF16, F32, I32>, {{2, 32, 30, 30}, {2, 2}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}}};
std::tuple<F32, F32, F32, I32>,
std::tuple<F32, F32, F32, I32>>,
::testing::Types<std::tuple<F32, F32, F32, I32>, std::tuple<F32, F32, F32, I32>>>;
TYPED_TEST_SUITE(TestAvgPool2dFwd, KernelTypes); using AvgPool2D_F32_Types =
TYPED_TEST(TestAvgPool2dFwd, Test_Pool) ::testing::Types<std::tuple<F32, F32, F32, I32>, std::tuple<F32, F32, F32, I32>>;
using AvgPool2D_F16_Types =
::testing::Types<std::tuple<F16, F16, F32, I32>, std::tuple<F16, F16, F32, I32>>;
using AvgPool2D_BF16_Types =
::testing::Types<std::tuple<I8, I8, F32, I32>, std::tuple<BF16, BF16, F32, I32>>;
using AvgPool2D_I8_Types =
::testing::Types<std::tuple<I8, I8, F32, I32>, std::tuple<I8, I8, F32, I32>>;
using AvgPool2D_F8_Types =
::testing::Types<std::tuple<F8, F8, F32, I32>, std::tuple<F8, F8, F32, I32>>;
template <typename TType>
class AvgPool2D_F32 : public TestAvgPool2dFwd<TType>
{
protected:
void SetUp() override
{
if(!CK_ENABLE_FP32)
{
GTEST_SKIP() << "Skipping AvgPool2D_F32 tests because CK_ENABLE_FP32 is "
"not enabled";
}
}
};
template <typename TType>
class AvgPool2D_F16 : public TestAvgPool2dFwd<TType>
{
protected:
void SetUp() override
{
if(!CK_ENABLE_FP16)
{
GTEST_SKIP() << "Skipping AvgPool2D_F16 tests because CK_ENABLE_FP16 is "
"not enabled";
}
}
};
template <typename TType>
class AvgPool2D_BF16 : public TestAvgPool2dFwd<TType>
{
protected:
void SetUp() override
{
if(!CK_ENABLE_BF16)
{
GTEST_SKIP() << "Skipping AvgPool2D_BF16 tests because CK_ENABLE_BF16 is "
"not enabled";
}
}
};
template <typename TType>
class AvgPool2D_I8 : public TestAvgPool2dFwd<TType>
{
protected:
void SetUp() override
{
if(!CK_ENABLE_INT8)
{
GTEST_SKIP() << "Skipping AvgPool2D_I8 tests because CK_ENABLE_INT8 is "
"not enabled";
}
}
};
template <typename TType>
class AvgPool2D_F8 : public TestAvgPool2dFwd<TType>
{ {
// length, window_length, window_stride, window_dilation, left_pad, right_pad protected:
this->params = {{{1, 1, 1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}, void SetUp() override
{{2, 16, 64, 64}, {64, 64}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}, {
{{2, 16, 64, 64}, {4, 4}, {4, 4}, {2, 2}, {0, 0}, {0, 0}}, if(!CK_ENABLE_FP8)
{{2, 32, 30, 30}, {2, 2}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}}; {
GTEST_SKIP() << "Skipping AvgPool2D_F8 tests because CK_ENABLE_FP8 is "
"not enabled";
}
}
};
TYPED_TEST_SUITE(AvgPool2D_F32, AvgPool2D_F32_Types);
TYPED_TEST_SUITE(AvgPool2D_F16, AvgPool2D_F16_Types);
TYPED_TEST_SUITE(AvgPool2D_BF16, AvgPool2D_BF16_Types);
TYPED_TEST_SUITE(AvgPool2D_I8, AvgPool2D_I8_Types);
TYPED_TEST_SUITE(AvgPool2D_F8, AvgPool2D_F8_Types);
this->Run(); TYPED_TEST(AvgPool2D_F32, AvgPool2D_I8_Test) { this->Run(); }
} TYPED_TEST(AvgPool2D_F16, AvgPool2D_F16_Test) { this->Run(); }
TYPED_TEST(AvgPool2D_BF16, AvgPool2D_BF16_Test) { this->Run(); }
TYPED_TEST(AvgPool2D_I8, AvgPool2D_I8_Test) { this->Run(); }
TYPED_TEST(AvgPool2D_F8, AvgPool2D_F8_Test) { this->Run(); }
...@@ -15,7 +15,7 @@ class TestMaxPool2dFwd : public ::testing::Test ...@@ -15,7 +15,7 @@ class TestMaxPool2dFwd : public ::testing::Test
using IndexDataType = std::tuple_element_t<3, Tuple>; using IndexDataType = std::tuple_element_t<3, Tuple>;
static constexpr bool ReturnIndex = std::tuple_element_t<4, Tuple>::value; static constexpr bool ReturnIndex = std::tuple_element_t<4, Tuple>::value;
std::vector<PoolingParam> params; static std::vector<PoolingParam> params;
void Run() void Run()
{ {
...@@ -46,27 +46,105 @@ class TestMaxPool2dFwd : public ::testing::Test ...@@ -46,27 +46,105 @@ class TestMaxPool2dFwd : public ::testing::Test
} }
}; };
template <typename T>
std::vector<PoolingParam> TestMaxPool2dFwd<T>::params = {
{{{1, 1, 1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0}, {0, 0}},
{{2, 16, 64, 64}, {64, 64}, {1, 1}, {1, 1}, {0, 0}, {0, 0}},
{{2, 16, 64, 64}, {4, 4}, {4, 4}, {2, 2}, {0, 0}, {0, 0}},
{{2, 32, 30, 30}, {2, 2}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}}};
using true_t = std::integral_constant<bool, true>; using true_t = std::integral_constant<bool, true>;
using false_t = std::integral_constant<bool, false>; using false_t = std::integral_constant<bool, false>;
using KernelTypes = std::conditional_t<CK_ENABLE_FP16 && CK_ENABLE_BF16, using MaxPool2D_F32_Types = ::testing::Types<std::tuple<F32, F32, F32, I32, true_t>,
::testing::Types<std::tuple<F16, F16, F32, I32, true_t>, std::tuple<F32, F32, F32, I32, false_t>>;
std::tuple<F16, F16, F32, I32, false_t>, using MaxPool2D_F16_Types = ::testing::Types<std::tuple<F16, F16, F32, I32, true_t>,
std::tuple<BF16, BF16, F32, I32, true_t>, std::tuple<F16, F16, F32, I32, false_t>>;
std::tuple<BF16, BF16, F32, I32, false_t>, using MaxPool2D_BF16_Types = ::testing::Types<std::tuple<I8, I8, F32, I32, true_t>,
std::tuple<F32, F32, F32, I32, true_t>, std::tuple<BF16, BF16, F32, I32, false_t>>;
std::tuple<F32, F32, F32, I32, false_t>>, using MaxPool2D_I8_Types =
::testing::Types<std::tuple<F32, F32, F32, I32, true_t>, ::testing::Types<std::tuple<I8, I8, F32, I32, true_t>, std::tuple<I8, I8, F32, I32, false_t>>;
std::tuple<F32, F32, F32, I32, false_t>>>; using MaxPool2D_F8_Types =
::testing::Types<std::tuple<F8, F8, F32, I32, true_t>, std::tuple<F8, F8, F32, I32, false_t>>;
template <typename TType>
class MaxPool2D_F32 : public TestMaxPool2dFwd<TType>
{
protected:
void SetUp() override
{
if(!CK_ENABLE_FP32)
{
GTEST_SKIP() << "Skipping MaxPool2D_F32 tests because CK_ENABLE_FP32 is "
"not enabled";
}
}
};
template <typename TType>
class MaxPool2D_F16 : public TestMaxPool2dFwd<TType>
{
protected:
void SetUp() override
{
if(!CK_ENABLE_FP16)
{
GTEST_SKIP() << "Skipping MaxPool2D_F16 tests because CK_ENABLE_FP16 is "
"not enabled";
}
}
};
template <typename TType>
class MaxPool2D_BF16 : public TestMaxPool2dFwd<TType>
{
protected:
void SetUp() override
{
if(!CK_ENABLE_BF16)
{
GTEST_SKIP() << "Skipping MaxPool2D_BF16 tests because CK_ENABLE_BF16 is "
"not enabled";
}
}
};
template <typename TType>
class MaxPool2D_I8 : public TestMaxPool2dFwd<TType>
{
protected:
void SetUp() override
{
if(!CK_ENABLE_INT8)
{
GTEST_SKIP() << "Skipping MaxPool2D_I8 tests because CK_ENABLE_INT8 is "
"not enabled";
}
}
};
TYPED_TEST_SUITE(TestMaxPool2dFwd, KernelTypes); template <typename TType>
TYPED_TEST(TestMaxPool2dFwd, Test_Pool) class MaxPool2D_F8 : public TestMaxPool2dFwd<TType>
{ {
// length, window_length, window_stride, window_dilation, left_pad, right_pad protected:
this->params = {{{1, 1, 1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}, void SetUp() override
{{2, 16, 64, 64}, {64, 64}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}, {
{{2, 16, 64, 64}, {4, 4}, {4, 4}, {2, 2}, {0, 0}, {0, 0}}, if(!CK_ENABLE_FP8)
{{2, 32, 30, 30}, {2, 2}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}}; {
GTEST_SKIP() << "Skipping MaxPool2D_F8 tests because CK_ENABLE_FP8 is "
"not enabled";
}
}
};
TYPED_TEST_SUITE(MaxPool2D_F32, MaxPool2D_F32_Types);
TYPED_TEST_SUITE(MaxPool2D_F16, MaxPool2D_F16_Types);
TYPED_TEST_SUITE(MaxPool2D_BF16, MaxPool2D_BF16_Types);
TYPED_TEST_SUITE(MaxPool2D_I8, MaxPool2D_I8_Types);
TYPED_TEST_SUITE(MaxPool2D_F8, MaxPool2D_F8_Types);
this->Run(); TYPED_TEST(MaxPool2D_F32, MaxPool2D_I8_Test) { this->Run(); }
} TYPED_TEST(MaxPool2D_F16, MaxPool2D_F16_Test) { this->Run(); }
TYPED_TEST(MaxPool2D_BF16, MaxPool2D_BF16_Test) { this->Run(); }
TYPED_TEST(MaxPool2D_I8, MaxPool2D_I8_Test) { this->Run(); }
TYPED_TEST(MaxPool2D_F8, MaxPool2D_F8_Test) { this->Run(); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment