Commit 5929e503 authored by Haocong WANG's avatar Haocong WANG
Browse files

Merge branch 'jizhan/enable_bf16_atomic_add' of...

Merge branch 'jizhan/enable_bf16_atomic_add' of https://github.com/zjing14/composable_kernel into gemm_opt_rocm6.2
parents c2e13d9b a6390bbe
......@@ -17,6 +17,7 @@ class TestGroupedConvndFwd : public ::testing::Test
using InLayout = std::tuple_element_t<1, Tuple>;
using WeiLayout = std::tuple_element_t<2, Tuple>;
using OutLayout = std::tuple_element_t<3, Tuple>;
using IndexType = std::tuple_element_t<4, Tuple>;
std::vector<ck::utils::conv::ConvParam> conv_params;
......@@ -33,7 +34,10 @@ class TestGroupedConvndFwd : public ::testing::Test
OutLayout,
DataType,
DataType,
DataType>(
DataType,
DataType,
DataType,
IndexType>(
true, // do_verification
1, // init_method: integer value
false, // do_log
......@@ -46,30 +50,31 @@ class TestGroupedConvndFwd : public ::testing::Test
using namespace ck::tensor_layout::convolution;
using KernelTypes1d = ::testing::Types<std::tuple<float, GNWC, GKXC, GNWK>,
std::tuple<ck::half_t, GNWC, GKXC, GNWK>,
std::tuple<ck::bhalf_t, GNWC, GKXC, GNWK>,
std::tuple<int8_t, GNWC, GKXC, GNWK>>;
using KernelTypes2d = ::testing::Types<std::tuple<float, GNHWC, GKYXC, GNHWK>,
std::tuple<ck::half_t, GNHWC, GKYXC, GNHWK>,
std::tuple<ck::bhalf_t, GNHWC, GKYXC, GNHWK>,
std::tuple<int8_t, GNHWC, GKYXC, GNHWK>,
std::tuple<float, NHWGC, GKYXC, NHWGK>,
std::tuple<ck::half_t, NHWGC, GKYXC, NHWGK>,
std::tuple<ck::bhalf_t, NHWGC, GKYXC, NHWGK>,
std::tuple<int8_t, NHWGC, GKYXC, NHWGK>>;
using KernelTypes3d = ::testing::Types<std::tuple<float, GNDHWC, GKZYXC, GNDHWK>,
std::tuple<ck::half_t, GNDHWC, GKZYXC, GNDHWK>,
std::tuple<ck::bhalf_t, GNDHWC, GKZYXC, GNDHWK>,
std::tuple<int8_t, GNDHWC, GKZYXC, GNDHWK>,
std::tuple<float, NDHWGC, GKZYXC, NDHWGK>,
std::tuple<ck::half_t, NDHWGC, GKZYXC, NDHWGK>,
std::tuple<ck::bhalf_t, NDHWGC, GKZYXC, NDHWGK>,
std::tuple<int8_t, NDHWGC, GKZYXC, NDHWGK>>;
using KernelTypes2dLargeCases = ::testing::Types<std::tuple<float, NHWGC, GKYXC, NHWGK>>;
using KernelTypes1d = ::testing::Types<std::tuple<float, GNWC, GKXC, GNWK, ck::index_t>,
std::tuple<ck::half_t, GNWC, GKXC, GNWK, ck::index_t>,
std::tuple<ck::bhalf_t, GNWC, GKXC, GNWK, ck::index_t>,
std::tuple<int8_t, GNWC, GKXC, GNWK, ck::index_t>>;
using KernelTypes2d = ::testing::Types<std::tuple<float, GNHWC, GKYXC, GNHWK, ck::index_t>,
std::tuple<ck::half_t, GNHWC, GKYXC, GNHWK, ck::index_t>,
std::tuple<ck::bhalf_t, GNHWC, GKYXC, GNHWK, ck::index_t>,
std::tuple<int8_t, GNHWC, GKYXC, GNHWK, ck::index_t>,
std::tuple<float, NHWGC, GKYXC, NHWGK, ck::index_t>,
std::tuple<ck::half_t, NHWGC, GKYXC, NHWGK, ck::index_t>,
std::tuple<ck::bhalf_t, NHWGC, GKYXC, NHWGK, ck::index_t>,
std::tuple<int8_t, NHWGC, GKYXC, NHWGK, ck::index_t>>;
using KernelTypes3d = ::testing::Types<std::tuple<float, GNDHWC, GKZYXC, GNDHWK, ck::index_t>,
std::tuple<ck::half_t, GNDHWC, GKZYXC, GNDHWK, ck::index_t>,
std::tuple<ck::bhalf_t, GNDHWC, GKZYXC, GNDHWK, ck::index_t>,
std::tuple<int8_t, GNDHWC, GKZYXC, GNDHWK, ck::index_t>,
std::tuple<float, NDHWGC, GKZYXC, NDHWGK, ck::index_t>,
std::tuple<ck::half_t, NDHWGC, GKZYXC, NDHWGK, ck::index_t>,
std::tuple<ck::bhalf_t, NDHWGC, GKZYXC, NDHWGK, ck::index_t>,
std::tuple<int8_t, NDHWGC, GKZYXC, NDHWGK, ck::index_t>>;
using KernelTypes2dLargeCases =
::testing::Types<std::tuple<float, NHWGC, GKYXC, NHWGK, ck::long_index_t>>;
template <typename Tuple>
class TestGroupedConvndFwd1d : public TestGroupedConvndFwd<Tuple>
......@@ -153,5 +158,8 @@ TYPED_TEST(TestGroupedConvndFwd2dLargeCases, Test2DLargeCases)
// With supported NumGroupsToMerge > 1
this->conv_params.push_back(
{2, 32, 64, 1, 1, {2, 2}, {672, 672}, {672, 672}, {1, 1}, {0, 0}, {0, 0}});
// When image is larger than 2GB
this->conv_params.push_back(
{2, 1, 1, 256, 256, {3, 3}, {4096, 2048}, {1024, 1024}, {3, 3}, {1, 1}, {1, 1}});
this->template Run<2>();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment