"vscode:/vscode.git/clone" did not exist on "3dbba11945463a3ca28e5610d23b38225f8357d4"
Commit 6975cb8f authored by carlushuang's avatar carlushuang
Browse files

Merge remote-tracking branch 'origin/develop' into ck_tile/fav3_fwd_sept

parents 33aff2ef 6834e5ee
...@@ -2,5 +2,6 @@ set(DEVICE_MAXPOOL_BWD_INSTANCES) ...@@ -2,5 +2,6 @@ set(DEVICE_MAXPOOL_BWD_INSTANCES)
list(APPEND DEVICE_MAXPOOL_BWD_INSTANCES device_max_pool_bwd_f16_instance.cpp list(APPEND DEVICE_MAXPOOL_BWD_INSTANCES device_max_pool_bwd_f16_instance.cpp
device_max_pool_bwd_bf16_instance.cpp device_max_pool_bwd_bf16_instance.cpp
device_max_pool_bwd_f32_instance.cpp device_max_pool_bwd_f32_instance.cpp
device_max_pool_bwd_f8_instance.cpp
device_max_pool_bwd_int8_instance.cpp) device_max_pool_bwd_int8_instance.cpp)
add_instance_library(device_max_pool_bwd_instance ${DEVICE_MAXPOOL_BWD_INSTANCES}) add_instance_library(device_max_pool_bwd_instance ${DEVICE_MAXPOOL_BWD_INSTANCES})
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "max_pool_bwd_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_maxpool_bwd_f8_instances(
std::vector<std::unique_ptr<DeviceMaxPoolBwd<F8, I32, F8>>>& instances)
{
add_device_operation_instances(instances, device_maxpool_bwd_instances<F8, I32, F8>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
set(FMHA_CPP_FOLDER ${CMAKE_CURRENT_BINARY_DIR}) set(FMHA_CPP_FOLDER ${CMAKE_CURRENT_BINARY_DIR})
set(FMHA_SRC_FOLDER ${CMAKE_SOURCE_DIR}/example/ck_tile/01_fmha/) set(FMHA_SRC_FOLDER ${CMAKE_SOURCE_DIR}/example/ck_tile/01_fmha/)
set(CK_TILE_SRC_FOLDER ${CMAKE_SOURCE_DIR}/include/ck_tile/) set(CK_TILE_SRC_FOLDER ${CMAKE_SOURCE_DIR}/include/ck_tile/)
# python stuff
# Usage: for customized Python location cmake -DCK_USE_ALTERNATIVE_PYTHON="/opt/Python-3.8.13/bin/python3.8"
# CK Codegen requires dataclass which is added in Python 3.7
# Python version 3.8 is required for general good practice as it is default for Ubuntu 20.04
if(NOT CK_USE_ALTERNATIVE_PYTHON) if(NOT CK_USE_ALTERNATIVE_PYTHON)
find_package(PythonInterp 3 REQUIRED) find_package(PythonInterp 3 REQUIRED)
else() else()
message("Using alternative python version") message("Using alternative python version")
set(EXTRA_PYTHON_PATH) set(EXTRA_PYTHON_PATH)
string(REPLACE "/bin/python3.8" "" EXTRA_PYTHON_PATH "${CK_USE_ALTERNATIVE_PYTHON}") # this is overly restrictive, we may need to be more flexible on the following
message("alternative python path is: ${EXTRA_PYTHON_PATH}") string(REPLACE "/bin/python3.8" "" EXTRA_PYTHON_PATH "${CK_USE_ALTERNATIVE_PYTHON}")
find_package(Python3 3.6 COMPONENTS Interpreter REQUIRED) message("alternative python path is: ${EXTRA_PYTHON_PATH}")
add_definitions(-DPython3_EXECUTABLE="${CK_USE_ALTERNATIVE_PYTHON}") find_package(Python3 3.6 COMPONENTS Interpreter REQUIRED)
set(Python3_EXECUTABLE "${CK_USE_ALTERNATIVE_PYTHON}") add_definitions(-DPython3_EXECUTABLE="${CK_USE_ALTERNATIVE_PYTHON}")
set(PYTHON_EXECUTABLE "${CK_USE_ALTERNATIVE_PYTHON}") set(Python3_EXECUTABLE "${CK_USE_ALTERNATIVE_PYTHON}")
set(ENV{LD_LIBRARY_PATH} "${EXTRA_PYTHON_PATH}/lib:$ENV{LD_LIBRARY_PATH}") set(PYTHON_EXECUTABLE "${CK_USE_ALTERNATIVE_PYTHON}")
set(ENV{LD_LIBRARY_PATH} "${EXTRA_PYTHON_PATH}/lib:$ENV{LD_LIBRARY_PATH}")
endif() endif()
rocm_install(DIRECTORY ${CK_TILE_SRC_FOLDER} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck_tile) rocm_install(DIRECTORY ${CK_TILE_SRC_FOLDER} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck_tile)
...@@ -23,14 +27,24 @@ rocm_install(FILES ${MHA_HEADERS} DESTINATION include/ck_tile/ops) ...@@ -23,14 +27,24 @@ rocm_install(FILES ${MHA_HEADERS} DESTINATION include/ck_tile/ops)
# headers for building lib # headers for building lib
file(COPY ${MHA_HEADERS} DESTINATION ${FMHA_CPP_FOLDER}) file(COPY ${MHA_HEADERS} DESTINATION ${FMHA_CPP_FOLDER})
# Delete the blob file if it exists to avoid append of old content.
if(EXISTS ${FMHA_CPP_FOLDER}/blob_list.txt)
file(REMOVE ${FMHA_CPP_FOLDER}/blob_list.txt)
endif()
# generate a list of kernels, but not actually emit files at config stage # generate a list of kernels, but not actually emit files at config stage
execute_process( execute_process(
COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py
--list_blobs ${FMHA_CPP_FOLDER}/blob_list.txt --list_blobs ${FMHA_CPP_FOLDER}/blob_list.txt
RESULT_VARIABLE ret
) )
file(STRINGS ${FMHA_CPP_FOLDER}/blob_list.txt FMHA_FWD_GEN_BLOBS) if(ret AND NOT ret EQUAL 0)
message( FATAL_ERROR "CK Tile MHA FAILED to genrate a list of kernels via Python.")
else()
file(STRINGS ${FMHA_CPP_FOLDER}/blob_list.txt FMHA_FWD_GEN_BLOBS)
endif()
# actually generate the cpp files # actually generate the kernel content now
add_custom_command( add_custom_command(
OUTPUT ${FMHA_FWD_GEN_BLOBS} OUTPUT ${FMHA_FWD_GEN_BLOBS}
COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py
...@@ -52,8 +66,6 @@ add_custom_target(generate_cpp_files DEPENDS ${FMHA_FWD_GEN_BLOBS}) ...@@ -52,8 +66,6 @@ add_custom_target(generate_cpp_files DEPENDS ${FMHA_FWD_GEN_BLOBS})
add_instance_library(device_mha_instance ${device_files}) add_instance_library(device_mha_instance ${device_files})
if (TARGET device_mha_instance) if (TARGET device_mha_instance)
add_dependencies(device_mha_instance generate_cpp_files) add_dependencies(device_mha_instance generate_cpp_files)
endif() endif()
......
...@@ -4,5 +4,9 @@ list(APPEND DEVICE_POOL2D_FWD_INSTANCES device_avg_pool2d_fwd_nhwc_f16_instance. ...@@ -4,5 +4,9 @@ list(APPEND DEVICE_POOL2D_FWD_INSTANCES device_avg_pool2d_fwd_nhwc_f16_instance.
device_avg_pool2d_fwd_nhwc_f32_instance.cpp device_avg_pool2d_fwd_nhwc_f32_instance.cpp
device_max_pool2d_fwd_nhwc_f32_instance.cpp device_max_pool2d_fwd_nhwc_f32_instance.cpp
device_avg_pool2d_fwd_nhwc_bf16_instance.cpp device_avg_pool2d_fwd_nhwc_bf16_instance.cpp
device_max_pool2d_fwd_nhwc_bf16_instance.cpp) device_max_pool2d_fwd_nhwc_bf16_instance.cpp
device_avg_pool2d_fwd_nhwc_i8_instance.cpp
device_max_pool2d_fwd_nhwc_i8_instance.cpp
device_avg_pool2d_fwd_nhwc_f8_instance.cpp
device_max_pool2d_fwd_nhwc_f8_instance.cpp)
add_instance_library(device_pool2d_fwd_instance ${DEVICE_POOL2D_FWD_INSTANCES}) add_instance_library(device_pool2d_fwd_instance ${DEVICE_POOL2D_FWD_INSTANCES})
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "pool2d_fwd_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
void add_device_pool2d_fwd_nhwc_f8_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F8, F8, I32, NHWC, NHWC, ReduceOpId, false>>>&
instances)
{
add_device_operation_instances(
instances, device_pool2d_fwd_nhwc_instances<F8, F8, I32, F32, ReduceOpId, false>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "pool2d_fwd_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
void add_device_pool2d_fwd_nhwc_i8_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, I8, I8, I32, NHWC, NHWC, ReduceOpId, false>>>&
instances)
{
add_device_operation_instances(
instances, device_pool2d_fwd_nhwc_instances<I8, I8, I32, F32, ReduceOpId, false>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "pool2d_fwd_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
void add_device_pool2d_fwd_nhwc_f8_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F8, F8, I32, NHWC, NHWC, ReduceOpId, false>>>&
instances)
{
add_device_operation_instances(
instances, device_pool2d_fwd_nhwc_instances<F8, F8, I32, F32, ReduceOpId, false>{});
}
void add_device_pool2d_fwd_nhwc_index_f8_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F8, F8, I32, NHWC, NHWC, ReduceOpId, true>>>&
instances)
{
add_device_operation_instances(
instances, device_pool2d_fwd_nhwc_instances<F8, F8, I32, F32, ReduceOpId, true>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "pool2d_fwd_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
void add_device_pool2d_fwd_nhwc_i8_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, I8, I8, I32, NHWC, NHWC, ReduceOpId, false>>>&
instances)
{
add_device_operation_instances(
instances, device_pool2d_fwd_nhwc_instances<I8, I8, I32, F32, ReduceOpId, false>{});
}
void add_device_pool2d_fwd_nhwc_index_i8_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, I8, I8, I32, NHWC, NHWC, ReduceOpId, true>>>&
instances)
{
add_device_operation_instances(
instances, device_pool2d_fwd_nhwc_instances<I8, I8, I32, F32, ReduceOpId, true>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -15,9 +15,11 @@ namespace device { ...@@ -15,9 +15,11 @@ namespace device {
namespace instance { namespace instance {
using I32 = int32_t; using I32 = int32_t;
using F32 = float;
using F16 = ck::half_t; using F16 = ck::half_t;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using F32 = float; using I8 = int8_t;
using F8 = ck::f8_t;
using NHWC = ck::tensor_layout::convolution::NHWC; using NHWC = ck::tensor_layout::convolution::NHWC;
template <typename InDataType, template <typename InDataType,
......
...@@ -49,9 +49,18 @@ struct maxPoolFwdArgParser ...@@ -49,9 +49,18 @@ struct maxPoolFwdArgParser
} }
}; };
enum struct PoolDataType
{
F32 = 0,
BF16,
F16,
INT8,
F8,
};
void print_help_max_pool2d_fwd() void print_help_max_pool2d_fwd()
{ {
std::cout << "arg1: data type (0: fp16; 1: fp32; 5: bf16)\n" std::cout << "arg1: data type (0: fp16; 1: fp32; 2: bf16; 3: int8; 4: fp8)\n"
<< "arg2: verification (0: no; 1: yes)\n" << "arg2: verification (0: no; 1: yes)\n"
<< "arg3: initialization (0: no init; 1: integer value; 2: decimal value)\n" << "arg3: initialization (0: no init; 1: integer value; 2: decimal value)\n"
<< "arg4: print tensor value (0: no; 1: yes)\n" << "arg4: print tensor value (0: no; 1: yes)\n"
...@@ -70,12 +79,12 @@ void print_help_max_pool2d_fwd() ...@@ -70,12 +79,12 @@ void print_help_max_pool2d_fwd()
int profile_max_pool2d_fwd(int argc, char* argv[]) int profile_max_pool2d_fwd(int argc, char* argv[])
{ {
ck::DataTypeEnum data_type = ck::DataTypeEnum::Half; PoolDataType data_type = PoolDataType::F32;
bool do_verification = true; bool do_verification = true;
int init_method = 0; int init_method = 0;
bool do_log = false; bool do_log = false;
bool time_kernel = true; bool time_kernel = true;
bool return_index = false; bool return_index = false;
std::vector<index_t> in_length = {2, 32, 30, 30}; std::vector<index_t> in_length = {2, 32, 30, 30};
std::vector<index_t> wsize = {2, 2}; std::vector<index_t> wsize = {2, 2};
...@@ -91,7 +100,7 @@ int profile_max_pool2d_fwd(int argc, char* argv[]) ...@@ -91,7 +100,7 @@ int profile_max_pool2d_fwd(int argc, char* argv[])
} }
else if(argc == 28) else if(argc == 28)
{ {
data_type = static_cast<ck::DataTypeEnum>(std::stoi(argv[2])); data_type = static_cast<PoolDataType>(std::stoi(argv[2]));
do_verification = std::stoi(argv[3]); do_verification = std::stoi(argv[3]);
init_method = std::stoi(argv[4]); init_method = std::stoi(argv[4]);
do_log = std::stoi(argv[5]); do_log = std::stoi(argv[5]);
...@@ -113,11 +122,13 @@ int profile_max_pool2d_fwd(int argc, char* argv[]) ...@@ -113,11 +122,13 @@ int profile_max_pool2d_fwd(int argc, char* argv[])
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using F32 = float; using F32 = float;
using I32 = int32_t; using I32 = int32_t;
using F8 = ck::f8_t;
using I8 = int8_t;
using NHWC = ck::tensor_layout::convolution::NHWC; using NHWC = ck::tensor_layout::convolution::NHWC;
constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX; constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
if(data_type == ck::DataTypeEnum::Half) if(data_type == PoolDataType::F16)
{ {
if(return_index) if(return_index)
{ {
...@@ -150,7 +161,7 @@ int profile_max_pool2d_fwd(int argc, char* argv[]) ...@@ -150,7 +161,7 @@ int profile_max_pool2d_fwd(int argc, char* argv[])
pad2); pad2);
} }
} }
else if(data_type == ck::DataTypeEnum::BFloat16) else if(data_type == PoolDataType::BF16)
{ {
if(return_index) if(return_index)
{ {
...@@ -189,7 +200,7 @@ int profile_max_pool2d_fwd(int argc, char* argv[]) ...@@ -189,7 +200,7 @@ int profile_max_pool2d_fwd(int argc, char* argv[])
pad2); pad2);
} }
} }
else if(data_type == ck::DataTypeEnum::Float) else if(data_type == PoolDataType::F32)
{ {
if(return_index) if(return_index)
{ {
...@@ -222,6 +233,72 @@ int profile_max_pool2d_fwd(int argc, char* argv[]) ...@@ -222,6 +233,72 @@ int profile_max_pool2d_fwd(int argc, char* argv[])
pad2); pad2);
} }
} }
else if(data_type == PoolDataType::INT8)
{
if(return_index)
{
ck::profiler::
profile_pool2d_fwd_impl<I8, I8, F32, I32, NHWC, NHWC, ReduceOpId, false, true>(
do_verification,
init_method,
do_log,
time_kernel,
in_length,
wsize,
wstride,
wdilation,
pad1,
pad2);
}
else
{
ck::profiler::
profile_pool2d_fwd_impl<I8, I8, F32, I32, NHWC, NHWC, ReduceOpId, false, false>(
do_verification,
init_method,
do_log,
time_kernel,
in_length,
wsize,
wstride,
wdilation,
pad1,
pad2);
}
}
else if(data_type == PoolDataType::F8)
{
if(return_index)
{
ck::profiler::
profile_pool2d_fwd_impl<F8, F8, F32, I32, NHWC, NHWC, ReduceOpId, false, true>(
do_verification,
init_method,
do_log,
time_kernel,
in_length,
wsize,
wstride,
wdilation,
pad1,
pad2);
}
else
{
ck::profiler::
profile_pool2d_fwd_impl<F8, F8, F32, I32, NHWC, NHWC, ReduceOpId, false, false>(
do_verification,
init_method,
do_log,
time_kernel,
in_length,
wsize,
wstride,
wdilation,
pad1,
pad2);
}
}
else else
{ {
throw std::runtime_error("not implemented yet"); throw std::runtime_error("not implemented yet");
......
...@@ -14,13 +14,12 @@ class TestAvgPool2dFwd : public ::testing::Test ...@@ -14,13 +14,12 @@ class TestAvgPool2dFwd : public ::testing::Test
using ComputeDataType = std::tuple_element_t<2, Tuple>; using ComputeDataType = std::tuple_element_t<2, Tuple>;
using IndexDataType = std::tuple_element_t<3, Tuple>; using IndexDataType = std::tuple_element_t<3, Tuple>;
std::vector<PoolingParam> params; static std::vector<PoolingParam> params;
void Run() void Run()
{ {
for(auto param : params) for(auto param : params)
{ {
// avg pool
bool success = bool success =
ck::profiler::profile_pool2d_fwd_impl<InDataType, ck::profiler::profile_pool2d_fwd_impl<InDataType,
OutDataType, OutDataType,
...@@ -45,24 +44,102 @@ class TestAvgPool2dFwd : public ::testing::Test ...@@ -45,24 +44,102 @@ class TestAvgPool2dFwd : public ::testing::Test
} }
}; };
using KernelTypes = std::conditional_t< template <typename T>
CK_ENABLE_FP16 && CK_ENABLE_BF16, std::vector<PoolingParam> TestAvgPool2dFwd<T>::params = {
::testing::Types<std::tuple<F16, F16, F32, I32>, {{{1, 1, 1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0}, {0, 0}},
std::tuple<F16, F16, F32, I32>, {{2, 16, 64, 64}, {64, 64}, {1, 1}, {1, 1}, {0, 0}, {0, 0}},
std::tuple<BF16, BF16, F32, I32>, {{2, 16, 64, 64}, {4, 4}, {4, 4}, {2, 2}, {0, 0}, {0, 0}},
std::tuple<BF16, BF16, F32, I32>, {{2, 32, 30, 30}, {2, 2}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}}};
std::tuple<F32, F32, F32, I32>,
std::tuple<F32, F32, F32, I32>>,
::testing::Types<std::tuple<F32, F32, F32, I32>, std::tuple<F32, F32, F32, I32>>>;
TYPED_TEST_SUITE(TestAvgPool2dFwd, KernelTypes); using AvgPool2D_F32_Types =
TYPED_TEST(TestAvgPool2dFwd, Test_Pool) ::testing::Types<std::tuple<F32, F32, F32, I32>, std::tuple<F32, F32, F32, I32>>;
using AvgPool2D_F16_Types =
::testing::Types<std::tuple<F16, F16, F32, I32>, std::tuple<F16, F16, F32, I32>>;
using AvgPool2D_BF16_Types =
::testing::Types<std::tuple<I8, I8, F32, I32>, std::tuple<BF16, BF16, F32, I32>>;
using AvgPool2D_I8_Types =
::testing::Types<std::tuple<I8, I8, F32, I32>, std::tuple<I8, I8, F32, I32>>;
using AvgPool2D_F8_Types =
::testing::Types<std::tuple<F8, F8, F32, I32>, std::tuple<F8, F8, F32, I32>>;
template <typename TType>
class AvgPool2D_F32 : public TestAvgPool2dFwd<TType>
{
protected:
void SetUp() override
{
if(!CK_ENABLE_FP32)
{
GTEST_SKIP() << "Skipping AvgPool2D_F32 tests because CK_ENABLE_FP32 is "
"not enabled";
}
}
};
template <typename TType>
class AvgPool2D_F16 : public TestAvgPool2dFwd<TType>
{
protected:
void SetUp() override
{
if(!CK_ENABLE_FP16)
{
GTEST_SKIP() << "Skipping AvgPool2D_F16 tests because CK_ENABLE_FP16 is "
"not enabled";
}
}
};
template <typename TType>
class AvgPool2D_BF16 : public TestAvgPool2dFwd<TType>
{
protected:
void SetUp() override
{
if(!CK_ENABLE_BF16)
{
GTEST_SKIP() << "Skipping AvgPool2D_BF16 tests because CK_ENABLE_BF16 is "
"not enabled";
}
}
};
template <typename TType>
class AvgPool2D_I8 : public TestAvgPool2dFwd<TType>
{
protected:
void SetUp() override
{
if(!CK_ENABLE_INT8)
{
GTEST_SKIP() << "Skipping AvgPool2D_I8 tests because CK_ENABLE_INT8 is "
"not enabled";
}
}
};
template <typename TType>
class AvgPool2D_F8 : public TestAvgPool2dFwd<TType>
{ {
// length, window_length, window_stride, window_dilation, left_pad, right_pad protected:
this->params = {{{1, 1, 1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}, void SetUp() override
{{2, 16, 64, 64}, {64, 64}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}, {
{{2, 16, 64, 64}, {4, 4}, {4, 4}, {2, 2}, {0, 0}, {0, 0}}, if(!CK_ENABLE_FP8)
{{2, 32, 30, 30}, {2, 2}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}}; {
GTEST_SKIP() << "Skipping AvgPool2D_F8 tests because CK_ENABLE_FP8 is "
"not enabled";
}
}
};
TYPED_TEST_SUITE(AvgPool2D_F32, AvgPool2D_F32_Types);
TYPED_TEST_SUITE(AvgPool2D_F16, AvgPool2D_F16_Types);
TYPED_TEST_SUITE(AvgPool2D_BF16, AvgPool2D_BF16_Types);
TYPED_TEST_SUITE(AvgPool2D_I8, AvgPool2D_I8_Types);
TYPED_TEST_SUITE(AvgPool2D_F8, AvgPool2D_F8_Types);
this->Run(); TYPED_TEST(AvgPool2D_F32, AvgPool2D_I8_Test) { this->Run(); }
} TYPED_TEST(AvgPool2D_F16, AvgPool2D_F16_Test) { this->Run(); }
TYPED_TEST(AvgPool2D_BF16, AvgPool2D_BF16_Test) { this->Run(); }
TYPED_TEST(AvgPool2D_I8, AvgPool2D_I8_Test) { this->Run(); }
TYPED_TEST(AvgPool2D_F8, AvgPool2D_F8_Test) { this->Run(); }
...@@ -55,6 +55,7 @@ using Max_Pool_2D_f32_types = ::testing::Types<std::tuple<F32, F32, I32>>; ...@@ -55,6 +55,7 @@ using Max_Pool_2D_f32_types = ::testing::Types<std::tuple<F32, F32, I32>>;
using Max_Pool_2D_int8_types = ::testing::Types<std::tuple<I8, I8, I32>>; using Max_Pool_2D_int8_types = ::testing::Types<std::tuple<I8, I8, I32>>;
using Max_Pool_2D_f16_types = ::testing::Types<std::tuple<F16, F16, I32>>; using Max_Pool_2D_f16_types = ::testing::Types<std::tuple<F16, F16, I32>>;
using Max_Pool_2D_bf16_types = ::testing::Types<std::tuple<BF16, BF16, I32>>; using Max_Pool_2D_bf16_types = ::testing::Types<std::tuple<BF16, BF16, I32>>;
using Max_Pool_2D_f8_types = ::testing::Types<std::tuple<F8, F8, I32>>;
template <typename TType> template <typename TType>
class MaxPool2D_f32 : public MaxPool2dBWDTest<TType> class MaxPool2D_f32 : public MaxPool2dBWDTest<TType>
...@@ -108,10 +109,24 @@ class MaxPool2D_bf16 : public MaxPool2dBWDTest<TType> ...@@ -108,10 +109,24 @@ class MaxPool2D_bf16 : public MaxPool2dBWDTest<TType>
} }
}; };
template <typename TType>
class MaxPool2D_f8 : public MaxPool2dBWDTest<TType>
{
protected:
void SetUp() override
{
if(!CK_ENABLE_FP8)
{
GTEST_SKIP() << "Skipping MaxPool2D_f8 tests because CK_ENABLE_FP8 is not enabled";
}
}
};
TYPED_TEST_SUITE(MaxPool2D_f32, Max_Pool_2D_f32_types); TYPED_TEST_SUITE(MaxPool2D_f32, Max_Pool_2D_f32_types);
TYPED_TEST_SUITE(MaxPool2D_int8, Max_Pool_2D_int8_types); TYPED_TEST_SUITE(MaxPool2D_int8, Max_Pool_2D_int8_types);
TYPED_TEST_SUITE(MaxPool2D_f16, Max_Pool_2D_f16_types); TYPED_TEST_SUITE(MaxPool2D_f16, Max_Pool_2D_f16_types);
TYPED_TEST_SUITE(MaxPool2D_bf16, Max_Pool_2D_bf16_types); TYPED_TEST_SUITE(MaxPool2D_bf16, Max_Pool_2D_bf16_types);
TYPED_TEST_SUITE(MaxPool2D_f8, Max_Pool_2D_f8_types);
TYPED_TEST(MaxPool2D_f32, MaxPool2DTest_f32) { this->Run(); } TYPED_TEST(MaxPool2D_f32, MaxPool2DTest_f32) { this->Run(); }
...@@ -120,3 +135,5 @@ TYPED_TEST(MaxPool2D_int8, MaxPool2DTest_int8) { this->Run(); } ...@@ -120,3 +135,5 @@ TYPED_TEST(MaxPool2D_int8, MaxPool2DTest_int8) { this->Run(); }
TYPED_TEST(MaxPool2D_f16, MaxPool2DTest_f16) { this->Run(); } TYPED_TEST(MaxPool2D_f16, MaxPool2DTest_f16) { this->Run(); }
TYPED_TEST(MaxPool2D_bf16, MaxPool2DTest_bf16) { this->Run(); } TYPED_TEST(MaxPool2D_bf16, MaxPool2DTest_bf16) { this->Run(); }
TYPED_TEST(MaxPool2D_f8, MaxPool2DTest_f8) { this->Run(); }
...@@ -15,7 +15,7 @@ class TestMaxPool2dFwd : public ::testing::Test ...@@ -15,7 +15,7 @@ class TestMaxPool2dFwd : public ::testing::Test
using IndexDataType = std::tuple_element_t<3, Tuple>; using IndexDataType = std::tuple_element_t<3, Tuple>;
static constexpr bool ReturnIndex = std::tuple_element_t<4, Tuple>::value; static constexpr bool ReturnIndex = std::tuple_element_t<4, Tuple>::value;
std::vector<PoolingParam> params; static std::vector<PoolingParam> params;
void Run() void Run()
{ {
...@@ -46,27 +46,105 @@ class TestMaxPool2dFwd : public ::testing::Test ...@@ -46,27 +46,105 @@ class TestMaxPool2dFwd : public ::testing::Test
} }
}; };
template <typename T>
std::vector<PoolingParam> TestMaxPool2dFwd<T>::params = {
{{{1, 1, 1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0}, {0, 0}},
{{2, 16, 64, 64}, {64, 64}, {1, 1}, {1, 1}, {0, 0}, {0, 0}},
{{2, 16, 64, 64}, {4, 4}, {4, 4}, {2, 2}, {0, 0}, {0, 0}},
{{2, 32, 30, 30}, {2, 2}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}}};
using true_t = std::integral_constant<bool, true>; using true_t = std::integral_constant<bool, true>;
using false_t = std::integral_constant<bool, false>; using false_t = std::integral_constant<bool, false>;
using KernelTypes = std::conditional_t<CK_ENABLE_FP16 && CK_ENABLE_BF16, using MaxPool2D_F32_Types = ::testing::Types<std::tuple<F32, F32, F32, I32, true_t>,
::testing::Types<std::tuple<F16, F16, F32, I32, true_t>, std::tuple<F32, F32, F32, I32, false_t>>;
std::tuple<F16, F16, F32, I32, false_t>, using MaxPool2D_F16_Types = ::testing::Types<std::tuple<F16, F16, F32, I32, true_t>,
std::tuple<BF16, BF16, F32, I32, true_t>, std::tuple<F16, F16, F32, I32, false_t>>;
std::tuple<BF16, BF16, F32, I32, false_t>, using MaxPool2D_BF16_Types = ::testing::Types<std::tuple<I8, I8, F32, I32, true_t>,
std::tuple<F32, F32, F32, I32, true_t>, std::tuple<BF16, BF16, F32, I32, false_t>>;
std::tuple<F32, F32, F32, I32, false_t>>, using MaxPool2D_I8_Types =
::testing::Types<std::tuple<F32, F32, F32, I32, true_t>, ::testing::Types<std::tuple<I8, I8, F32, I32, true_t>, std::tuple<I8, I8, F32, I32, false_t>>;
std::tuple<F32, F32, F32, I32, false_t>>>; using MaxPool2D_F8_Types =
::testing::Types<std::tuple<F8, F8, F32, I32, true_t>, std::tuple<F8, F8, F32, I32, false_t>>;
template <typename TType>
class MaxPool2D_F32 : public TestMaxPool2dFwd<TType>
{
protected:
void SetUp() override
{
if(!CK_ENABLE_FP32)
{
GTEST_SKIP() << "Skipping MaxPool2D_F32 tests because CK_ENABLE_FP32 is "
"not enabled";
}
}
};
template <typename TType>
class MaxPool2D_F16 : public TestMaxPool2dFwd<TType>
{
protected:
void SetUp() override
{
if(!CK_ENABLE_FP16)
{
GTEST_SKIP() << "Skipping MaxPool2D_F16 tests because CK_ENABLE_FP16 is "
"not enabled";
}
}
};
template <typename TType>
class MaxPool2D_BF16 : public TestMaxPool2dFwd<TType>
{
protected:
void SetUp() override
{
if(!CK_ENABLE_BF16)
{
GTEST_SKIP() << "Skipping MaxPool2D_BF16 tests because CK_ENABLE_BF16 is "
"not enabled";
}
}
};
template <typename TType>
class MaxPool2D_I8 : public TestMaxPool2dFwd<TType>
{
protected:
void SetUp() override
{
if(!CK_ENABLE_INT8)
{
GTEST_SKIP() << "Skipping MaxPool2D_I8 tests because CK_ENABLE_INT8 is "
"not enabled";
}
}
};
TYPED_TEST_SUITE(TestMaxPool2dFwd, KernelTypes); template <typename TType>
TYPED_TEST(TestMaxPool2dFwd, Test_Pool) class MaxPool2D_F8 : public TestMaxPool2dFwd<TType>
{ {
// length, window_length, window_stride, window_dilation, left_pad, right_pad protected:
this->params = {{{1, 1, 1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}, void SetUp() override
{{2, 16, 64, 64}, {64, 64}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}, {
{{2, 16, 64, 64}, {4, 4}, {4, 4}, {2, 2}, {0, 0}, {0, 0}}, if(!CK_ENABLE_FP8)
{{2, 32, 30, 30}, {2, 2}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}}; {
GTEST_SKIP() << "Skipping MaxPool2D_F8 tests because CK_ENABLE_FP8 is "
"not enabled";
}
}
};
TYPED_TEST_SUITE(MaxPool2D_F32, MaxPool2D_F32_Types);
TYPED_TEST_SUITE(MaxPool2D_F16, MaxPool2D_F16_Types);
TYPED_TEST_SUITE(MaxPool2D_BF16, MaxPool2D_BF16_Types);
TYPED_TEST_SUITE(MaxPool2D_I8, MaxPool2D_I8_Types);
TYPED_TEST_SUITE(MaxPool2D_F8, MaxPool2D_F8_Types);
this->Run(); TYPED_TEST(MaxPool2D_F32, MaxPool2D_I8_Test) { this->Run(); }
} TYPED_TEST(MaxPool2D_F16, MaxPool2D_F16_Test) { this->Run(); }
TYPED_TEST(MaxPool2D_BF16, MaxPool2D_BF16_Test) { this->Run(); }
TYPED_TEST(MaxPool2D_I8, MaxPool2D_I8_Test) { this->Run(); }
TYPED_TEST(MaxPool2D_F8, MaxPool2D_F8_Test) { this->Run(); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment