"examples/unconditional_image_generation/README.md" did not exist on "85244d4a5901cc5062562b15361b06fdbeebf528"
Commit 536c6651 authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Re-enable fp8 gemms for gfx94/95

parent a6422f3f
......@@ -185,6 +185,10 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx9")
add_definitions(-DCK_USE_XDL)
set(CK_USE_XDL "ON")
endif()
if (SUPPORTED_GPU_TARGETS MATCHES "gfx94" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95")
message("Enabling FP8 gemms on native architectures")
add_definitions(-DCK_USE_GFX94)
endif()
if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12")
message("Enabling WMMA instances")
add_definitions(-DCK_USE_WMMA)
......
......@@ -101,7 +101,7 @@ int profile_gemm_universal(int argc, char* argv[])
using F32 = float;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)
using F8 = ck::f8_t;
#endif
......@@ -164,7 +164,7 @@ int profile_gemm_universal(int argc, char* argv[])
{
return profile(F16{}, F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{});
}
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
return profile(F16{}, F8{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{});
......@@ -198,7 +198,7 @@ int profile_gemm_universal(int argc, char* argv[])
{
return profile(BF16{}, BF16{}, BF16{}, F32{}, BF16{}, Col{}, Row{}, Row{});
}
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)
else if(data_type == GemmDataType::F8_F8_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{
return profile(F8{}, F8{}, F8{}, F32{}, BF16{}, Row{}, Row{}, Row{});
......
......@@ -56,7 +56,7 @@ class TestGemmUniversal_KM_NK
using KernelTypes_MK_KN = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType
std::tuple< F16, F16, F16, F16>,
#if defined(CK_ENABLE_FP8) && defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
std::tuple< F16, F8, F16, F16>,
std::tuple< F8, F16, F16, F16>,
std::tuple< F8, F8, F8, BF16>,
......@@ -66,7 +66,7 @@ using KernelTypes_MK_KN = ::testing::Types<
using KernelTypes_MK_NK = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType
std::tuple< F16, F16, F16, F16>,
#if defined(CK_ENABLE_FP8) && defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
std::tuple< F16, F8, F16, F16>,
std::tuple< F8, F16, F16, F16>,
std::tuple< F8, F8, F8, BF16>,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment