Commit 2a51cf49 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Decouple fp8 and bf8 flags

parent c028e416
...@@ -89,7 +89,7 @@ struct PassThrough ...@@ -89,7 +89,7 @@ struct PassThrough
} }
#endif #endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 #if defined CK_ENABLE_FP8
template <> template <>
__host__ __device__ void operator()<f8_t, f8_t>(f8_t& y, const f8_t& x) const __host__ __device__ void operator()<f8_t, f8_t>(f8_t& y, const f8_t& x) const
{ {
...@@ -148,7 +148,7 @@ struct ConvertBF16RTN ...@@ -148,7 +148,7 @@ struct ConvertBF16RTN
} }
}; };
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 #if defined CK_ENABLE_FP8
struct ConvertF8SR struct ConvertF8SR
{ {
// convert to fp8 using stochastic rounding (SR) // convert to fp8 using stochastic rounding (SR)
......
...@@ -456,7 +456,7 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64> ...@@ -456,7 +456,7 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
} }
}; };
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 #if defined CK_ENABLE_FP8
template <> template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8f8> struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8f8>
{ {
...@@ -642,7 +642,7 @@ struct MfmaSelector ...@@ -642,7 +642,7 @@ struct MfmaSelector
} }
#endif #endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 #if defined CK_ENABLE_FP8
template <> template <>
static constexpr auto GetMfma<f8_t, 32, 32>() static constexpr auto GetMfma<f8_t, 32, 32>()
{ {
...@@ -857,7 +857,7 @@ struct XdlopsGemm ...@@ -857,7 +857,7 @@ struct XdlopsGemm
static_assert(is_same<base_type, double>::value || is_same<base_type, float>::value || static_assert(is_same<base_type, double>::value || is_same<base_type, float>::value ||
is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value || is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value ||
is_same<base_type, int8_t>::value is_same<base_type, int8_t>::value
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 #if defined CK_ENABLE_FP8
|| is_same<base_type, f8_t>::value || is_same<base_type, f8_t>::value
#endif #endif
, ,
......
...@@ -1127,7 +1127,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, ...@@ -1127,7 +1127,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000; uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000;
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 #if defined CK_ENABLE_FP8
if constexpr(is_same<scalar_t, f8_t>::value) if constexpr(is_same<scalar_t, f8_t>::value)
{ {
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>( auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
...@@ -1139,11 +1139,11 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, ...@@ -1139,11 +1139,11 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#endif #endif
return amd_buffer_load_impl<scalar_t, vector_size, coherence>( return amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 #if defined CK_ENABLE_FP8
} }
#endif #endif
#else #else
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 #if defined CK_ENABLE_FP8
if constexpr(is_same<scalar_t, f8_t>::value) if constexpr(is_same<scalar_t, f8_t>::value)
{ {
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>( auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
...@@ -1156,7 +1156,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, ...@@ -1156,7 +1156,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>( vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, 0);
return src_thread_element_valid ? tmp : vector_t(0); return src_thread_element_valid ? tmp : vector_t(0);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 #if defined CK_ENABLE_FP8
} }
#endif #endif
#endif #endif
...@@ -1216,7 +1216,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t ...@@ -1216,7 +1216,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 #if defined CK_ENABLE_FP8
if constexpr(is_same<scalar_t, f8_t>::value) if constexpr(is_same<scalar_t, f8_t>::value)
{ {
auto tmp = auto tmp =
...@@ -1229,13 +1229,13 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t ...@@ -1229,13 +1229,13 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
#endif #endif
amd_buffer_store_impl<scalar_t, vector_size, coherence>( amd_buffer_store_impl<scalar_t, vector_size, coherence>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 #if defined CK_ENABLE_FP8
} }
#endif #endif
#else #else
if(dst_thread_element_valid) if(dst_thread_element_valid)
{ {
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 #if defined CK_ENABLE_FP8
if constexpr(is_same<scalar_t, f8_t>::value) if constexpr(is_same<scalar_t, f8_t>::value)
{ {
auto tmp = bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>( auto tmp = bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>(
...@@ -1248,7 +1248,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t ...@@ -1248,7 +1248,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
#endif #endif
amd_buffer_store_impl<scalar_t, vector_size, coherence>( amd_buffer_store_impl<scalar_t, vector_size, coherence>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 #if defined CK_ENABLE_FP8
} }
#endif #endif
} }
......
...@@ -355,7 +355,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16> ...@@ -355,7 +355,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
} }
}; };
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 #if defined CK_ENABLE_FP8
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16f8f8; struct intrin_mfma_f32_32x32x16f8f8;
......
...@@ -12,8 +12,10 @@ using half_t = _Float16; ...@@ -12,8 +12,10 @@ using half_t = _Float16;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using int4_t = _BitInt(4); using int4_t = _BitInt(4);
#endif #endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 #if defined CK_ENABLE_FP8
using f8_t = _BitInt(8); using f8_t = _BitInt(8);
#endif
#if defined CK_ENABLE_BF8
using bf8_t = unsigned _BitInt(8); using bf8_t = unsigned _BitInt(8);
#endif #endif
...@@ -146,14 +148,16 @@ struct scalar_type<int4_t> ...@@ -146,14 +148,16 @@ struct scalar_type<int4_t>
}; };
#endif #endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 #if defined CK_ENABLE_FP8
template <> template <>
struct scalar_type<f8_t> struct scalar_type<f8_t>
{ {
using type = f8_t; using type = f8_t;
static constexpr index_t vector_size = 1; static constexpr index_t vector_size = 1;
}; };
#endif
#if defined CK_ENABLE_BF8
template <> template <>
struct scalar_type<bf8_t> struct scalar_type<bf8_t>
{ {
...@@ -963,16 +967,18 @@ using int8x16_t = typename vector_type<int8_t, 16>::type; ...@@ -963,16 +967,18 @@ using int8x16_t = typename vector_type<int8_t, 16>::type;
using int8x32_t = typename vector_type<int8_t, 32>::type; using int8x32_t = typename vector_type<int8_t, 32>::type;
using int8x64_t = typename vector_type<int8_t, 64>::type; using int8x64_t = typename vector_type<int8_t, 64>::type;
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
// f8 // f8
#if defined CK_ENABLE_FP8
using f8x2_t = typename vector_type<f8_t, 2>::type; using f8x2_t = typename vector_type<f8_t, 2>::type;
using f8x4_t = typename vector_type<f8_t, 4>::type; using f8x4_t = typename vector_type<f8_t, 4>::type;
using f8x8_t = typename vector_type<f8_t, 8>::type; using f8x8_t = typename vector_type<f8_t, 8>::type;
using f8x16_t = typename vector_type<f8_t, 16>::type; using f8x16_t = typename vector_type<f8_t, 16>::type;
using f8x32_t = typename vector_type<f8_t, 32>::type; using f8x32_t = typename vector_type<f8_t, 32>::type;
using f8x64_t = typename vector_type<f8_t, 64>::type; using f8x64_t = typename vector_type<f8_t, 64>::type;
#endif
// bf8 // bf8
#if defined CK_ENABLE_BF8
using bf8x2_t = typename vector_type<bf8_t, 2>::type; using bf8x2_t = typename vector_type<bf8_t, 2>::type;
using bf8x4_t = typename vector_type<bf8_t, 4>::type; using bf8x4_t = typename vector_type<bf8_t, 4>::type;
using bf8x8_t = typename vector_type<bf8_t, 8>::type; using bf8x8_t = typename vector_type<bf8_t, 8>::type;
...@@ -1027,7 +1033,7 @@ struct NumericLimits<int4_t> ...@@ -1027,7 +1033,7 @@ struct NumericLimits<int4_t>
}; };
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 #if defined CK_ENABLE_FP8
template <> template <>
struct NumericLimits<f8_t> struct NumericLimits<f8_t>
{ {
...@@ -1050,7 +1056,9 @@ struct NumericLimits<f8_t> ...@@ -1050,7 +1056,9 @@ struct NumericLimits<f8_t>
__host__ __device__ static constexpr f8_t QuietNaN() { return f8_t(binary_qnan); } __host__ __device__ static constexpr f8_t QuietNaN() { return f8_t(binary_qnan); }
}; };
#endif
#if defined CK_ENABLE_BF8
template <> template <>
struct NumericLimits<bf8_t> struct NumericLimits<bf8_t>
{ {
......
...@@ -80,7 +80,7 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_ ...@@ -80,7 +80,7 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
return type_convert<bhalf_t>(x_fp32); return type_convert<bhalf_t>(x_fp32);
} }
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 #if defined CK_ENABLE_FP8
// convert fp32 to fp8 // convert fp32 to fp8
template <> template <>
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x) inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
...@@ -122,7 +122,9 @@ inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x) ...@@ -122,7 +122,9 @@ inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<f8_t, half_t, negative_zero_nan>(x); return utils::cast_from_f8<f8_t, half_t, negative_zero_nan>(x);
} }
#endif
#if defined CK_ENABLE_BF8
// convert fp32 to bf8 // convert fp32 to bf8
template <> template <>
inline __host__ __device__ bf8_t type_convert<bf8_t, float>(float x) inline __host__ __device__ bf8_t type_convert<bf8_t, float>(float x)
...@@ -223,11 +225,11 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(h ...@@ -223,11 +225,11 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(h
return bf16_convert_rtn<bhalf_t>(x_fp32); return bf16_convert_rtn<bhalf_t>(x_fp32);
} }
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
// Declare a template function for fp8 conversion using SR // Declare a template function for fp8 conversion using SR
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ constexpr Y f8_convert_sr(X x); __host__ __device__ constexpr Y f8_convert_sr(X x);
#if defined CK_ENABLE_FP8
// convert fp32 to fp8 with stochastic rounding // convert fp32 to fp8 with stochastic rounding
template <> template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x) inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
...@@ -257,7 +259,9 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x) ...@@ -257,7 +259,9 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng); x, rng);
} }
#endif
#if defined CK_ENABLE_BF8
// convert fp32 to bf8 with stochastic rounding // convert fp32 to bf8 with stochastic rounding
template <> template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x) inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
......
...@@ -20,9 +20,12 @@ using F16 = ck::half_t; ...@@ -20,9 +20,12 @@ using F16 = ck::half_t;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using I8 = int8_t; using I8 = int8_t;
using I32 = int32_t; using I32 = int32_t;
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 #if defined CK_ENABLE_FP8
using F8 = ck::f8_t; using F8 = ck::f8_t;
#endif #endif
#if defined CK_ENABLE_BF8
using BF8 = ck::bf8_t;
#endif
using Empty_Tuple = ck::Tuple<>; using Empty_Tuple = ck::Tuple<>;
......
...@@ -45,7 +45,7 @@ void add_device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_ ...@@ -45,7 +45,7 @@ void add_device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_
PassThrough, PassThrough,
MultiplyAdd>>>&); MultiplyAdd>>>&);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 #if defined CK_ENABLE_FP8
void add_device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instances( void add_device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row, std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Row, Row,
...@@ -133,7 +133,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -133,7 +133,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
} }
} }
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 #if defined CK_ENABLE_FP8
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> && if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> &&
is_same_v<D0DataType, float> && is_same_v<D1DataType, float> && is_same_v<D0DataType, float> && is_same_v<D1DataType, float> &&
is_same_v<EDataType, half_t>) is_same_v<EDataType, half_t>)
......
...@@ -57,7 +57,7 @@ void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances( ...@@ -57,7 +57,7 @@ void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(
DeviceGemmSplitK<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>& DeviceGemmSplitK<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 #if defined CK_ENABLE_FP8
void add_device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instances( void add_device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmSplitK<Col, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmSplitK<Col, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
...@@ -178,7 +178,7 @@ struct DeviceOperationInstanceFactory< ...@@ -178,7 +178,7 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(op_ptrs); add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(op_ptrs);
} }
} }
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 #if defined CK_ENABLE_FP8
else if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, half_t> && else if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t>) is_same_v<CDataType, half_t>)
{ {
......
...@@ -230,11 +230,57 @@ check_err(const Range& out, ...@@ -230,11 +230,57 @@ check_err(const Range& out,
return res; return res;
} }
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 #if defined CK_ENABLE_FP8
template <typename Range, typename RefRange> template <typename Range, typename RefRange>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> && std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
(std::is_same_v<ranges::range_value_t<Range>, f8_t> || std::is_same_v<ranges::range_value_t<Range>, f8_t>),
std::is_same_v<ranges::range_value_t<Range>, bf8_t>)), bool>
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
double rtol = 1e-3,
double atol = 1e-3)
{
if(out.size() != ref.size())
{
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl;
return false;
}
bool res{true};
int err_count = 0;
double err = 0;
double max_err = std::numeric_limits<float>::min();
for(std::size_t i = 0; i < ref.size(); ++i)
{
const double o = type_convert<float>(*std::next(std::begin(out), i));
const double r = type_convert<float>(*std::next(std::begin(ref), i));
err = std::abs(o - r);
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
}
res = false;
}
}
if(!res)
{
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
}
return res;
}
#endif
#if defined CK_ENABLE_BF8
template <typename Range, typename RefRange>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, bf8_t>),
bool> bool>
check_err(const Range& out, check_err(const Range& out,
const RefRange& ref, const RefRange& ref,
...@@ -252,7 +298,6 @@ check_err(const Range& out, ...@@ -252,7 +298,6 @@ check_err(const Range& out,
bool res{true}; bool res{true};
int err_count = 0; int err_count = 0;
double err = 0; double err = 0;
// TODO: This is a hack. We should have proper specialization for bhalf_t data type.
double max_err = std::numeric_limits<float>::min(); double max_err = std::numeric_limits<float>::min();
for(std::size_t i = 0; i < ref.size(); ++i) for(std::size_t i = 0; i < ref.size(); ++i)
{ {
......
...@@ -214,7 +214,7 @@ bool profile_gemm_splitk_impl(int do_verification, ...@@ -214,7 +214,7 @@ bool profile_gemm_splitk_impl(int do_verification,
<< " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch " << " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch "
<< kbatch_curr << std::endl; << kbatch_curr << std::endl;
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 #if defined CK_ENABLE_FP8
// set softer tolerances for fp8 // set softer tolerances for fp8
if constexpr(is_same_v<ADataType, f8_t> || is_same_v<BDataType, f8_t> || if constexpr(is_same_v<ADataType, f8_t> || is_same_v<BDataType, f8_t> ||
is_same_v<CDataType, f8_t>) is_same_v<CDataType, f8_t>)
...@@ -229,7 +229,7 @@ bool profile_gemm_splitk_impl(int do_verification, ...@@ -229,7 +229,7 @@ bool profile_gemm_splitk_impl(int do_verification,
{ {
#endif #endif
pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 #if defined CK_ENABLE_FP8
} }
#endif #endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment