Unverified Commit 8f8a2ce3 authored by jakpiase's avatar jakpiase Committed by GitHub
Browse files

Add pool2d int8 and fp8 instances (#1508)

* add pool2d fp8 and int8

* minor fixes

* add formatting

* add reviewer suggestions

* add reviewer suggestions
parent a4982c3b
......@@ -67,6 +67,36 @@ void add_device_pool2d_fwd_nhwc_index_f32_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, NHWC, NHWC, MaxOp, true>>>&);
#endif
#ifdef CK_ENABLE_INT8
// I8
void add_device_pool2d_fwd_nhwc_i8_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, I8, I8, I32, NHWC, NHWC, MaxOp, false>>>&);
void add_device_pool2d_fwd_nhwc_i8_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, I8, I8, I32, NHWC, NHWC, AvgOp, false>>>&);
// I8 - return index
void add_device_pool2d_fwd_nhwc_index_i8_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, I8, I8, I32, NHWC, NHWC, MaxOp, true>>>&);
#endif
#ifdef CK_ENABLE_FP8
// F8
void add_device_pool2d_fwd_nhwc_f8_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F8, F8, I32, NHWC, NHWC, MaxOp, false>>>&);
void add_device_pool2d_fwd_nhwc_f8_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F8, F8, I32, NHWC, NHWC, AvgOp, false>>>&);
// F8 - return index
void add_device_pool2d_fwd_nhwc_index_f8_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F8, F8, I32, NHWC, NHWC, MaxOp, true>>>&);
#endif
template <typename InDataType,
typename OutDataType,
typename IndexDataType,
......@@ -140,6 +170,34 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
add_device_pool2d_fwd_nhwc_f32_instances(op_ptrs);
}
}
#endif
#ifdef CK_ENABLE_INT8
else if constexpr(is_same_v<InDataType, I8> && is_same_v<OutDataType, I8> &&
is_same_v<IndexDataType, I32>)
{
if constexpr(OutputIndex && ReduceOpId == MaxOp)
{
add_device_pool2d_fwd_nhwc_index_i8_instances(op_ptrs);
}
else
{
add_device_pool2d_fwd_nhwc_i8_instances(op_ptrs);
}
}
#endif
#ifdef CK_ENABLE_FP8
else if constexpr(is_same_v<InDataType, F8> && is_same_v<OutDataType, F8> &&
is_same_v<IndexDataType, I32>)
{
if constexpr(OutputIndex && ReduceOpId == MaxOp)
{
add_device_pool2d_fwd_nhwc_index_f8_instances(op_ptrs);
}
else
{
add_device_pool2d_fwd_nhwc_f8_instances(op_ptrs);
}
}
#endif
}
......
......@@ -4,5 +4,9 @@ list(APPEND DEVICE_POOL2D_FWD_INSTANCES device_avg_pool2d_fwd_nhwc_f16_instance.
device_avg_pool2d_fwd_nhwc_f32_instance.cpp
device_max_pool2d_fwd_nhwc_f32_instance.cpp
device_avg_pool2d_fwd_nhwc_bf16_instance.cpp
device_max_pool2d_fwd_nhwc_bf16_instance.cpp)
device_max_pool2d_fwd_nhwc_bf16_instance.cpp
device_avg_pool2d_fwd_nhwc_i8_instance.cpp
device_max_pool2d_fwd_nhwc_i8_instance.cpp
device_avg_pool2d_fwd_nhwc_f8_instance.cpp
device_max_pool2d_fwd_nhwc_f8_instance.cpp)
add_instance_library(device_pool2d_fwd_instance ${DEVICE_POOL2D_FWD_INSTANCES})
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "pool2d_fwd_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
void add_device_pool2d_fwd_nhwc_f8_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F8, F8, I32, NHWC, NHWC, ReduceOpId, false>>>&
instances)
{
add_device_operation_instances(
instances, device_pool2d_fwd_nhwc_instances<F8, F8, I32, F32, ReduceOpId, false>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "pool2d_fwd_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
void add_device_pool2d_fwd_nhwc_i8_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, I8, I8, I32, NHWC, NHWC, ReduceOpId, false>>>&
instances)
{
add_device_operation_instances(
instances, device_pool2d_fwd_nhwc_instances<I8, I8, I32, F32, ReduceOpId, false>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "pool2d_fwd_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
void add_device_pool2d_fwd_nhwc_f8_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F8, F8, I32, NHWC, NHWC, ReduceOpId, false>>>&
instances)
{
add_device_operation_instances(
instances, device_pool2d_fwd_nhwc_instances<F8, F8, I32, F32, ReduceOpId, false>{});
}
void add_device_pool2d_fwd_nhwc_index_f8_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F8, F8, I32, NHWC, NHWC, ReduceOpId, true>>>&
instances)
{
add_device_operation_instances(
instances, device_pool2d_fwd_nhwc_instances<F8, F8, I32, F32, ReduceOpId, true>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "pool2d_fwd_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
void add_device_pool2d_fwd_nhwc_i8_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, I8, I8, I32, NHWC, NHWC, ReduceOpId, false>>>&
instances)
{
add_device_operation_instances(
instances, device_pool2d_fwd_nhwc_instances<I8, I8, I32, F32, ReduceOpId, false>{});
}
void add_device_pool2d_fwd_nhwc_index_i8_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, I8, I8, I32, NHWC, NHWC, ReduceOpId, true>>>&
instances)
{
add_device_operation_instances(
instances, device_pool2d_fwd_nhwc_instances<I8, I8, I32, F32, ReduceOpId, true>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -15,9 +15,11 @@ namespace device {
namespace instance {
using I32 = int32_t;
using F32 = float;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using I8 = int8_t;
using F8 = ck::f8_t;
using NHWC = ck::tensor_layout::convolution::NHWC;
template <typename InDataType,
......
......@@ -49,9 +49,18 @@ struct maxPoolFwdArgParser
}
};
enum struct PoolDataType
{
F32 = 0,
BF16,
F16,
INT8,
F8,
};
void print_help_max_pool2d_fwd()
{
std::cout << "arg1: data type (0: fp16; 1: fp32; 5: bf16)\n"
std::cout << "arg1: data type (0: fp16; 1: fp32; 2: bf16; 3: int8; 4: fp8)\n"
<< "arg2: verification (0: no; 1: yes)\n"
<< "arg3: initialization (0: no init; 1: integer value; 2: decimal value)\n"
<< "arg4: print tensor value (0: no; 1: yes)\n"
......@@ -70,12 +79,12 @@ void print_help_max_pool2d_fwd()
int profile_max_pool2d_fwd(int argc, char* argv[])
{
ck::DataTypeEnum data_type = ck::DataTypeEnum::Half;
bool do_verification = true;
int init_method = 0;
bool do_log = false;
bool time_kernel = true;
bool return_index = false;
PoolDataType data_type = PoolDataType::F32;
bool do_verification = true;
int init_method = 0;
bool do_log = false;
bool time_kernel = true;
bool return_index = false;
std::vector<index_t> in_length = {2, 32, 30, 30};
std::vector<index_t> wsize = {2, 2};
......@@ -91,7 +100,7 @@ int profile_max_pool2d_fwd(int argc, char* argv[])
}
else if(argc == 28)
{
data_type = static_cast<ck::DataTypeEnum>(std::stoi(argv[2]));
data_type = static_cast<PoolDataType>(std::stoi(argv[2]));
do_verification = std::stoi(argv[3]);
init_method = std::stoi(argv[4]);
do_log = std::stoi(argv[5]);
......@@ -113,11 +122,13 @@ int profile_max_pool2d_fwd(int argc, char* argv[])
using BF16 = ck::bhalf_t;
using F32 = float;
using I32 = int32_t;
using F8 = ck::f8_t;
using I8 = int8_t;
using NHWC = ck::tensor_layout::convolution::NHWC;
constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
if(data_type == ck::DataTypeEnum::Half)
if(data_type == PoolDataType::F16)
{
if(return_index)
{
......@@ -150,7 +161,7 @@ int profile_max_pool2d_fwd(int argc, char* argv[])
pad2);
}
}
else if(data_type == ck::DataTypeEnum::BFloat16)
else if(data_type == PoolDataType::BF16)
{
if(return_index)
{
......@@ -189,7 +200,7 @@ int profile_max_pool2d_fwd(int argc, char* argv[])
pad2);
}
}
else if(data_type == ck::DataTypeEnum::Float)
else if(data_type == PoolDataType::F32)
{
if(return_index)
{
......@@ -222,6 +233,72 @@ int profile_max_pool2d_fwd(int argc, char* argv[])
pad2);
}
}
else if(data_type == PoolDataType::INT8)
{
if(return_index)
{
ck::profiler::
profile_pool2d_fwd_impl<I8, I8, F32, I32, NHWC, NHWC, ReduceOpId, false, true>(
do_verification,
init_method,
do_log,
time_kernel,
in_length,
wsize,
wstride,
wdilation,
pad1,
pad2);
}
else
{
ck::profiler::
profile_pool2d_fwd_impl<I8, I8, F32, I32, NHWC, NHWC, ReduceOpId, false, false>(
do_verification,
init_method,
do_log,
time_kernel,
in_length,
wsize,
wstride,
wdilation,
pad1,
pad2);
}
}
else if(data_type == PoolDataType::F8)
{
if(return_index)
{
ck::profiler::
profile_pool2d_fwd_impl<F8, F8, F32, I32, NHWC, NHWC, ReduceOpId, false, true>(
do_verification,
init_method,
do_log,
time_kernel,
in_length,
wsize,
wstride,
wdilation,
pad1,
pad2);
}
else
{
ck::profiler::
profile_pool2d_fwd_impl<F8, F8, F32, I32, NHWC, NHWC, ReduceOpId, false, false>(
do_verification,
init_method,
do_log,
time_kernel,
in_length,
wsize,
wstride,
wdilation,
pad1,
pad2);
}
}
else
{
throw std::runtime_error("not implemented yet");
......
......@@ -14,13 +14,12 @@ class TestAvgPool2dFwd : public ::testing::Test
using ComputeDataType = std::tuple_element_t<2, Tuple>;
using IndexDataType = std::tuple_element_t<3, Tuple>;
std::vector<PoolingParam> params;
static std::vector<PoolingParam> params;
void Run()
{
for(auto param : params)
{
// avg pool
bool success =
ck::profiler::profile_pool2d_fwd_impl<InDataType,
OutDataType,
......@@ -45,24 +44,102 @@ class TestAvgPool2dFwd : public ::testing::Test
}
};
using KernelTypes = std::conditional_t<
CK_ENABLE_FP16 && CK_ENABLE_BF16,
::testing::Types<std::tuple<F16, F16, F32, I32>,
std::tuple<F16, F16, F32, I32>,
std::tuple<BF16, BF16, F32, I32>,
std::tuple<BF16, BF16, F32, I32>,
std::tuple<F32, F32, F32, I32>,
std::tuple<F32, F32, F32, I32>>,
::testing::Types<std::tuple<F32, F32, F32, I32>, std::tuple<F32, F32, F32, I32>>>;
template <typename T>
std::vector<PoolingParam> TestAvgPool2dFwd<T>::params = {
{{{1, 1, 1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0}, {0, 0}},
{{2, 16, 64, 64}, {64, 64}, {1, 1}, {1, 1}, {0, 0}, {0, 0}},
{{2, 16, 64, 64}, {4, 4}, {4, 4}, {2, 2}, {0, 0}, {0, 0}},
{{2, 32, 30, 30}, {2, 2}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}}};
TYPED_TEST_SUITE(TestAvgPool2dFwd, KernelTypes);
TYPED_TEST(TestAvgPool2dFwd, Test_Pool)
using AvgPool2D_F32_Types =
::testing::Types<std::tuple<F32, F32, F32, I32>, std::tuple<F32, F32, F32, I32>>;
using AvgPool2D_F16_Types =
::testing::Types<std::tuple<F16, F16, F32, I32>, std::tuple<F16, F16, F32, I32>>;
using AvgPool2D_BF16_Types =
::testing::Types<std::tuple<I8, I8, F32, I32>, std::tuple<BF16, BF16, F32, I32>>;
using AvgPool2D_I8_Types =
::testing::Types<std::tuple<I8, I8, F32, I32>, std::tuple<I8, I8, F32, I32>>;
using AvgPool2D_F8_Types =
::testing::Types<std::tuple<F8, F8, F32, I32>, std::tuple<F8, F8, F32, I32>>;
template <typename TType>
class AvgPool2D_F32 : public TestAvgPool2dFwd<TType>
{
protected:
void SetUp() override
{
if(!CK_ENABLE_FP32)
{
GTEST_SKIP() << "Skipping AvgPool2D_F32 tests because CK_ENABLE_FP32 is "
"not enabled";
}
}
};
template <typename TType>
class AvgPool2D_F16 : public TestAvgPool2dFwd<TType>
{
protected:
void SetUp() override
{
if(!CK_ENABLE_FP16)
{
GTEST_SKIP() << "Skipping AvgPool2D_F16 tests because CK_ENABLE_FP16 is "
"not enabled";
}
}
};
template <typename TType>
class AvgPool2D_BF16 : public TestAvgPool2dFwd<TType>
{
protected:
void SetUp() override
{
if(!CK_ENABLE_BF16)
{
GTEST_SKIP() << "Skipping AvgPool2D_BF16 tests because CK_ENABLE_BF16 is "
"not enabled";
}
}
};
template <typename TType>
class AvgPool2D_I8 : public TestAvgPool2dFwd<TType>
{
protected:
void SetUp() override
{
if(!CK_ENABLE_INT8)
{
GTEST_SKIP() << "Skipping AvgPool2D_I8 tests because CK_ENABLE_INT8 is "
"not enabled";
}
}
};
template <typename TType>
class AvgPool2D_F8 : public TestAvgPool2dFwd<TType>
{
// length, window_length, window_stride, window_dilation, left_pad, right_pad
this->params = {{{1, 1, 1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0}, {0, 0}},
{{2, 16, 64, 64}, {64, 64}, {1, 1}, {1, 1}, {0, 0}, {0, 0}},
{{2, 16, 64, 64}, {4, 4}, {4, 4}, {2, 2}, {0, 0}, {0, 0}},
{{2, 32, 30, 30}, {2, 2}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}};
protected:
void SetUp() override
{
if(!CK_ENABLE_FP8)
{
GTEST_SKIP() << "Skipping AvgPool2D_F8 tests because CK_ENABLE_FP8 is "
"not enabled";
}
}
};
TYPED_TEST_SUITE(AvgPool2D_F32, AvgPool2D_F32_Types);
TYPED_TEST_SUITE(AvgPool2D_F16, AvgPool2D_F16_Types);
TYPED_TEST_SUITE(AvgPool2D_BF16, AvgPool2D_BF16_Types);
TYPED_TEST_SUITE(AvgPool2D_I8, AvgPool2D_I8_Types);
TYPED_TEST_SUITE(AvgPool2D_F8, AvgPool2D_F8_Types);
this->Run();
}
TYPED_TEST(AvgPool2D_F32, AvgPool2D_I8_Test) { this->Run(); }
TYPED_TEST(AvgPool2D_F16, AvgPool2D_F16_Test) { this->Run(); }
TYPED_TEST(AvgPool2D_BF16, AvgPool2D_BF16_Test) { this->Run(); }
TYPED_TEST(AvgPool2D_I8, AvgPool2D_I8_Test) { this->Run(); }
TYPED_TEST(AvgPool2D_F8, AvgPool2D_F8_Test) { this->Run(); }
......@@ -15,7 +15,7 @@ class TestMaxPool2dFwd : public ::testing::Test
using IndexDataType = std::tuple_element_t<3, Tuple>;
static constexpr bool ReturnIndex = std::tuple_element_t<4, Tuple>::value;
std::vector<PoolingParam> params;
static std::vector<PoolingParam> params;
void Run()
{
......@@ -46,27 +46,105 @@ class TestMaxPool2dFwd : public ::testing::Test
}
};
template <typename T>
std::vector<PoolingParam> TestMaxPool2dFwd<T>::params = {
{{{1, 1, 1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0}, {0, 0}},
{{2, 16, 64, 64}, {64, 64}, {1, 1}, {1, 1}, {0, 0}, {0, 0}},
{{2, 16, 64, 64}, {4, 4}, {4, 4}, {2, 2}, {0, 0}, {0, 0}},
{{2, 32, 30, 30}, {2, 2}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}}};
using true_t = std::integral_constant<bool, true>;
using false_t = std::integral_constant<bool, false>;
using KernelTypes = std::conditional_t<CK_ENABLE_FP16 && CK_ENABLE_BF16,
::testing::Types<std::tuple<F16, F16, F32, I32, true_t>,
std::tuple<F16, F16, F32, I32, false_t>,
std::tuple<BF16, BF16, F32, I32, true_t>,
std::tuple<BF16, BF16, F32, I32, false_t>,
std::tuple<F32, F32, F32, I32, true_t>,
std::tuple<F32, F32, F32, I32, false_t>>,
::testing::Types<std::tuple<F32, F32, F32, I32, true_t>,
std::tuple<F32, F32, F32, I32, false_t>>>;
using MaxPool2D_F32_Types = ::testing::Types<std::tuple<F32, F32, F32, I32, true_t>,
std::tuple<F32, F32, F32, I32, false_t>>;
using MaxPool2D_F16_Types = ::testing::Types<std::tuple<F16, F16, F32, I32, true_t>,
std::tuple<F16, F16, F32, I32, false_t>>;
using MaxPool2D_BF16_Types = ::testing::Types<std::tuple<I8, I8, F32, I32, true_t>,
std::tuple<BF16, BF16, F32, I32, false_t>>;
using MaxPool2D_I8_Types =
::testing::Types<std::tuple<I8, I8, F32, I32, true_t>, std::tuple<I8, I8, F32, I32, false_t>>;
using MaxPool2D_F8_Types =
::testing::Types<std::tuple<F8, F8, F32, I32, true_t>, std::tuple<F8, F8, F32, I32, false_t>>;
template <typename TType>
class MaxPool2D_F32 : public TestMaxPool2dFwd<TType>
{
protected:
void SetUp() override
{
if(!CK_ENABLE_FP32)
{
GTEST_SKIP() << "Skipping MaxPool2D_F32 tests because CK_ENABLE_FP32 is "
"not enabled";
}
}
};
template <typename TType>
class MaxPool2D_F16 : public TestMaxPool2dFwd<TType>
{
protected:
void SetUp() override
{
if(!CK_ENABLE_FP16)
{
GTEST_SKIP() << "Skipping MaxPool2D_F16 tests because CK_ENABLE_FP16 is "
"not enabled";
}
}
};
template <typename TType>
class MaxPool2D_BF16 : public TestMaxPool2dFwd<TType>
{
protected:
void SetUp() override
{
if(!CK_ENABLE_BF16)
{
GTEST_SKIP() << "Skipping MaxPool2D_BF16 tests because CK_ENABLE_BF16 is "
"not enabled";
}
}
};
template <typename TType>
class MaxPool2D_I8 : public TestMaxPool2dFwd<TType>
{
protected:
void SetUp() override
{
if(!CK_ENABLE_INT8)
{
GTEST_SKIP() << "Skipping MaxPool2D_I8 tests because CK_ENABLE_INT8 is "
"not enabled";
}
}
};
TYPED_TEST_SUITE(TestMaxPool2dFwd, KernelTypes);
TYPED_TEST(TestMaxPool2dFwd, Test_Pool)
template <typename TType>
class MaxPool2D_F8 : public TestMaxPool2dFwd<TType>
{
// length, window_length, window_stride, window_dilation, left_pad, right_pad
this->params = {{{1, 1, 1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0}, {0, 0}},
{{2, 16, 64, 64}, {64, 64}, {1, 1}, {1, 1}, {0, 0}, {0, 0}},
{{2, 16, 64, 64}, {4, 4}, {4, 4}, {2, 2}, {0, 0}, {0, 0}},
{{2, 32, 30, 30}, {2, 2}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}};
protected:
void SetUp() override
{
if(!CK_ENABLE_FP8)
{
GTEST_SKIP() << "Skipping MaxPool2D_F8 tests because CK_ENABLE_FP8 is "
"not enabled";
}
}
};
TYPED_TEST_SUITE(MaxPool2D_F32, MaxPool2D_F32_Types);
TYPED_TEST_SUITE(MaxPool2D_F16, MaxPool2D_F16_Types);
TYPED_TEST_SUITE(MaxPool2D_BF16, MaxPool2D_BF16_Types);
TYPED_TEST_SUITE(MaxPool2D_I8, MaxPool2D_I8_Types);
TYPED_TEST_SUITE(MaxPool2D_F8, MaxPool2D_F8_Types);
this->Run();
}
TYPED_TEST(MaxPool2D_F32, MaxPool2D_I8_Test) { this->Run(); }
TYPED_TEST(MaxPool2D_F16, MaxPool2D_F16_Test) { this->Run(); }
TYPED_TEST(MaxPool2D_BF16, MaxPool2D_BF16_Test) { this->Run(); }
TYPED_TEST(MaxPool2D_I8, MaxPool2D_I8_Test) { this->Run(); }
TYPED_TEST(MaxPool2D_F8, MaxPool2D_F8_Test) { this->Run(); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment