Unverified Commit b4a79045 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

re-enable fp8 gemms in ckProfiler (#1667)

parent 3b6a481e
...@@ -183,12 +183,14 @@ message("Building CK for the following targets: ${SUPPORTED_GPU_TARGETS}") ...@@ -183,12 +183,14 @@ message("Building CK for the following targets: ${SUPPORTED_GPU_TARGETS}")
if (SUPPORTED_GPU_TARGETS MATCHES "gfx9") if (SUPPORTED_GPU_TARGETS MATCHES "gfx9")
message("Enabling XDL instances") message("Enabling XDL instances")
add_definitions(-DCK_USE_XDL) add_definitions(-DCK_USE_XDL)
set(CK_USE_XDL "ON") endif()
if (SUPPORTED_GPU_TARGETS MATCHES "gfx94")
message("Enabling FP8 gemms in ckProfiler")
add_definitions(-DCK_USE_GFX94)
endif() endif()
if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12")
message("Enabling WMMA instances") message("Enabling WMMA instances")
add_definitions(-DCK_USE_WMMA) add_definitions(-DCK_USE_WMMA)
set(CK_USE_WMMA "ON")
endif() endif()
option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF) option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF)
if(CK_USE_FP8_ON_UNSUPPORTED_ARCH AND (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx908")) if(CK_USE_FP8_ON_UNSUPPORTED_ARCH AND (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx908"))
......
...@@ -101,7 +101,7 @@ int profile_gemm_universal(int argc, char* argv[]) ...@@ -101,7 +101,7 @@ int profile_gemm_universal(int argc, char* argv[])
using F32 = float; using F32 = float;
using F16 = ck::half_t; using F16 = ck::half_t;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) #if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)
using F8 = ck::f8_t; using F8 = ck::f8_t;
#endif #endif
...@@ -164,7 +164,7 @@ int profile_gemm_universal(int argc, char* argv[]) ...@@ -164,7 +164,7 @@ int profile_gemm_universal(int argc, char* argv[])
{ {
return profile(F16{}, F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}); return profile(F16{}, F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{});
} }
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) #if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_KN_MN) else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
return profile(F16{}, F8{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}); return profile(F16{}, F8{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{});
...@@ -198,7 +198,7 @@ int profile_gemm_universal(int argc, char* argv[]) ...@@ -198,7 +198,7 @@ int profile_gemm_universal(int argc, char* argv[])
{ {
return profile(BF16{}, BF16{}, BF16{}, F32{}, BF16{}, Col{}, Row{}, Row{}); return profile(BF16{}, BF16{}, BF16{}, F32{}, BF16{}, Col{}, Row{}, Row{});
} }
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) #if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)
else if(data_type == GemmDataType::F8_F8_BF16 && layout == GemmMatrixLayout::MK_KN_MN) else if(data_type == GemmDataType::F8_F8_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
return profile(F8{}, F8{}, F8{}, F32{}, BF16{}, Row{}, Row{}, Row{}); return profile(F8{}, F8{}, F8{}, F32{}, BF16{}, Row{}, Row{}, Row{});
......
...@@ -56,7 +56,7 @@ class TestGemmUniversal_KM_NK ...@@ -56,7 +56,7 @@ class TestGemmUniversal_KM_NK
using KernelTypes_MK_KN = ::testing::Types< using KernelTypes_MK_KN = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType // ADataType, BDataType, ComputeDataType, CDataType
std::tuple< F16, F16, F16, F16>, std::tuple< F16, F16, F16, F16>,
#if defined(CK_ENABLE_FP8) && defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) #if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
std::tuple< F16, F8, F16, F16>, std::tuple< F16, F8, F16, F16>,
std::tuple< F8, F16, F16, F16>, std::tuple< F8, F16, F16, F16>,
std::tuple< F8, F8, F8, BF16>, std::tuple< F8, F8, F8, BF16>,
...@@ -66,7 +66,7 @@ using KernelTypes_MK_KN = ::testing::Types< ...@@ -66,7 +66,7 @@ using KernelTypes_MK_KN = ::testing::Types<
using KernelTypes_MK_NK = ::testing::Types< using KernelTypes_MK_NK = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType // ADataType, BDataType, ComputeDataType, CDataType
std::tuple< F16, F16, F16, F16>, std::tuple< F16, F16, F16, F16>,
#if defined(CK_ENABLE_FP8) && defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) #if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
std::tuple< F16, F8, F16, F16>, std::tuple< F16, F8, F16, F16>,
std::tuple< F8, F16, F16, F16>, std::tuple< F8, F16, F16, F16>,
std::tuple< F8, F8, F8, BF16>, std::tuple< F8, F8, F8, BF16>,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment