"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "91d13ef4f791920c1ea9e048b92ba8b74833c3ce"
Commit 53dba87a authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Add macros to enable build with disabled fp8/bf8

parent 59954f5a
...@@ -89,6 +89,7 @@ struct PassThrough ...@@ -89,6 +89,7 @@ struct PassThrough
} }
#endif #endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
template <> template <>
__host__ __device__ void operator()<f8_t, f8_t>(f8_t& y, const f8_t& x) const __host__ __device__ void operator()<f8_t, f8_t>(f8_t& y, const f8_t& x) const
{ {
...@@ -118,6 +119,7 @@ struct PassThrough ...@@ -118,6 +119,7 @@ struct PassThrough
{ {
y = type_convert<f8_t>(x); y = type_convert<f8_t>(x);
} }
#endif
}; };
struct UnaryConvert struct UnaryConvert
...@@ -146,6 +148,7 @@ struct ConvertBF16RTN ...@@ -146,6 +148,7 @@ struct ConvertBF16RTN
} }
}; };
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
struct ConvertF8SR struct ConvertF8SR
{ {
// convert to fp8 using stochastic rounding (SR) // convert to fp8 using stochastic rounding (SR)
...@@ -162,6 +165,7 @@ struct ConvertF8SR ...@@ -162,6 +165,7 @@ struct ConvertF8SR
y = f8_convert_sr<Y>(x); y = f8_convert_sr<Y>(x);
} }
}; };
#endif
struct Scale struct Scale
{ {
......
...@@ -456,6 +456,7 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64> ...@@ -456,6 +456,7 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
} }
}; };
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
template <> template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8f8> struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8f8>
{ {
...@@ -499,6 +500,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8> ...@@ -499,6 +500,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
intrin_mfma_f32_16x16x32f8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c); intrin_mfma_f32_16x16x32f8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
} }
}; };
#endif
template <typename base_type, index_t MPerXdlops, index_t NPerXdlops> template <typename base_type, index_t MPerXdlops, index_t NPerXdlops>
struct MfmaSelector struct MfmaSelector
...@@ -640,6 +642,7 @@ struct MfmaSelector ...@@ -640,6 +642,7 @@ struct MfmaSelector
} }
#endif #endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
template <> template <>
static constexpr auto GetMfma<f8_t, 32, 32>() static constexpr auto GetMfma<f8_t, 32, 32>()
{ {
...@@ -651,6 +654,7 @@ struct MfmaSelector ...@@ -651,6 +654,7 @@ struct MfmaSelector
{ {
return MfmaInstr::mfma_f32_16x16x32f8f8; return MfmaInstr::mfma_f32_16x16x32f8f8;
} }
#endif
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{}; static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{};
...@@ -852,7 +856,11 @@ struct XdlopsGemm ...@@ -852,7 +856,11 @@ struct XdlopsGemm
{ {
static_assert(is_same<base_type, double>::value || is_same<base_type, float>::value || static_assert(is_same<base_type, double>::value || is_same<base_type, float>::value ||
is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value || is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value ||
is_same<base_type, int8_t>::value || is_same<base_type, f8_t>::value, is_same<base_type, int8_t>::value
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
|| is_same<base_type, f8_t>::value
#endif
,
"base base_type must be double, float, half, bfloat16, and int8_t!"); "base base_type must be double, float, half, bfloat16, and int8_t!");
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
......
...@@ -1127,7 +1127,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, ...@@ -1127,7 +1127,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000; uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000;
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
if constexpr(is_same<scalar_t, f8_t>::value) if constexpr(is_same<scalar_t, f8_t>::value)
{ {
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>( auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
...@@ -1136,10 +1136,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, ...@@ -1136,10 +1136,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
} }
else else
{ {
#endif
return amd_buffer_load_impl<scalar_t, vector_size, coherence>( return amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
} }
#endif
#else #else
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
if constexpr(is_same<scalar_t, f8_t>::value) if constexpr(is_same<scalar_t, f8_t>::value)
{ {
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>( auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
...@@ -1148,11 +1152,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, ...@@ -1148,11 +1152,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
} }
else else
{ {
#endif
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>( vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, 0);
return src_thread_element_valid ? tmp : vector_t(0); return src_thread_element_valid ? tmp : vector_t(0);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
} }
#endif #endif
#endif
} }
// buffer_load requires: // buffer_load requires:
...@@ -1209,7 +1216,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t ...@@ -1209,7 +1216,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
if constexpr(is_same<scalar_t, f8_t>::value) if constexpr(is_same<scalar_t, f8_t>::value)
{ {
auto tmp = auto tmp =
...@@ -1219,12 +1226,16 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t ...@@ -1219,12 +1226,16 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
} }
else else
{ {
#endif
amd_buffer_store_impl<scalar_t, vector_size, coherence>( amd_buffer_store_impl<scalar_t, vector_size, coherence>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
} }
#endif
#else #else
if(dst_thread_element_valid) if(dst_thread_element_valid)
{ {
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
if constexpr(is_same<scalar_t, f8_t>::value) if constexpr(is_same<scalar_t, f8_t>::value)
{ {
auto tmp = bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>( auto tmp = bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>(
...@@ -1234,9 +1245,12 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t ...@@ -1234,9 +1245,12 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
} }
else else
{ {
#endif
amd_buffer_store_impl<scalar_t, vector_size, coherence>( amd_buffer_store_impl<scalar_t, vector_size, coherence>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
} }
#endif
} }
#endif #endif
} }
......
...@@ -355,6 +355,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16> ...@@ -355,6 +355,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
} }
}; };
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16f8f8; struct intrin_mfma_f32_32x32x16f8f8;
...@@ -417,5 +418,6 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16> ...@@ -417,5 +418,6 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
#endif #endif
} }
}; };
#endif
} // namespace ck } // namespace ck
#endif #endif
...@@ -12,8 +12,10 @@ using half_t = _Float16; ...@@ -12,8 +12,10 @@ using half_t = _Float16;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using int4_t = _BitInt(4); using int4_t = _BitInt(4);
#endif #endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
using f8_t = _BitInt(8); using f8_t = _BitInt(8);
using bf8_t = unsigned _BitInt(8); using bf8_t = unsigned _BitInt(8);
#endif
template <typename T> template <typename T>
inline __host__ __device__ constexpr auto is_native() inline __host__ __device__ constexpr auto is_native()
...@@ -152,6 +154,7 @@ struct scalar_type<int4_t> ...@@ -152,6 +154,7 @@ struct scalar_type<int4_t>
}; };
#endif #endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
template <> template <>
struct scalar_type<f8_t> struct scalar_type<f8_t>
{ {
...@@ -165,6 +168,7 @@ struct scalar_type<bf8_t> ...@@ -165,6 +168,7 @@ struct scalar_type<bf8_t>
using type = bf8_t; using type = bf8_t;
static constexpr index_t vector_size = 1; static constexpr index_t vector_size = 1;
}; };
#endif
template <typename T> template <typename T>
struct vector_type<T, 1> struct vector_type<T, 1>
...@@ -967,6 +971,7 @@ using int8x16_t = typename vector_type<int8_t, 16>::type; ...@@ -967,6 +971,7 @@ using int8x16_t = typename vector_type<int8_t, 16>::type;
using int8x32_t = typename vector_type<int8_t, 32>::type; using int8x32_t = typename vector_type<int8_t, 32>::type;
using int8x64_t = typename vector_type<int8_t, 64>::type; using int8x64_t = typename vector_type<int8_t, 64>::type;
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
// f8 // f8
using f8x2_t = typename vector_type<f8_t, 2>::type; using f8x2_t = typename vector_type<f8_t, 2>::type;
using f8x4_t = typename vector_type<f8_t, 4>::type; using f8x4_t = typename vector_type<f8_t, 4>::type;
...@@ -982,6 +987,7 @@ using bf8x8_t = typename vector_type<bf8_t, 8>::type; ...@@ -982,6 +987,7 @@ using bf8x8_t = typename vector_type<bf8_t, 8>::type;
using bf8x16_t = typename vector_type<bf8_t, 16>::type; using bf8x16_t = typename vector_type<bf8_t, 16>::type;
using bf8x32_t = typename vector_type<bf8_t, 32>::type; using bf8x32_t = typename vector_type<bf8_t, 32>::type;
using bf8x64_t = typename vector_type<bf8_t, 64>::type; using bf8x64_t = typename vector_type<bf8_t, 64>::type;
#endif
template <typename T> template <typename T>
struct NumericLimits struct NumericLimits
...@@ -1029,6 +1035,7 @@ struct NumericLimits<int4_t> ...@@ -1029,6 +1035,7 @@ struct NumericLimits<int4_t>
}; };
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
template <> template <>
struct NumericLimits<f8_t> struct NumericLimits<f8_t>
{ {
...@@ -1074,5 +1081,6 @@ struct NumericLimits<bf8_t> ...@@ -1074,5 +1081,6 @@ struct NumericLimits<bf8_t>
__host__ __device__ static constexpr bf8_t QuietNaN() { return bf8_t(binary_qnan); } __host__ __device__ static constexpr bf8_t QuietNaN() { return bf8_t(binary_qnan); }
}; };
#endif
} // namespace ck } // namespace ck
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
namespace ck { namespace ck {
// fp8 rounding modes // fp8 rounding modes
...@@ -283,3 +284,4 @@ __host__ __device__ Y cast_from_f8(X x) ...@@ -283,3 +284,4 @@ __host__ __device__ Y cast_from_f8(X x)
} }
} // namespace ck::utils } // namespace ck::utils
#endif
...@@ -80,6 +80,7 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_ ...@@ -80,6 +80,7 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
return type_convert<bhalf_t>(x_fp32); return type_convert<bhalf_t>(x_fp32);
} }
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
// convert fp32 to fp8 // convert fp32 to fp8
template <> template <>
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x) inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
...@@ -163,6 +164,7 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x) ...@@ -163,6 +164,7 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<bf8_t, half_t, negative_zero_nan>(x); return utils::cast_from_f8<bf8_t, half_t, negative_zero_nan>(x);
} }
#endif
// Declare a template function for bf16 conversion using RTN // Declare a template function for bf16 conversion using RTN
template <typename Y, typename X> template <typename Y, typename X>
...@@ -221,6 +223,7 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(h ...@@ -221,6 +223,7 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(h
return bf16_convert_rtn<bhalf_t>(x_fp32); return bf16_convert_rtn<bhalf_t>(x_fp32);
} }
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
// Declare a template function for fp8 conversion using SR // Declare a template function for fp8 conversion using SR
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ constexpr Y f8_convert_sr(X x); __host__ __device__ constexpr Y f8_convert_sr(X x);
...@@ -284,5 +287,6 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x) ...@@ -284,5 +287,6 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng); x, rng);
} }
#endif
} // namespace ck } // namespace ck
...@@ -17,10 +17,12 @@ namespace instance { ...@@ -17,10 +17,12 @@ namespace instance {
using F64 = double; using F64 = double;
using F32 = float; using F32 = float;
using F16 = ck::half_t; using F16 = ck::half_t;
using F8 = ck::f8_t;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using I8 = int8_t; using I8 = int8_t;
using I32 = int32_t; using I32 = int32_t;
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
using F8 = ck::f8_t;
#endif
using Empty_Tuple = ck::Tuple<>; using Empty_Tuple = ck::Tuple<>;
......
...@@ -230,6 +230,7 @@ check_err(const Range& out, ...@@ -230,6 +230,7 @@ check_err(const Range& out,
return res; return res;
} }
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
template <typename Range, typename RefRange> template <typename Range, typename RefRange>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> && std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
(std::is_same_v<ranges::range_value_t<Range>, f8_t> || (std::is_same_v<ranges::range_value_t<Range>, f8_t> ||
...@@ -276,6 +277,7 @@ check_err(const Range& out, ...@@ -276,6 +277,7 @@ check_err(const Range& out,
} }
return res; return res;
} }
#endif
} // namespace utils } // namespace utils
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment