Commit 53dba87a authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Add macros to enable build with disabled fp8/bf8

parent 59954f5a
......@@ -89,6 +89,7 @@ struct PassThrough
}
#endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
template <>
__host__ __device__ void operator()<f8_t, f8_t>(f8_t& y, const f8_t& x) const
{
......@@ -118,6 +119,7 @@ struct PassThrough
{
y = type_convert<f8_t>(x);
}
#endif
};
struct UnaryConvert
......@@ -146,6 +148,7 @@ struct ConvertBF16RTN
}
};
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
struct ConvertF8SR
{
// convert to fp8 using stochastic rounding (SR)
......@@ -162,6 +165,7 @@ struct ConvertF8SR
y = f8_convert_sr<Y>(x);
}
};
#endif
struct Scale
{
......
......@@ -456,6 +456,7 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
}
};
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8f8>
{
......@@ -499,6 +500,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
intrin_mfma_f32_16x16x32f8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
#endif
template <typename base_type, index_t MPerXdlops, index_t NPerXdlops>
struct MfmaSelector
......@@ -640,6 +642,7 @@ struct MfmaSelector
}
#endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
template <>
static constexpr auto GetMfma<f8_t, 32, 32>()
{
......@@ -651,6 +654,7 @@ struct MfmaSelector
{
return MfmaInstr::mfma_f32_16x16x32f8f8;
}
#endif
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{};
......@@ -852,7 +856,11 @@ struct XdlopsGemm
{
static_assert(is_same<base_type, double>::value || is_same<base_type, float>::value ||
is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value ||
is_same<base_type, int8_t>::value || is_same<base_type, f8_t>::value,
is_same<base_type, int8_t>::value
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
|| is_same<base_type, f8_t>::value
#endif
,
"base base_type must be double, float, half, bfloat16, and int8_t!");
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
......
......@@ -1127,7 +1127,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000;
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
if constexpr(is_same<scalar_t, f8_t>::value)
{
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
......@@ -1136,10 +1136,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
}
else
{
#endif
return amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
}
#endif
#else
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
if constexpr(is_same<scalar_t, f8_t>::value)
{
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
......@@ -1148,11 +1152,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
}
else
{
#endif
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0);
return src_thread_element_valid ? tmp : vector_t(0);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
}
#endif
#endif
}
// buffer_load requires:
......@@ -1209,7 +1216,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
if constexpr(is_same<scalar_t, f8_t>::value)
{
auto tmp =
......@@ -1219,12 +1226,16 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
}
else
{
#endif
amd_buffer_store_impl<scalar_t, vector_size, coherence>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
}
#endif
#else
if(dst_thread_element_valid)
{
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
if constexpr(is_same<scalar_t, f8_t>::value)
{
auto tmp = bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>(
......@@ -1234,9 +1245,12 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
}
else
{
#endif
amd_buffer_store_impl<scalar_t, vector_size, coherence>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
}
#endif
}
#endif
}
......
......@@ -355,6 +355,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
}
};
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16f8f8;
......@@ -417,5 +418,6 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
#endif
}
};
#endif
} // namespace ck
#endif
......@@ -12,8 +12,10 @@ using half_t = _Float16;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using int4_t = _BitInt(4);
#endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
using f8_t = _BitInt(8);
using bf8_t = unsigned _BitInt(8);
#endif
template <typename T>
inline __host__ __device__ constexpr auto is_native()
......@@ -152,6 +154,7 @@ struct scalar_type<int4_t>
};
#endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
template <>
struct scalar_type<f8_t>
{
......@@ -165,6 +168,7 @@ struct scalar_type<bf8_t>
using type = bf8_t;
static constexpr index_t vector_size = 1;
};
#endif
template <typename T>
struct vector_type<T, 1>
......@@ -967,6 +971,7 @@ using int8x16_t = typename vector_type<int8_t, 16>::type;
using int8x32_t = typename vector_type<int8_t, 32>::type;
using int8x64_t = typename vector_type<int8_t, 64>::type;
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
// f8
using f8x2_t = typename vector_type<f8_t, 2>::type;
using f8x4_t = typename vector_type<f8_t, 4>::type;
......@@ -982,6 +987,7 @@ using bf8x8_t = typename vector_type<bf8_t, 8>::type;
using bf8x16_t = typename vector_type<bf8_t, 16>::type;
using bf8x32_t = typename vector_type<bf8_t, 32>::type;
using bf8x64_t = typename vector_type<bf8_t, 64>::type;
#endif
template <typename T>
struct NumericLimits
......@@ -1029,6 +1035,7 @@ struct NumericLimits<int4_t>
};
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
template <>
struct NumericLimits<f8_t>
{
......@@ -1074,5 +1081,6 @@ struct NumericLimits<bf8_t>
__host__ __device__ static constexpr bf8_t QuietNaN() { return bf8_t(binary_qnan); }
};
#endif
} // namespace ck
......@@ -5,6 +5,7 @@
#include "ck/utility/data_type.hpp"
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
namespace ck {
// fp8 rounding modes
......@@ -283,3 +284,4 @@ __host__ __device__ Y cast_from_f8(X x)
}
} // namespace ck::utils
#endif
......@@ -80,6 +80,7 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
return type_convert<bhalf_t>(x_fp32);
}
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
// convert fp32 to fp8
template <>
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
......@@ -163,6 +164,7 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<bf8_t, half_t, negative_zero_nan>(x);
}
#endif
// Declare a template function for bf16 conversion using RTN
template <typename Y, typename X>
......@@ -221,6 +223,7 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(h
return bf16_convert_rtn<bhalf_t>(x_fp32);
}
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
// Declare a template function for fp8 conversion using SR
template <typename Y, typename X>
__host__ __device__ constexpr Y f8_convert_sr(X x);
......@@ -284,5 +287,6 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
}
#endif
} // namespace ck
......@@ -17,10 +17,12 @@ namespace instance {
using F64 = double;
using F32 = float;
using F16 = ck::half_t;
using F8 = ck::f8_t;
using BF16 = ck::bhalf_t;
using I8 = int8_t;
using I32 = int32_t;
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
using F8 = ck::f8_t;
#endif
using Empty_Tuple = ck::Tuple<>;
......
......@@ -230,6 +230,7 @@ check_err(const Range& out,
return res;
}
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
template <typename Range, typename RefRange>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
(std::is_same_v<ranges::range_value_t<Range>, f8_t> ||
......@@ -276,6 +277,7 @@ check_err(const Range& out,
}
return res;
}
#endif
} // namespace utils
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment