Unverified Commit 689a5ae4 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Pass build flags to config.h (#1760)

* pass the build flags to config.h

* fix clang format
parent 6ef8d3c2
...@@ -183,14 +183,17 @@ message("Building CK for the following targets: ${SUPPORTED_GPU_TARGETS}") ...@@ -183,14 +183,17 @@ message("Building CK for the following targets: ${SUPPORTED_GPU_TARGETS}")
if (SUPPORTED_GPU_TARGETS MATCHES "gfx9") if (SUPPORTED_GPU_TARGETS MATCHES "gfx9")
message("Enabling XDL instances") message("Enabling XDL instances")
add_definitions(-DCK_USE_XDL) add_definitions(-DCK_USE_XDL)
set(CK_USE_XDL "ON")
endif() endif()
if (SUPPORTED_GPU_TARGETS MATCHES "gfx94") if (SUPPORTED_GPU_TARGETS MATCHES "gfx94")
message("Enabling FP8 gemms on native architectures") message("Enabling FP8 gemms on native architectures")
add_definitions(-DCK_USE_GFX94) add_definitions(-DCK_USE_GFX94)
set(CK_USE_GFX94 "ON")
endif() endif()
if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12")
message("Enabling WMMA instances") message("Enabling WMMA instances")
add_definitions(-DCK_USE_WMMA) add_definitions(-DCK_USE_WMMA)
set(CK_USE_WMMA "ON")
endif() endif()
if (SUPPORTED_GPU_TARGETS MATCHES "gfx12") if (SUPPORTED_GPU_TARGETS MATCHES "gfx12")
add_definitions(-DCK_USE_OCP_FP8) add_definitions(-DCK_USE_OCP_FP8)
...@@ -204,6 +207,7 @@ endif() ...@@ -204,6 +207,7 @@ endif()
option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF) option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF)
if(CK_USE_FP8_ON_UNSUPPORTED_ARCH AND (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx908")) if(CK_USE_FP8_ON_UNSUPPORTED_ARCH AND (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx908"))
add_definitions(-DCK_USE_FP8_ON_UNSUPPORTED_ARCH) add_definitions(-DCK_USE_FP8_ON_UNSUPPORTED_ARCH)
set(CK_USE_FP8_ON_UNSUPPORTED_ARCH "ON")
endif() endif()
# CK config file to record supported datatypes, etc. # CK config file to record supported datatypes, etc.
......
...@@ -111,6 +111,22 @@ ...@@ -111,6 +111,22 @@
#cmakedefine CK_USE_WMMA @CK_USE_WMMA@ #cmakedefine CK_USE_WMMA @CK_USE_WMMA@
#endif #endif
#ifndef CK_USE_GFX94
#cmakedefine CK_USE_GFX94 @CK_USE_GFX94@
#endif
#ifndef DCK_USE_OCP_FP8
#cmakedefine DCK_USE_OCP_FP8 @DCK_USE_OCP_FP8@
#endif
#ifndef CK_USE_FNUZ_FP8
#cmakedefine CK_USE_FNUZ_FP8 @CK_USE_FNUZ_FP8@
#endif
#ifndef CK_USE_FP8_ON_UNSUPPORTED_ARCH
#cmakedefine CK_USE_FP8_ON_UNSUPPORTED_ARCH @CK_USE_FP8_ON_UNSUPPORTED_ARCH@
#endif
// clang-format on // clang-format on
#endif // CK_CONFIG_H_IN #endif // CK_CONFIG_H_IN
...@@ -20,9 +20,17 @@ ...@@ -20,9 +20,17 @@
namespace { namespace {
// https://en.cppreference.com/w/cpp/types/conditional // https://en.cppreference.com/w/cpp/types/conditional
template <bool B, class T, class F> struct conditional { using type = T; }; template <bool B, class T, class F>
template <class T, class F> struct conditional<false, T, F> { using type = F; }; struct conditional
} {
using type = T;
};
template <class T, class F>
struct conditional<false, T, F>
{
using type = F;
};
} // namespace
namespace ck { namespace ck {
...@@ -200,8 +208,7 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x) ...@@ -200,8 +208,7 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
typename conditional< typename conditional<
sizeof(T) == 2, sizeof(T) == 2,
unsigned short int, unsigned short int,
typename conditional<sizeof(T) == 4, unsigned int, unsigned long long>:: typename conditional<sizeof(T) == 4, unsigned int, unsigned long long>::type>::type retval;
type>::type retval;
if constexpr(we == 5 && is_half && !is_fnuz) if constexpr(we == 5 && is_half && !is_fnuz)
{ {
...@@ -547,8 +554,7 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn ...@@ -547,8 +554,7 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
using T_bitwise = typename conditional< using T_bitwise = typename conditional<
sizeof(T) == 2, sizeof(T) == 2,
unsigned short int, unsigned short int,
typename conditional<sizeof(T) == 4, unsigned int, unsigned long long>:: typename conditional<sizeof(T) == 4, unsigned int, unsigned long long>::type>::type;
type>::type;
T_bitwise x_bitwise = bit_cast<T_bitwise>(_x); T_bitwise x_bitwise = bit_cast<T_bitwise>(_x);
unsigned long long x{x_bitwise}; unsigned long long x{x_bitwise};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment