Commit c5ad2e80 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents 4b798833 489c78d0
...@@ -56,7 +56,7 @@ class TestGemmUniversal_KM_NK ...@@ -56,7 +56,7 @@ class TestGemmUniversal_KM_NK
using KernelTypes_MK_KN = ::testing::Types< using KernelTypes_MK_KN = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType // ADataType, BDataType, ComputeDataType, CDataType
std::tuple< F16, F16, F16, F16>, std::tuple< F16, F16, F16, F16>,
#if (defined CK_ENABLE_FP8) #if defined(CK_ENABLE_FP8) && defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
std::tuple< F16, F8, F16, F16>, std::tuple< F16, F8, F16, F16>,
std::tuple< F8, F16, F16, F16>, std::tuple< F8, F16, F16, F16>,
std::tuple< F8, F8, F8, BF16>, std::tuple< F8, F8, F8, BF16>,
...@@ -66,7 +66,7 @@ using KernelTypes_MK_KN = ::testing::Types< ...@@ -66,7 +66,7 @@ using KernelTypes_MK_KN = ::testing::Types<
using KernelTypes_MK_NK = ::testing::Types< using KernelTypes_MK_NK = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType // ADataType, BDataType, ComputeDataType, CDataType
std::tuple< F16, F16, F16, F16>, std::tuple< F16, F16, F16, F16>,
#if (defined CK_ENABLE_FP8) #if defined(CK_ENABLE_FP8) && defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
std::tuple< F16, F8, F16, F16>, std::tuple< F16, F8, F16, F16>,
std::tuple< F8, F16, F16, F16>, std::tuple< F8, F16, F16, F16>,
std::tuple< F8, F8, F8, BF16>, std::tuple< F8, F8, F8, BF16>,
......
...@@ -58,13 +58,13 @@ using KernelTypes1d = ::testing::Types<std::tuple<float, GNWC, GKXC, GNWK>, ...@@ -58,13 +58,13 @@ using KernelTypes1d = ::testing::Types<std::tuple<float, GNWC, GKXC, GNWK>,
using KernelTypes2d = ::testing::Types<std::tuple<float, GNHWC, GKYXC, GNHWK>, using KernelTypes2d = ::testing::Types<std::tuple<float, GNHWC, GKYXC, GNHWK>,
std::tuple<ck::half_t, GNHWC, GKYXC, GNHWK>, std::tuple<ck::half_t, GNHWC, GKYXC, GNHWK>,
std::tuple<ck::bhalf_t, GNHWC, GKYXC, GNHWK>, std::tuple<ck::bhalf_t, GNHWC, GKYXC, GNHWK>,
std::tuple<int8_t, GNHWC, GKYXC, GNHWK>,
std::tuple<float, NHWGC, GKYXC, NHWGK>, std::tuple<float, NHWGC, GKYXC, NHWGK>,
std::tuple<ck::half_t, NHWGC, GKYXC, NHWGK>, std::tuple<ck::half_t, NHWGC, GKYXC, NHWGK>,
std::tuple<ck::bhalf_t, NHWGC, GKYXC, NHWGK>, std::tuple<ck::bhalf_t, NHWGC, GKYXC, NHWGK>,
std::tuple<int8_t, NHWGC, GKYXC, NHWGK>, std::tuple<int8_t, NHWGC, GKYXC, NHWGK>,
std::tuple<float, NGCHW, GKYXC, NGKHW>, std::tuple<float, NGCHW, GKYXC, NGKHW>,
std::tuple<ck::half_t, NGCHW, GKYXC, NGKHW>>; std::tuple<ck::half_t, NGCHW, GKYXC, NGKHW>,
std::tuple<int8_t, NGCHW, GKYXC, NGKHW>>;
using KernelTypes3d = ::testing::Types<std::tuple<float, GNDHWC, GKZYXC, GNDHWK>, using KernelTypes3d = ::testing::Types<std::tuple<float, GNDHWC, GKZYXC, GNDHWK>,
std::tuple<ck::half_t, GNDHWC, GKZYXC, GNDHWK>, std::tuple<ck::half_t, GNDHWC, GKZYXC, GNDHWK>,
......
...@@ -52,7 +52,8 @@ using namespace ck::tensor_layout::convolution; ...@@ -52,7 +52,8 @@ using namespace ck::tensor_layout::convolution;
using KernelTypes2d = ::testing::Types<std::tuple<float, NHWGC, GKYXC, NHWGK>, using KernelTypes2d = ::testing::Types<std::tuple<float, NHWGC, GKYXC, NHWGK>,
std::tuple<ck::half_t, NHWGC, GKYXC, NHWGK>, std::tuple<ck::half_t, NHWGC, GKYXC, NHWGK>,
std::tuple<ck::bhalf_t, NHWGC, GKYXC, NHWGK>>; std::tuple<ck::bhalf_t, NHWGC, GKYXC, NHWGK>,
std::tuple<int8_t, NHWGC, GKYXC, NHWGK>>;
using KernelTypes3d = ::testing::Types<std::tuple<float, NDHWGC, GKZYXC, NDHWGK>, using KernelTypes3d = ::testing::Types<std::tuple<float, NDHWGC, GKZYXC, NDHWGK>,
std::tuple<ck::half_t, NDHWGC, GKZYXC, NDHWGK>, std::tuple<ck::half_t, NDHWGC, GKZYXC, NDHWGK>,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment