Unverified Commit 1504c3e8 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge pull request #219 from ROCm/andriy/lwpck-2430

Add support of OCP FP8 data types in CK for gfx950 arch
parents b6f7cddd 27a05c7e
......@@ -3,6 +3,7 @@
#pragma once
#include "ck/utility/amd_ck_fp8.hpp"
#include "ck/utility/statically_indexed_array.hpp"
namespace ck {
......@@ -10,8 +11,6 @@ namespace ck {
using bhalf_t = ushort;
using half_t = _Float16;
using int4_t = _BitInt(4);
using f8_t = _BitInt(8);
using bf8_t = unsigned _BitInt(8);
inline constexpr auto next_pow2(uint32_t x)
{
......@@ -19,14 +18,15 @@ inline constexpr auto next_pow2(uint32_t x)
return x > 1u ? (1u << (32u - __builtin_clz(x - 1u))) : x;
}
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_t, bf8_t, bool
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t,
// native types: bool
template <typename T>
inline constexpr bool is_native_type()
{
return is_same<T, double>::value || is_same<T, float>::value || is_same<T, half_t>::value ||
is_same<T, bhalf_t>::value || is_same<T, int32_t>::value || is_same<T, int8_t>::value ||
is_same<T, uint8_t>::value || is_same<T, f8_t>::value || is_same<T, bf8_t>::value ||
is_same<T, bool>::value;
is_same<T, uint8_t>::value || is_same<T, f8_fnuz_t>::value ||
is_same<T, bf8_fnuz_t>::value || is_same<T, bool>::value;
}
// vector_type
......@@ -166,16 +166,30 @@ struct scalar_type<int4_t>
#endif
template <>
struct scalar_type<f8_t>
struct scalar_type<f8_fnuz_t>
{
using type = f8_t;
using type = f8_fnuz_t;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<bf8_t>
struct scalar_type<bf8_fnuz_t>
{
using type = bf8_t;
using type = bf8_fnuz_t;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<f8_ocp_t>
{
using type = f8_ocp_t::data_type;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<bf8_ocp_t>
{
using type = bf8_ocp_t::data_type;
static constexpr index_t vector_size = 1;
};
......@@ -1023,46 +1037,82 @@ struct non_native_vector_base
T d[N];
};
template <typename T, index_t N>
struct scalar_type<non_native_vector_base<T, N>>;
template <index_t N>
struct scalar_type<non_native_vector_base<f8_ocp_t, N>>
{
using type = typename non_native_vector_base<f8_ocp_t, N>::data_t;
static constexpr index_t vector_size = N;
};
template <index_t N>
struct scalar_type<non_native_vector_base<bf8_ocp_t, N>>
{
using type = typename non_native_vector_base<bf8_ocp_t, N>::data_t;
static constexpr index_t vector_size = N;
};
// non-native vector_type implementation
template <typename T>
struct vector_type<T, 1, typename std::enable_if_t<!is_native_type<T>()>>
{
using d1_t = T;
using type = d1_t;
using d1_nnv_t = non_native_vector_base<T, 1>;
using type = d1_nnv_t;
union alignas(next_pow2(1 * sizeof(T)))
{
d1_t d1_;
StaticallyIndexedArray<d1_t, 1> d1x1_;
d1_nnv_t d1_nnv_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{}} {}
__host__ __device__ constexpr vector_type() : data_{d1_t{}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
__host__ __device__ constexpr vector_type(type v) : data_{{v}} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value,
static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
{
return data_.d1x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value,
static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
{
return data_.d1x1_;
}
else
{
return err;
}
}
};
template <typename T>
struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
{
using d1_t = T;
using d1_nnv_t = non_native_vector_base<T, 1>;
using d2_t = non_native_vector_base<T, 2>;
using type = d2_t;
......@@ -1081,10 +1131,11 @@ struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value,
static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value ||
is_same<X, d2_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
{
return data_.d1x2_;
}
......@@ -1101,10 +1152,11 @@ struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value,
static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value ||
is_same<X, d2_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
{
return data_.d1x2_;
}
......@@ -1123,6 +1175,7 @@ template <typename T>
struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
{
using d1_t = T;
using d1_nnv_t = non_native_vector_base<T, 1>;
using d2_t = non_native_vector_base<T, 2>;
using d4_t = non_native_vector_base<T, 4>;
......@@ -1143,10 +1196,11 @@ struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value,
static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value ||
is_same<X, d2_t>::value || is_same<X, d4_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
{
return data_.d1x4_;
}
......@@ -1167,10 +1221,11 @@ struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value,
static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value ||
is_same<X, d2_t>::value || is_same<X, d4_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
{
return data_.d1x4_;
}
......@@ -1193,6 +1248,7 @@ template <typename T>
struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
{
using d1_t = T;
using d1_nnv_t = non_native_vector_base<T, 1>;
using d2_t = non_native_vector_base<T, 2>;
using d4_t = non_native_vector_base<T, 4>;
using d8_t = non_native_vector_base<T, 8>;
......@@ -1215,11 +1271,12 @@ struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value,
static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value ||
is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
is_same<X, d8_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
{
return data_.d1x8_;
}
......@@ -1244,11 +1301,12 @@ struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value,
static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value ||
is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
is_same<X, d8_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
{
return data_.d1x8_;
}
......@@ -1275,6 +1333,7 @@ template <typename T>
struct vector_type<T, 16, typename std::enable_if_t<!is_native_type<T>()>>
{
using d1_t = T;
using d1_nnv_t = non_native_vector_base<T, 1>;
using d2_t = non_native_vector_base<T, 2>;
using d4_t = non_native_vector_base<T, 4>;
using d8_t = non_native_vector_base<T, 8>;
......@@ -1299,12 +1358,12 @@ struct vector_type<T, 16, typename std::enable_if_t<!is_native_type<T>()>>
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value,
static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value ||
is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
is_same<X, d8_t>::value || is_same<X, d16_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
{
return data_.d1x16_;
}
......@@ -1333,12 +1392,12 @@ struct vector_type<T, 16, typename std::enable_if_t<!is_native_type<T>()>>
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value,
static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value ||
is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
is_same<X, d8_t>::value || is_same<X, d16_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
{
return data_.d1x16_;
}
......@@ -1632,20 +1691,70 @@ using int8x32_t = typename vector_type<int8_t, 32>::type;
using int8x64_t = typename vector_type<int8_t, 64>::type;
// f8
using f8x2_t = typename vector_type<f8_t, 2>::type;
using f8x4_t = typename vector_type<f8_t, 4>::type;
using f8x8_t = typename vector_type<f8_t, 8>::type;
using f8x16_t = typename vector_type<f8_t, 16>::type;
using f8x32_t = typename vector_type<f8_t, 32>::type;
using f8x64_t = typename vector_type<f8_t, 64>::type;
using f8x2_fnuz_t = typename vector_type<f8_fnuz_t, 2>::type;
using f8x4_fnuz_t = typename vector_type<f8_fnuz_t, 4>::type;
using f8x8_fnuz_t = typename vector_type<f8_fnuz_t, 8>::type;
using f8x16_fnuz_t = typename vector_type<f8_fnuz_t, 16>::type;
using f8x32_fnuz_t = typename vector_type<f8_fnuz_t, 32>::type;
using f8x64_fnuz_t = typename vector_type<f8_fnuz_t, 64>::type;
// bf8
using bf8x2_t = typename vector_type<bf8_t, 2>::type;
using bf8x4_t = typename vector_type<bf8_t, 4>::type;
using bf8x8_t = typename vector_type<bf8_t, 8>::type;
using bf8x16_t = typename vector_type<bf8_t, 16>::type;
using bf8x32_t = typename vector_type<bf8_t, 32>::type;
using bf8x64_t = typename vector_type<bf8_t, 64>::type;
using bf8x2_fnuz_t = typename vector_type<bf8_fnuz_t, 2>::type;
using bf8x4_fnuz_t = typename vector_type<bf8_fnuz_t, 4>::type;
using bf8x8_fnuz_t = typename vector_type<bf8_fnuz_t, 8>::type;
using bf8x16_fnuz_t = typename vector_type<bf8_fnuz_t, 16>::type;
using bf8x32_fnuz_t = typename vector_type<bf8_fnuz_t, 32>::type;
using bf8x64_fnuz_t = typename vector_type<bf8_fnuz_t, 64>::type;
// f8
using f8x2_ocp_t = typename vector_type<f8_ocp_t, 2>::type;
using f8x4_ocp_t = typename vector_type<f8_ocp_t, 4>::type;
using f8x8_ocp_t = typename vector_type<f8_ocp_t, 8>::type;
using f8x16_ocp_t = typename vector_type<f8_ocp_t, 16>::type;
using f8x32_ocp_t = typename vector_type<f8_ocp_t, 32>::type;
using f8x64_ocp_t = typename vector_type<f8_ocp_t, 64>::type;
// bf8
using bf8x2_ocp_t = typename vector_type<bf8_ocp_t, 2>::type;
using bf8x4_ocp_t = typename vector_type<bf8_ocp_t, 4>::type;
using bf8x8_ocp_t = typename vector_type<bf8_ocp_t, 8>::type;
using bf8x16_ocp_t = typename vector_type<bf8_ocp_t, 16>::type;
using bf8x32_ocp_t = typename vector_type<bf8_ocp_t, 32>::type;
using bf8x64_ocp_t = typename vector_type<bf8_ocp_t, 64>::type;
#if CK_FP8_TYPE_OCP
// f8
using f8x2_t = f8x2_ocp_t;
using f8x4_t = f8x4_ocp_t;
using f8x8_t = f8x8_ocp_t;
using f8x16_t = f8x16_ocp_t;
using f8x32_t = f8x32_ocp_t;
using f8x64_t = f8x64_ocp_t;
// bf8
using bf8x2_t = bf8x2_ocp_t;
using bf8x4_t = bf8x4_ocp_t;
using bf8x8_t = bf8x8_ocp_t;
using bf8x16_t = bf8x16_ocp_t;
using bf8x32_t = bf8x32_ocp_t;
using bf8x64_t = bf8x64_ocp_t;
#elif CK_FP8_TYPE_FNUZ
// f8
using f8x2_t = f8x2_fnuz_t;
using f8x4_t = f8x4_fnuz_t;
using f8x8_t = f8x8_fnuz_t;
using f8x16_t = f8x16_fnuz_t;
using f8x32_t = f8x32_fnuz_t;
using f8x64_t = f8x64_fnuz_t;
// bf8
using bf8x2_t = bf8x2_fnuz_t;
using bf8x4_t = bf8x4_fnuz_t;
using bf8x8_t = bf8x8_fnuz_t;
using bf8x16_t = bf8x16_fnuz_t;
using bf8x32_t = bf8x32_fnuz_t;
using bf8x64_t = bf8x64_fnuz_t;
#endif
// u8
using uint8x2_t = typename vector_type<uint8_t, 2>::type;
......@@ -1702,7 +1811,7 @@ struct NumericLimits<int4_t>
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
struct NumericLimits<f8_t>
struct NumericLimits<f8_fnuz_t>
{
// negative zero nan mode with exp bias = 8
static constexpr uint8_t binary_min = 0x08; // 0b00001000
......@@ -1715,17 +1824,17 @@ struct NumericLimits<f8_t>
// static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0
__host__ __device__ static constexpr f8_t Min() { return f8_t(binary_min); }
__host__ __device__ static constexpr f8_fnuz_t Min() { return f8_fnuz_t(binary_min); }
__host__ __device__ static constexpr f8_t Max() { return f8_t(binary_max); }
__host__ __device__ static constexpr f8_fnuz_t Max() { return f8_fnuz_t(binary_max); }
__host__ __device__ static constexpr f8_t Lowest() { return f8_t(binary_lowest); }
__host__ __device__ static constexpr f8_fnuz_t Lowest() { return f8_fnuz_t(binary_lowest); }
__host__ __device__ static constexpr f8_t QuietNaN() { return f8_t(binary_qnan); }
__host__ __device__ static constexpr f8_fnuz_t QuietNaN() { return f8_fnuz_t(binary_qnan); }
};
template <>
struct NumericLimits<bf8_t>
struct NumericLimits<bf8_fnuz_t>
{
// negative zero nan mode with exp bias = 16
static constexpr uint8_t binary_min = 0x04; // 0b00000100
......@@ -1738,13 +1847,59 @@ struct NumericLimits<bf8_t>
// static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=
__host__ __device__ static constexpr bf8_t Min() { return bf8_t(binary_min); }
__host__ __device__ static constexpr bf8_fnuz_t Min() { return bf8_fnuz_t(binary_min); }
__host__ __device__ static constexpr bf8_fnuz_t Max() { return bf8_fnuz_t(binary_max); }
__host__ __device__ static constexpr bf8_fnuz_t Lowest() { return bf8_fnuz_t(binary_lowest); }
__host__ __device__ static constexpr bf8_fnuz_t QuietNaN() { return bf8_fnuz_t(binary_qnan); }
};
template <>
struct NumericLimits<f8_ocp_t>
{
static constexpr uint8_t binary_min = 0x08; // 0b00001000 = 2^-6
static constexpr uint8_t binary_max = 0x7E; // 0b01111110 = 448
static constexpr uint8_t binary_lowest = 0xFE; // 0b11111110 = -448
static constexpr uint8_t binary_qnan = 0x7F; // 0b01111111
__host__ __device__ static constexpr f8_ocp_t Min() { return bit_cast<f8_ocp_t>(binary_min); }
__host__ __device__ static constexpr f8_ocp_t Max() { return bit_cast<f8_ocp_t>(binary_max); }
__host__ __device__ static constexpr bf8_t Max() { return bf8_t(binary_max); }
__host__ __device__ static constexpr f8_ocp_t Lowest()
{
return bit_cast<f8_ocp_t>(binary_lowest);
}
__host__ __device__ static constexpr f8_ocp_t QuietNaN()
{
return bit_cast<f8_ocp_t>(binary_qnan);
}
};
template <>
struct NumericLimits<bf8_ocp_t>
{
static constexpr uint8_t binary_min = 0x04; // 0b00000100 = 2^-14
static constexpr uint8_t binary_max = 0x7B; // 0b01111011 = 57344
static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011 = -57344
static constexpr uint8_t binary_qnan = 0x7D; // 0b01111101
__host__ __device__ static constexpr bf8_t Lowest() { return bf8_t(binary_lowest); }
__host__ __device__ static constexpr bf8_ocp_t Min() { return bit_cast<bf8_ocp_t>(binary_min); }
__host__ __device__ static constexpr bf8_t QuietNaN() { return bf8_t(binary_qnan); }
__host__ __device__ static constexpr bf8_ocp_t Max() { return bit_cast<bf8_ocp_t>(binary_max); }
__host__ __device__ static constexpr bf8_ocp_t Lowest()
{
return bit_cast<bf8_ocp_t>(binary_lowest);
}
__host__ __device__ static constexpr bf8_ocp_t QuietNaN()
{
return bit_cast<bf8_ocp_t>(binary_qnan);
}
};
template <typename T>
......@@ -1787,7 +1942,7 @@ struct NumericUtils<half_t>
};
template <>
struct NumericUtils<f8_t>
struct NumericUtils<f8_fnuz_t>
{
static constexpr int exp = 4;
static constexpr int mant = 3;
......@@ -1796,13 +1951,28 @@ struct NumericUtils<f8_t>
};
template <>
struct NumericUtils<bf8_t>
struct NumericUtils<bf8_fnuz_t>
{
static constexpr int exp = 5;
static constexpr int mant = 2;
static constexpr int bias = 16; // negative zero nan mode
// static constexpr int bias = 15; // ieee mode
};
template <>
struct NumericUtils<f8_ocp_t>
{
static constexpr int exp = 4;
static constexpr int mant = 3;
static constexpr int bias = 7;
};
template <>
struct NumericUtils<bf8_ocp_t>
{
static constexpr int exp = 5;
static constexpr int mant = 2;
static constexpr int bias = 15;
};
template <>
struct NumericUtils<bhalf_t>
......
......@@ -80,7 +80,7 @@ static inline __host__ bool isnan(half_t x)
return (xx & 0x7FFF) > 0x7C00;
};
static inline __host__ bool isnan(f8_t x) { return (x & 0x80); };
static inline __host__ bool isnan(f8_t x) { return ck::fp8_is_nan(x); };
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static inline __host__ bool isnan(int4_t x)
......@@ -531,7 +531,7 @@ static inline __device__ bool isnan(half_t x)
return (xx & 0x7FFF) > 0x7C00;
};
static inline __device__ bool isnan(f8_t x) { return (x & 0x80); };
static inline __device__ bool isnan(f8_t x) { return ck::fp8_is_nan(x); };
static inline __device__ half_t sqrt(half_t x)
{
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
namespace ck {
// Pseudo random number generator
......@@ -23,7 +25,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
}
// version for fp16
template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<half_t, T>{}, bool> = false>
template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<_Float16, T>{}, bool> = false>
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t)
{
uint16_t x = *(reinterpret_cast<uint16_t*>(&val));
......@@ -38,9 +40,10 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
}
// return 0 if data is not fp16 or fp32
template <typename T,
template <
typename T,
uint32_t seed_t,
std::enable_if_t<!(std::is_same<float, T>{} || std::is_same<half_t, T>{}), bool> = false>
std::enable_if_t<!(std::is_same<float, T>{} || std::is_same<_Float16, T>{}), bool> = false>
__host__ __device__ uint32_t prand_generator(int id, T val, uint32_t seed = seed_t)
{
std::ignore = id;
......
......@@ -100,6 +100,18 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
return type_convert<bhalf_t>(x_fp32);
}
template <>
inline __host__ __device__ constexpr f8_ocp_t type_convert<f8_ocp_t, int>(int x)
{
return f8_ocp_t{type_convert<f8_ocp_t::data_type>(x)};
}
template <>
inline __host__ __device__ constexpr bf8_ocp_t type_convert<bf8_ocp_t, int>(int x)
{
return bf8_ocp_t{type_convert<bf8_ocp_t::data_type>(x)};
}
// Convert X to Y
template <typename Y, typename X>
__host__ __device__ constexpr Y type_convert_sp(X x)
......@@ -163,7 +175,7 @@ __host__ __device__ constexpr Y f8_convert_sr(X x);
// convert fp32 to fp8 with stochastic rounding
template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, float>(float x)
{
constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
......@@ -189,33 +201,35 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
return utils::
cast_to_f8<float, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x,
rng);
cast_to_f8<float, f8_fnuz_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
// convert fp16 to fp8 with stochastic rounding
template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, half_t>(half_t x)
{
#if defined(__gfx94__)
// convert to float and use native converion
return f8_convert_sr<f8_t>(type_convert<float>(x));
return f8_convert_sr<f8_fnuz_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 1254739;
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils::
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
return utils::cast_to_f8<half_t,
f8_fnuz_t,
negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif
}
// convert fp32 to bf8 with stochastic rounding
template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, float>(float x)
{
constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
......@@ -240,28 +254,32 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
return utils::
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
return utils::cast_to_f8<float,
bf8_fnuz_t,
negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif
}
// convert fp16 to bf8 with stochastic rounding
template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, half_t>(half_t x)
{
#if defined(__gfx94__)
// convert to float and use native converion
return f8_convert_sr<bf8_t>(type_convert<float>(x));
return f8_convert_sr<bf8_fnuz_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 1254739;
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
return utils::cast_to_f8<half_t,
bf8_fnuz_t,
negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif
}
......@@ -271,7 +289,7 @@ __host__ __device__ constexpr Y f8_convert_rne(X x);
// convert fp32 to fp8 with rounding to nearest even
template <>
inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
inline __host__ __device__ f8_fnuz_t f8_convert_rne<f8_fnuz_t, float>(float x)
{
#if defined(__gfx94__)
union
......@@ -296,32 +314,34 @@ inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::
cast_to_f8<float, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x,
rng);
cast_to_f8<float, f8_fnuz_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
// convert fp16 to fp8 with rounding to nearest even
template <>
inline __host__ __device__ f8_t f8_convert_rne<f8_t, half_t>(half_t x)
inline __host__ __device__ f8_fnuz_t f8_convert_rne<f8_fnuz_t, half_t>(half_t x)
{
#if defined(__gfx94__)
// convert to float and use native converion
return f8_convert_rne<f8_t>(type_convert<float>(x));
return f8_convert_rne<f8_fnuz_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
return utils::cast_to_f8<half_t,
f8_fnuz_t,
negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif
}
// convert fp32 to bf8 with rounding to nearest even
template <>
inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x)
inline __host__ __device__ bf8_fnuz_t f8_convert_rne<bf8_fnuz_t, float>(float x)
{
#if defined(__gfx94__)
union
......@@ -345,44 +365,59 @@ inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x)
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
return utils::cast_to_f8<float,
bf8_fnuz_t,
negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif
}
// convert fp16 to bf8 with rounding to nearest even
template <>
inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, half_t>(half_t x)
inline __host__ __device__ bf8_fnuz_t f8_convert_rne<bf8_fnuz_t, half_t>(half_t x)
{
#if defined(__gfx94__)
// convert to float and use native converion
return f8_convert_rne<bf8_t>(type_convert<float>(x));
return f8_convert_rne<bf8_fnuz_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
return utils::cast_to_f8<half_t,
bf8_fnuz_t,
negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif
}
// convert fp32 to fp8
template <>
inline __host__ __device__ f8_fnuz_t type_convert<f8_fnuz_t, float>(float x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_fnuz_t>(x);
#else
return f8_convert_rne<f8_fnuz_t>(x);
#endif
}
// convert fp32 to fp8
template <>
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
inline __host__ __device__ f8_ocp_t type_convert<f8_ocp_t, float>(float x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_t>(x);
return f8_convert_sr<f8_ocp_t>(x);
#else
return f8_convert_rne<f8_t>(x);
return f8_convert_rne<f8_ocp_t>(x);
#endif
}
// convert fp8 to fp32
template <>
inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
inline __host__ __device__ float type_convert<float, f8_fnuz_t>(f8_fnuz_t x)
{
#if defined(__gfx94__)
float fval;
......@@ -392,26 +427,26 @@ inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
return fval;
#else
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<f8_t, float, negative_zero_nan>(x);
return utils::cast_from_f8<f8_fnuz_t, float, negative_zero_nan>(x);
#endif
}
template <>
inline __host__ __device__ float2_t type_convert<float2_t, f8x2_t>(f8x2_t x)
inline __host__ __device__ float2_t type_convert<float2_t, f8x2_fnuz_t>(f8x2_fnuz_t x)
{
#if defined(__gfx94__)
const auto i16val = bit_cast<uint16_t>(x);
return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, 0);
#else
constexpr bool negative_zero_nan = true;
const auto f8x2_v = vector_type<f8_t, 2>(x);
const auto f8x2_v = vector_type<f8_fnuz_t, 2>(x);
vector_type<float, 2> f32x2_v;
f32x2_v.template AsType<float>()(Number<0>{}) =
utils::cast_from_f8<f8_t, float, negative_zero_nan>(
f8x2_v.template AsType<f8_t>()[Number<0>{}]);
utils::cast_from_f8<f8_fnuz_t, float, negative_zero_nan>(
f8x2_v.template AsType<f8_fnuz_t>()[Number<0>{}]);
f32x2_v.template AsType<float>()(Number<1>{}) =
utils::cast_from_f8<f8_t, float, negative_zero_nan>(
f8x2_v.template AsType<f8_t>()[Number<1>{}]);
utils::cast_from_f8<f8_fnuz_t, float, negative_zero_nan>(
f8x2_v.template AsType<f8_fnuz_t>()[Number<1>{}]);
return f32x2_v.template AsType<float2_t>()[Number<0>{}];
#endif
}
......@@ -428,42 +463,64 @@ inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x)
// convert fp16 to fp8
template <>
inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
inline __host__ __device__ f8_fnuz_t type_convert<f8_fnuz_t, half_t>(half_t x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_fnuz_t>(x);
#else
return f8_convert_rne<f8_fnuz_t>(x);
#endif
}
// convert fp16 to fp8
template <>
inline __host__ __device__ f8_ocp_t type_convert<f8_ocp_t, half_t>(half_t x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_t>(x);
return f8_convert_sr<f8_ocp_t>(x);
#else
return f8_convert_rne<f8_t>(x);
return f8_convert_rne<f8_ocp_t>(x);
#endif
}
// convert fp8 to fp16
template <>
inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
inline __host__ __device__ half_t type_convert<half_t, f8_fnuz_t>(f8_fnuz_t x)
{
#if defined(__gfx94__)
// use native conversion to float and convert to fp16
return type_convert<half_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<f8_t, half_t, negative_zero_nan>(x);
return utils::cast_from_f8<f8_fnuz_t, half_t, negative_zero_nan>(x);
#endif
}
// convert fp32 to bf8
template <>
inline __host__ __device__ bf8_t type_convert<bf8_t, float>(float x)
inline __host__ __device__ bf8_fnuz_t type_convert<bf8_fnuz_t, float>(float x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_t>(x);
return f8_convert_sr<bf8_fnuz_t>(x);
#else
return f8_convert_rne<bf8_t>(x);
return f8_convert_rne<bf8_fnuz_t>(x);
#endif
}
// convert fp32 to bf8
template <>
inline __host__ __device__ bf8_ocp_t type_convert<bf8_ocp_t, float>(float x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_ocp_t>(x);
#else
return f8_convert_rne<bf8_ocp_t>(x);
#endif
}
// convert bf8 to fp32
template <>
inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x)
inline __host__ __device__ float type_convert<float, bf8_fnuz_t>(bf8_fnuz_t x)
{
#if defined(__gfx94__)
float fval;
......@@ -473,31 +530,42 @@ inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x)
return fval;
#else
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<bf8_t, float, negative_zero_nan>(x);
return utils::cast_from_f8<bf8_fnuz_t, float, negative_zero_nan>(x);
#endif
}
// convert fp16 to bf8
template <>
inline __host__ __device__ bf8_fnuz_t type_convert<bf8_fnuz_t, half_t>(half_t x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_fnuz_t>(x);
#else
return f8_convert_rne<bf8_fnuz_t>(x);
#endif
}
// convert fp16 to bf8
template <>
inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
inline __host__ __device__ bf8_ocp_t type_convert<bf8_ocp_t, half_t>(half_t x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_t>(x);
return f8_convert_sr<bf8_ocp_t>(x);
#else
return f8_convert_rne<bf8_t>(x);
return f8_convert_rne<bf8_ocp_t>(x);
#endif
}
// convert bf8 to fp16
template <>
inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
inline __host__ __device__ half_t type_convert<half_t, bf8_fnuz_t>(bf8_fnuz_t x)
{
#if defined(__gfx94__)
// use native conversion to float and convert to fp16
return type_convert<half_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<bf8_t, half_t, negative_zero_nan>(x);
return utils::cast_from_f8<bf8_fnuz_t, half_t, negative_zero_nan>(x);
#endif
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -62,9 +62,9 @@ struct ReferenceGemm : public device::BaseOperator
auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = arg.a_m_k_.mDesc.GetLengths()[1];
AccDataType v_acc = 0;
ComputeTypeA v_a = 0;
ComputeTypeB v_b = 0;
AccDataType v_acc{0};
ComputeTypeA v_a{0};
ComputeTypeB v_b{0};
for(int k = 0; k < K; ++k)
{
......@@ -93,7 +93,7 @@ struct ReferenceGemm : public device::BaseOperator
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
}
CDataType v_c = 0;
CDataType v_c{0};
arg.c_element_op_(v_c, v_acc);
......
......@@ -326,7 +326,7 @@ struct Tensor
std::size_t GetElementSpaceSizeInBytes() const { return sizeof(T) * GetElementSpaceSize(); }
void SetZero() { ck::ranges::fill<T>(mData, 0); }
void SetZero() { ck::ranges::fill<T>(mData, T{0}); }
template <typename F>
void ForEach_impl(F&& f, std::vector<size_t>& idx, size_t rank)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -37,7 +37,7 @@ struct GeneratorTensor_1<ck::half_t>
float value = 1.0;
template <typename... Is>
ck::bhalf_t operator()(Is...)
ck::half_t operator()(Is...)
{
return ck::type_convert<ck::half_t>(value);
}
......@@ -62,7 +62,7 @@ struct GeneratorTensor_1<ck::f8_t>
float value = 1.0;
template <typename... Is>
ck::bhalf_t operator()(Is...)
ck::f8_t operator()(Is...)
{
return ck::type_convert<ck::f8_t>(value);
}
......@@ -256,14 +256,33 @@ struct GeneratorTensor_Checkboard
}
};
template <ck::index_t Dim>
/**
* @brief Is used to generate sequential values based on the specified dimension.
*
* @tparam T The type of the tensor values.
* @tparam Dim The specific dimension used for generation.
*
* GeneratorTensor_Sequential<1>{} will generate the following values for a 3x3 tensor:
*
* 0 1 2
* 0 1 2
* 0 1 2
*
* Essentially, the values generated are logical coordinates of the generated element that
* correspond to dimension Dim. E.g. for 2-dimensional tensor and Dim=1, the values are the column
* indices.
*
*/
template <typename T, ck::index_t Dim>
struct GeneratorTensor_Sequential
{
template <typename... Ts>
float operator()(Ts... Xs) const
T operator()(Ts... Xs) const
{
std::array<ck::index_t, sizeof...(Ts)> dims = {{static_cast<ck::index_t>(Xs)...}};
return dims[Dim];
float tmp = dims[Dim];
return ck::type_convert<T>(tmp);
}
};
......
......@@ -70,13 +70,13 @@ function(add_instance_library INSTANCE_NAME)
# Do not build gemm_universal_f8 or gemm_multiply_multiply_f8 for any targets except gfx94
if(NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH)
foreach(source IN LISTS ARGN)
if(NOT INST_TARGETS MATCHES "gfx94" AND source MATCHES "gemm_multiply_multiply_xdl_f8")
if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx95" AND source MATCHES "gemm_multiply_multiply_xdl_f8")
message("removing gemm_multiply_multiply_f8 instance ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
foreach(source IN LISTS ARGN)
if(NOT INST_TARGETS MATCHES "gfx94" AND source MATCHES "gemm_xdl_universal" AND source MATCHES "_f8_")
if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx95" AND source MATCHES "gemm_xdl_universal" AND source MATCHES "_f8_")
message("removing gemm_universal_f8 instance ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
......
......@@ -15,7 +15,7 @@ void add_device_pool3d_fwd_ndhwc_f8_instances(
instances)
{
add_device_operation_instances(
instances, device_pool3d_fwd_ndhwc_instances<F8, F8, I32, F8, ReduceOpId, false>{});
instances, device_pool3d_fwd_ndhwc_instances<F8, F8, I32, F32, ReduceOpId, false>{});
}
void add_device_pool3d_fwd_ndhwc_index_f8_instances(
......@@ -23,7 +23,7 @@ void add_device_pool3d_fwd_ndhwc_index_f8_instances(
instances)
{
add_device_operation_instances(
instances, device_pool3d_fwd_ndhwc_instances<F8, F8, I32, F8, ReduceOpId, true>{});
instances, device_pool3d_fwd_ndhwc_instances<F8, F8, I32, F32, ReduceOpId, true>{});
}
} // namespace instance
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -150,7 +150,7 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification,
break;
default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<D0DataType>{1});
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -157,7 +157,7 @@ bool profile_batched_gemm_gemm_impl(bool do_verification,
break;
default:
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -174,7 +174,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
break;
default:
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -140,7 +140,7 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
break;
default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -74,8 +74,8 @@ int profile_gemm_impl(int do_verification,
switch(init_method)
{
case 0:
ck::utils::FillConstant<ADataType>{static_cast<ADataType>(1.f)}(a_m_k);
ck::utils::FillConstant<BDataType>{static_cast<BDataType>(1.f)}(b_k_n);
ck::utils::FillConstant<ADataType>{type_convert<ADataType>(1.f)}(a_m_k);
ck::utils::FillConstant<BDataType>{type_convert<BDataType>(1.f)}(b_k_n);
break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
......
......@@ -206,7 +206,7 @@ add_subdirectory(wrapper)
if(SUPPORTED_GPU_TARGETS MATCHES "gfx11")
add_subdirectory(wmma_op)
endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx942" AND CK_HIP_VERSION_MAJOR GREATER_EQUAL 6 AND CK_HIP_VERSION_MINOR GREATER_EQUAL 2) # smfmac needs ROCm6.2
if((SUPPORTED_GPU_TARGETS MATCHES "gfx942" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95") AND CK_HIP_VERSION_MAJOR GREATER_EQUAL 6 AND CK_HIP_VERSION_MINOR GREATER_EQUAL 2) # smfmac needs ROCm6.2
add_subdirectory(smfmac_op)
endif()
add_subdirectory(position_embedding)
......
......@@ -9,13 +9,38 @@ if (USE_BITINT_EXTENSION_INT4)
endif()
endif()
add_gtest_executable(test_fp8 test_fp8.cpp)
if(result EQUAL 0)
target_link_libraries(test_fp8 PRIVATE utility)
add_custom_target(test_fp8)
if (CK_USE_OCP_FP8)
add_gtest_executable(test_fp8_ocp test_fp8_ocp.cpp)
if(result EQUAL 0)
target_link_libraries(test_fp8_ocp PRIVATE utility)
endif()
add_gtest_executable(test_bf8_ocp test_bf8_ocp.cpp)
if(result EQUAL 0)
target_link_libraries(test_bf8_ocp PRIVATE utility)
endif()
add_dependencies(test_fp8 test_fp8_ocp)
add_dependencies(test_fp8 test_bf8_ocp)
endif()
add_gtest_executable(test_bf8 test_bf8.cpp)
if(result EQUAL 0)
target_link_libraries(test_bf8 PRIVATE utility)
if (CK_USE_FNUZ_FP8)
add_gtest_executable(test_fp8_fnuz test_fp8_fnuz.cpp)
if(result EQUAL 0)
target_link_libraries(test_fp8_fnuz PRIVATE utility)
endif()
add_gtest_executable(test_bf8_fnuz test_bf8_fnuz.cpp)
if(result EQUAL 0)
target_link_libraries(test_bf8_fnuz PRIVATE utility)
endif()
add_dependencies(test_fp8 test_fp8_fnuz)
add_dependencies(test_fp8 test_bf8_fnuz)
endif()
add_gtest_executable(test_custom_type test_custom_type.cpp)
......
......@@ -5,158 +5,169 @@
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
using ck::bf8_t;
using ck::bf8_fnuz_t;
using ck::f8_convert_rne;
using ck::f8_convert_sr;
using ck::half_t;
using ck::type_convert;
TEST(BF8, NumericLimits)
TEST(BF8FNUZ, NumericLimits)
{
// constants given for negative zero nan mode
EXPECT_EQ(ck::NumericLimits<bf8_t>::Min(), type_convert<bf8_t>(0x04));
EXPECT_EQ(ck::NumericLimits<bf8_t>::Max(), type_convert<bf8_t>(0x7F));
EXPECT_EQ(ck::NumericLimits<bf8_t>::Lowest(), type_convert<bf8_t>(0xFF));
EXPECT_EQ(ck::NumericLimits<bf8_t>::QuietNaN(), type_convert<bf8_t>(0x80));
EXPECT_EQ(ck::NumericLimits<bf8_fnuz_t>::Min(), type_convert<bf8_fnuz_t>(0x04));
EXPECT_EQ(ck::NumericLimits<bf8_fnuz_t>::Max(), type_convert<bf8_fnuz_t>(0x7F));
EXPECT_EQ(ck::NumericLimits<bf8_fnuz_t>::Lowest(), type_convert<bf8_fnuz_t>(0xFF));
EXPECT_EQ(ck::NumericLimits<bf8_fnuz_t>::QuietNaN(), type_convert<bf8_fnuz_t>(0x80));
}
TEST(BF8, ConvertFP32Nearest)
TEST(BF8FNUZ, ConvertFP32Nearest)
{
// fix the tolerance value
float abs_tol = 1e-6;
// convert 0 float to bf8 and back, check if holds
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_rne<bf8_t>(0.0f)), abs_tol);
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_rne<bf8_fnuz_t>(0.0f)), abs_tol);
// don't run the next test on gfx11 devices
#ifndef CK_SKIP_FLAKY_F8_TEST
// convert minimal float to bf8 and back, check if holds
ASSERT_NEAR(std::numeric_limits<float>::min(),
type_convert<float>(f8_convert_rne<bf8_t>(std::numeric_limits<float>::min())),
type_convert<float>(f8_convert_rne<bf8_fnuz_t>(std::numeric_limits<float>::min())),
abs_tol);
#endif
// convert maximal bf8_t to float and check if equal to 57344.0
ASSERT_NEAR(57344.0f, type_convert<float>(f8_convert_rne<bf8_t>(57344.0f)), abs_tol);
const auto max_bf8_t_float = type_convert<float>(ck::NumericLimits<bf8_fnuz_t>::Max());
// convert maximal bf8_fnuz_t to float and check if equal to 57344.0
ASSERT_NEAR(
max_bf8_t_float, type_convert<float>(f8_convert_rne<bf8_fnuz_t>(max_bf8_t_float)), abs_tol);
// convert maximal float to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR(57344.0f,
type_convert<float>(f8_convert_rne<bf8_t>(std::numeric_limits<float>::max())),
ASSERT_NEAR(max_bf8_t_float,
type_convert<float>(f8_convert_rne<bf8_fnuz_t>(std::numeric_limits<float>::max())),
abs_tol);
// convert inf float to bf8_t and check if it is qNan
ASSERT_NEAR(type_convert<bf8_t>(0x80),
f8_convert_rne<bf8_t>(std::numeric_limits<float>::infinity()),
// convert inf float to bf8_fnuz_t and check if it is qNan
ASSERT_NEAR(ck::NumericLimits<bf8_fnuz_t>::QuietNaN(),
f8_convert_rne<bf8_fnuz_t>(std::numeric_limits<float>::infinity()),
abs_tol);
// positive norm float value to bf8 and back, check if holds
float pos_float = 0.0000762939f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<bf8_t>(pos_float)), abs_tol);
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<bf8_fnuz_t>(pos_float)), abs_tol);
// negative norm float value to bf8 and back, check if holds
float neg_float = -0.0000610351f;
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<bf8_t>(neg_float)), abs_tol);
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<bf8_fnuz_t>(neg_float)), abs_tol);
// positive subnorm float value to bf8 and back, check if holds
pos_float = 0.0000305175f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<bf8_t>(pos_float)), abs_tol);
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<bf8_fnuz_t>(pos_float)), abs_tol);
// negative subnorm float value to bf8 and back, check if holds
neg_float = -0.0000152587f;
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<bf8_t>(neg_float)), abs_tol);
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<bf8_fnuz_t>(neg_float)), abs_tol);
}
TEST(BF8, ConvertFP32Stochastic)
TEST(BF8FNUZ, ConvertFP32Stochastic)
{
// fix the tolerance value
float abs_tol = 1e-6;
// convert 0 float to bf8 and back, check if holds
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_sr<bf8_t>(0.0f)), abs_tol);
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_sr<bf8_fnuz_t>(0.0f)), abs_tol);
// convert minimal float to bf8 and back, check if holds
ASSERT_NEAR(std::numeric_limits<float>::min(),
type_convert<float>(f8_convert_sr<bf8_t>(std::numeric_limits<float>::min())),
type_convert<float>(f8_convert_sr<bf8_fnuz_t>(std::numeric_limits<float>::min())),
abs_tol);
// convert maximal bf8_t to float and check if equal to 57344.0
ASSERT_NEAR(57344.0f, type_convert<float>(f8_convert_sr<bf8_t>(57344.0f)), abs_tol);
const auto max_bf8_t_float = type_convert<float>(ck::NumericLimits<bf8_fnuz_t>::Max());
// convert maximal bf8_fnuz_t to float and check if equal to 57344.0
ASSERT_NEAR(
max_bf8_t_float, type_convert<float>(f8_convert_sr<bf8_fnuz_t>(max_bf8_t_float)), abs_tol);
// convert maximal float to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR(57344.0f,
type_convert<float>(f8_convert_sr<bf8_t>(std::numeric_limits<float>::max())),
ASSERT_NEAR(max_bf8_t_float,
type_convert<float>(f8_convert_sr<bf8_fnuz_t>(std::numeric_limits<float>::max())),
abs_tol);
// convert inf float to bf8_t and check if it is qNan
ASSERT_NEAR(type_convert<bf8_t>(0x80),
f8_convert_sr<bf8_t>(std::numeric_limits<float>::infinity()),
// convert inf float to bf8_fnuz_t and check if it is qNan
ASSERT_NEAR(ck::NumericLimits<bf8_fnuz_t>::QuietNaN(),
f8_convert_sr<bf8_fnuz_t>(std::numeric_limits<float>::infinity()),
abs_tol);
// positive norm float value to bf8 and back, check if holds
float pos_float = 0.0000762939f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<bf8_t>(pos_float)), abs_tol);
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<bf8_fnuz_t>(pos_float)), abs_tol);
// negative norm float value to bf8 and back, check if holds
float neg_float = -0.0000610351f;
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<bf8_t>(neg_float)), abs_tol);
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<bf8_fnuz_t>(neg_float)), abs_tol);
// positive subnorm float value to bf8 and back, check if holds
pos_float = 0.0000305175f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<bf8_t>(pos_float)), abs_tol);
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<bf8_fnuz_t>(pos_float)), abs_tol);
// negative subnorm float value to bf8 and back, check if holds
neg_float = -0.0000152587f;
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<bf8_t>(neg_float)), abs_tol);
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<bf8_fnuz_t>(neg_float)), abs_tol);
}
TEST(BF8, ConvertFP16Nearest)
TEST(BF8FNUZ, ConvertFP16Nearest)
{
// fix the tolerance value
float abs_tol = 1e-3;
// convert 0 fp16 to bf8 and back, check if holds
ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(f8_convert_rne<bf8_t>(half_t{0.0})), abs_tol);
ASSERT_NEAR(
half_t{0.0}, type_convert<half_t>(f8_convert_rne<bf8_fnuz_t>(half_t{0.0})), abs_tol);
// convert minimal fp16 to bf8 and back, check if holds
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
type_convert<half_t>(f8_convert_rne<bf8_t>(ck::NumericLimits<half_t>::Min())),
type_convert<half_t>(f8_convert_rne<bf8_fnuz_t>(ck::NumericLimits<half_t>::Min())),
abs_tol);
// convert maximal bf8_t to fp16 and check if equal to 57344.0
const auto max_bf8_t_half = type_convert<half_t>(ck::NumericLimits<bf8_fnuz_t>::Max());
// convert maximal bf8_fnuz_t to fp16 and check if equal to 57344.0
ASSERT_NEAR(
half_t{57344.0}, type_convert<half_t>(f8_convert_rne<bf8_t>(half_t{57344.0})), abs_tol);
max_bf8_t_half, type_convert<half_t>(f8_convert_rne<bf8_fnuz_t>(max_bf8_t_half)), abs_tol);
// convert maximal fp16 to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR(half_t{57344.0},
type_convert<half_t>(f8_convert_rne<bf8_t>(ck::NumericLimits<half_t>::Max())),
ASSERT_NEAR(max_bf8_t_half,
type_convert<half_t>(f8_convert_rne<bf8_fnuz_t>(ck::NumericLimits<half_t>::Max())),
abs_tol);
// convert QuietNaN fp16 to bf8_t and check if it is QuietNaN
ASSERT_NEAR(type_convert<bf8_t>(0x80),
f8_convert_rne<bf8_t>(ck::NumericLimits<half_t>::QuietNaN()),
// convert QuietNaN fp16 to bf8_fnuz_t and check if it is QuietNaN
ASSERT_NEAR(ck::NumericLimits<bf8_fnuz_t>::QuietNaN(),
f8_convert_rne<bf8_fnuz_t>(ck::NumericLimits<half_t>::QuietNaN()),
abs_tol);
// positive norm fp16 value to bf8 and back, check if holds
half_t pos_half = half_t{0.0000762939};
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_rne<bf8_t>(pos_half)), abs_tol);
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_rne<bf8_fnuz_t>(pos_half)), abs_tol);
// negative norm fp16 value to bf8 and back, check if holds
half_t neg_half = half_t{-0.0000610351};
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_rne<bf8_t>(neg_half)), abs_tol);
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_rne<bf8_fnuz_t>(neg_half)), abs_tol);
// positive subnorm fp16 value to bf8 and back, check if holds
pos_half = half_t{0.0000305175};
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_rne<bf8_t>(pos_half)), abs_tol);
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_rne<bf8_fnuz_t>(pos_half)), abs_tol);
// negative subnorm fp16 value to bf8 and back, check if holds
neg_half = half_t{-0.0000152587};
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_rne<bf8_t>(neg_half)), abs_tol);
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_rne<bf8_fnuz_t>(neg_half)), abs_tol);
}
TEST(BF8, ConvertFP16Stochastic)
TEST(BF8FNUZ, ConvertFP16Stochastic)
{
// fix the tolerance value
float abs_tol = 1e-3;
// convert 0 fp16 to bf8 and back, check if holds
ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(f8_convert_sr<bf8_t>(half_t{0.0})), abs_tol);
ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(f8_convert_sr<bf8_fnuz_t>(half_t{0.0})), abs_tol);
// convert minimal fp16 to bf8 and back, check if holds
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
type_convert<half_t>(f8_convert_sr<bf8_t>(ck::NumericLimits<half_t>::Min())),
type_convert<half_t>(f8_convert_sr<bf8_fnuz_t>(ck::NumericLimits<half_t>::Min())),
abs_tol);
// convert maximal bf8_t to fp16 and check if equal to 57344.0
const auto max_bf8_t_half = type_convert<half_t>(ck::NumericLimits<bf8_fnuz_t>::Max());
// convert maximal bf8_fnuz_t to fp16 and check if equal to 57344.0
ASSERT_NEAR(
half_t{57344.0}, type_convert<half_t>(f8_convert_sr<bf8_t>(half_t{57344.0})), abs_tol);
max_bf8_t_half, type_convert<half_t>(f8_convert_sr<bf8_fnuz_t>(max_bf8_t_half)), abs_tol);
// convert maximal fp16 to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR(half_t{57344.0},
type_convert<half_t>(f8_convert_sr<bf8_t>(ck::NumericLimits<half_t>::Max())),
ASSERT_NEAR(max_bf8_t_half,
type_convert<half_t>(f8_convert_sr<bf8_fnuz_t>(ck::NumericLimits<half_t>::Max())),
abs_tol);
// convert QuietNaN fp16 to bf8_t and check if it is QuietNaN
ASSERT_NEAR(type_convert<bf8_t>(0x80),
f8_convert_sr<bf8_t>(ck::NumericLimits<half_t>::QuietNaN()),
// convert QuietNaN fp16 to bf8_fnuz_t and check if it is QuietNaN
ASSERT_NEAR(ck::NumericLimits<bf8_fnuz_t>::QuietNaN(),
f8_convert_sr<bf8_fnuz_t>(ck::NumericLimits<half_t>::QuietNaN()),
abs_tol);
// positive norm fp16 value to bf8 and back, check if holds
half_t pos_half = half_t{0.0000762939};
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<bf8_t>(pos_half)), abs_tol);
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<bf8_fnuz_t>(pos_half)), abs_tol);
// negative norm fp16 value to bf8 and back, check if holds
half_t neg_half = half_t{-0.0000610351};
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<bf8_t>(neg_half)), abs_tol);
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<bf8_fnuz_t>(neg_half)), abs_tol);
// positive subnorm fp16 value to bf8 and back, check if holds
pos_half = half_t{0.0000305175};
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<bf8_t>(pos_half)), abs_tol);
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<bf8_fnuz_t>(pos_half)), abs_tol);
// negative subnorm fp16 value to bf8 and back, check if holds
neg_half = half_t{-0.0000152587};
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<bf8_t>(neg_half)), abs_tol);
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<bf8_fnuz_t>(neg_half)), abs_tol);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
using ck::bf8_ocp_t;
using ck::f8_convert_rne;
using ck::f8_convert_sr;
using ck::half_t;
using ck::type_convert;
TEST(BF8OCP, NumericLimits)
{ // constants given for OCP FP8
EXPECT_EQ(ck::NumericLimits<bf8_ocp_t>::Min(),
type_convert<bf8_ocp_t>(0x04)); // 0b00000100 = 2^-14
EXPECT_EQ(ck::NumericLimits<bf8_ocp_t>::Max(),
type_convert<bf8_ocp_t>(0x7B)); // 0b01111011 = 57344
EXPECT_EQ(ck::NumericLimits<bf8_ocp_t>::Lowest(),
type_convert<bf8_ocp_t>(0xFB)); // 0b11111011 = -57344
EXPECT_EQ(ck::NumericLimits<bf8_ocp_t>::QuietNaN().data,
type_convert<bf8_ocp_t>(0x7D).data); // 0b01111101
EXPECT_FALSE(ck::NumericLimits<bf8_ocp_t>::QuietNaN() ==
ck::NumericLimits<bf8_ocp_t>::QuietNaN());
EXPECT_TRUE(ck::fp8_is_inf(type_convert<bf8_ocp_t>(0xFC)) &&
ck::fp8_is_inf(type_convert<bf8_ocp_t>(0x7C)));
}
TEST(BF8OCP, ConvertFP32Nearest)
{
// fix the tolerance value
float abs_tol = 1e-6;
// convert 0 float to bfp8 and back, check if holds
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_rne<bf8_ocp_t>(0.0f)), 0.0f);
// convert minimal float to bf8 and back, check if holds
ASSERT_NEAR(std::numeric_limits<float>::min(),
type_convert<float>(f8_convert_rne<bf8_ocp_t>(std::numeric_limits<float>::min())),
abs_tol);
const auto max_bf8_t_float = type_convert<float>(ck::NumericLimits<bf8_ocp_t>::Max());
// convert maximal bf8_ocp_t to float and check if equal to bf8 max
ASSERT_NEAR(
max_bf8_t_float, type_convert<float>(f8_convert_rne<bf8_ocp_t>(max_bf8_t_float)), 0.0f);
// convert maximal float to bf8 and back, check if clipped to bf8 max (saturation to finite)
ASSERT_NEAR(max_bf8_t_float,
type_convert<float>(f8_convert_rne<bf8_ocp_t>(std::numeric_limits<float>::max())),
0.0f);
// convert float infinity to bf8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ(ck::NumericLimits<bf8_ocp_t>::Max(),
f8_convert_rne<bf8_ocp_t>(std::numeric_limits<float>::infinity()));
// positive normal float value to bf8 and back, check if holds
float pos_float = 0.0000762939f; // 10*2^-17
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<bf8_ocp_t>(pos_float)), abs_tol);
// negative smallest normal bf8 value to bf8 and back, check if holds
constexpr auto neg_min_bf8 = -0.00006103515625f; //-2^-14
ASSERT_NEAR(neg_min_bf8, type_convert<float>(f8_convert_rne<bf8_ocp_t>(neg_min_bf8)), 0.0f);
// positive subnorm float value to bf8 and back, check if holds
constexpr auto pos_subnorm_bf8 = 0.000030517578125f; // 2^-15
ASSERT_NEAR(
pos_subnorm_bf8, type_convert<float>(f8_convert_rne<bf8_ocp_t>(pos_subnorm_bf8)), 0.0f);
// min subnorm bf8 value to bf8 and back, check if holds
constexpr auto min_subnorm_bf8 = -0.0000152587890625f; //-2^-16
ASSERT_NEAR(
min_subnorm_bf8, type_convert<float>(f8_convert_rne<bf8_ocp_t>(min_subnorm_bf8)), 0.0f);
// smaller than min subnorm bf8 value to bf8 must be zero
constexpr auto less_than_min_subnorm = 0.00000762939453125f; // 2^-17
ASSERT_EQ(0.0f, type_convert<float>(f8_convert_rne<bf8_ocp_t>(less_than_min_subnorm)));
// convert quiet NaN to bf8_ocp_t and check if it is quiet NaN
const auto bf8_nan = f8_convert_rne<bf8_ocp_t>(std::numeric_limits<float>::quiet_NaN());
ASSERT_TRUE(ck::fp8_impl::ocp_bf8_is_nan(bf8_nan.data));
}
TEST(BF8OCP, ConvertFP32Stochastic)
{
// fix the tolerance value
float abs_tol = 1e-6;
// convert 0 float to bfp8 and back, check if holds
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_sr<bf8_ocp_t>(0.0f)), 0.0f);
// convert minimal float to bf8 and back, check if holds
ASSERT_NEAR(std::numeric_limits<float>::min(),
type_convert<float>(f8_convert_sr<bf8_ocp_t>(std::numeric_limits<float>::min())),
abs_tol);
const auto max_bf8_t_float = type_convert<float>(ck::NumericLimits<bf8_ocp_t>::Max());
// convert maximal bf8_ocp_t to float and check if equal to bf8 max
ASSERT_NEAR(
max_bf8_t_float, type_convert<float>(f8_convert_sr<bf8_ocp_t>(max_bf8_t_float)), 0.0f);
// convert maximal float to bf8 and back, check if clipped to bf8 max (saturation to finite)
ASSERT_NEAR(max_bf8_t_float,
type_convert<float>(f8_convert_sr<bf8_ocp_t>(std::numeric_limits<float>::max())),
0.0f);
// convert float infinity to bf8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ(ck::NumericLimits<bf8_ocp_t>::Max(),
f8_convert_sr<bf8_ocp_t>(std::numeric_limits<float>::infinity()));
// positive normal float value to bf8 and back, check if holds
float pos_float = 0.0000762939f; // 10*2^-17
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<bf8_ocp_t>(pos_float)), abs_tol);
// negative smallest normal bf8 value to bf8 and back, check if holds
constexpr auto neg_min_bf8 = -0.00006103515625f; //-2^-14
ASSERT_NEAR(neg_min_bf8, type_convert<float>(f8_convert_sr<bf8_ocp_t>(neg_min_bf8)), 0.0f);
// positive subnorm float value to bf8 and back, check if holds
constexpr auto pos_subnorm_bf8 = 0.000030517578125f; // 2^-15
ASSERT_NEAR(
pos_subnorm_bf8, type_convert<float>(f8_convert_sr<bf8_ocp_t>(pos_subnorm_bf8)), 0.0f);
// min subnorm bf8 value to bf8 and back, check if holds
constexpr auto min_subnorm_bf8 = -0.0000152587890625f; //-2^-16
ASSERT_NEAR(
min_subnorm_bf8, type_convert<float>(f8_convert_sr<bf8_ocp_t>(min_subnorm_bf8)), 0.0f);
// smaller than min subnorm bf8 value to bf8 alternates between 0 and 2^-16
constexpr auto less_than_min_subnorm = 0.00000762939453125f; // 2^-17
ASSERT_NEAR(0.0f,
type_convert<float>(f8_convert_sr<bf8_ocp_t>(less_than_min_subnorm)),
0.0000152587890625f);
// convert quiet NaN to bf8_ocp_t and check if it is quiet NaN
const auto bf8_nan = f8_convert_sr<bf8_ocp_t>(std::numeric_limits<float>::quiet_NaN());
ASSERT_TRUE(ck::fp8_impl::ocp_bf8_is_nan(bf8_nan.data));
}
TEST(BF8OCP, ConvertFP16Nearest)
{
// fix the tolerance value
constexpr half_t half_t_tol = 1e-3;
constexpr half_t half_t_zero = 0.0;
// convert 0 half_t to bfp8 and back, check if holds
ASSERT_NEAR(
half_t_zero, type_convert<half_t>(f8_convert_rne<bf8_ocp_t>(half_t_zero)), half_t_zero);
// convert minimal half_t to bf8 and back, check if holds
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
type_convert<half_t>(f8_convert_rne<bf8_ocp_t>(ck::NumericLimits<half_t>::Min())),
half_t_tol);
const auto max_bf8_t_half_t = type_convert<half_t>(ck::NumericLimits<bf8_ocp_t>::Max());
// convert maximal bf8_ocp_t to half_t and check if equal to bf8 max
ASSERT_NEAR(max_bf8_t_half_t,
type_convert<half_t>(f8_convert_rne<bf8_ocp_t>(max_bf8_t_half_t)),
half_t_zero);
// convert maximal half_t to bf8 and back, check if clipped to bf8 max (saturation to finite)
ASSERT_NEAR(max_bf8_t_half_t,
type_convert<half_t>(f8_convert_rne<bf8_ocp_t>(ck::NumericLimits<half_t>::Max())),
half_t_zero);
// convert half_t infinity to bf8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ(
ck::NumericLimits<bf8_ocp_t>::Max(),
f8_convert_rne<bf8_ocp_t>(type_convert<half_t>(std::numeric_limits<float>::infinity())));
// positive normal bf8 value to bf8 and back, check if holds
constexpr half_t pos_norm_bf8{0.0000762939f}; // 10*2^-17
ASSERT_NEAR(
pos_norm_bf8, type_convert<half_t>(f8_convert_rne<bf8_ocp_t>(pos_norm_bf8)), half_t_tol);
// negative smallest normal bf8 value to bf8 and back, check if holds
constexpr half_t neg_min_bf8{-0.00006103515625f}; //-2^-14
ASSERT_NEAR(
neg_min_bf8, type_convert<half_t>(f8_convert_rne<bf8_ocp_t>(neg_min_bf8)), half_t_zero);
// positive subnorm bf8 value to bf8 and back, check if holds
constexpr half_t pos_subnorm_bf8{0.000030517578125f}; // 2^-15
ASSERT_NEAR(pos_subnorm_bf8,
type_convert<half_t>(f8_convert_rne<bf8_ocp_t>(pos_subnorm_bf8)),
half_t_zero);
// min subnorm bf8 value to bf8 and back, check if holds
constexpr half_t min_subnorm_bf8{-0.0000152587890625f}; //-2^-16
ASSERT_NEAR(min_subnorm_bf8,
type_convert<half_t>(f8_convert_rne<bf8_ocp_t>(min_subnorm_bf8)),
half_t_zero);
// smaller than min subnorm bf8 value to bf8 must be zero
constexpr half_t less_than_min_subnorm{0.00000762939453125f}; // 2^-17
ASSERT_EQ(half_t_zero, type_convert<half_t>(f8_convert_rne<bf8_ocp_t>(less_than_min_subnorm)));
// convert quiet NaN to bf8_ocp_t and check if it is quiet NaN
const auto bf8_nan = f8_convert_rne<bf8_ocp_t>(ck::NumericLimits<half_t>::QuietNaN());
ASSERT_TRUE(ck::fp8_impl::ocp_bf8_is_nan(bf8_nan.data));
}
TEST(BF8OCP, ConvertFP16Stochastic)
{
// fix the tolerance value
constexpr half_t half_t_tol = 1e-3;
constexpr half_t half_t_zero = 0.0;
constexpr auto min_subnorm_bf8 = 0.0000152587890625f; // 2^-16
// convert 0 half_t to bfp8 and back, check if holds
ASSERT_NEAR(
half_t_zero, type_convert<half_t>(f8_convert_sr<bf8_ocp_t>(half_t_zero)), half_t_zero);
// convert minimal half_t (6.103515625e-05) to fp8 and back
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
type_convert<half_t>(f8_convert_sr<bf8_ocp_t>(ck::NumericLimits<half_t>::Min())),
half_t_zero);
const auto max_bf8_t_half_t = type_convert<half_t>(ck::NumericLimits<bf8_ocp_t>::Max());
// convert maximal bf8_ocp_t to half_t and check if equal to bf8 max
ASSERT_NEAR(max_bf8_t_half_t,
type_convert<half_t>(f8_convert_sr<bf8_ocp_t>(max_bf8_t_half_t)),
half_t_zero);
// convert maximal half_t to bf8 and back, check if clipped to bf8 max (saturation to finite)
ASSERT_NEAR(max_bf8_t_half_t,
type_convert<half_t>(f8_convert_sr<bf8_ocp_t>(ck::NumericLimits<half_t>::Max())),
half_t_zero);
// convert half_t infinity to bf8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ(
ck::NumericLimits<bf8_ocp_t>::Max(),
f8_convert_sr<bf8_ocp_t>(type_convert<half_t>(std::numeric_limits<float>::infinity())));
// positive normal bf8 value to bf8 and back, check if holds
constexpr half_t pos_norm_bf8{0.0000762939f}; // 10*2^-17
ASSERT_NEAR(
pos_norm_bf8, type_convert<half_t>(f8_convert_sr<bf8_ocp_t>(pos_norm_bf8)), half_t_tol);
// negative smallest normal bf8 value to bf8 and back, check if holds
constexpr half_t neg_min_bf8{-0.00006103515625f}; //-2^-14
ASSERT_NEAR(
neg_min_bf8, type_convert<half_t>(f8_convert_sr<bf8_ocp_t>(neg_min_bf8)), half_t_zero);
// positive subnorm bf8 value to bf8 and back, check if holds
constexpr half_t pos_subnorm_bf8{0.000030517578125f}; // 2^-15
ASSERT_NEAR(pos_subnorm_bf8,
type_convert<half_t>(f8_convert_sr<bf8_ocp_t>(pos_subnorm_bf8)),
half_t_zero);
// min subnorm bf8 value to bf8 and back, check if holds
ASSERT_NEAR(half_t{-min_subnorm_bf8},
type_convert<half_t>(f8_convert_sr<bf8_ocp_t>(half_t{-min_subnorm_bf8})),
half_t_zero);
// smaller than min subnorm bf8 value to bf8 alternates between 0 and 2^-16
constexpr half_t less_than_min_subnorm{0.00000762939453125f}; // 2^-17
ASSERT_NEAR(half_t_zero,
type_convert<half_t>(f8_convert_sr<bf8_ocp_t>(less_than_min_subnorm)),
half_t{min_subnorm_bf8});
// convert quiet NaN to bf8_ocp_t and check if it is quiet NaN
const auto bf8_nan = f8_convert_sr<bf8_ocp_t>(ck::NumericLimits<half_t>::QuietNaN());
ASSERT_TRUE(ck::fp8_impl::ocp_bf8_is_nan(bf8_nan.data));
}
......@@ -872,3 +872,153 @@ TEST(Complex_half, TestAsTypeReshape)
test_vec.at(num_elem * i + 1));
});
}
#if CK_USE_OCP_FP8
TEST(FP8OCP, TestSize)
{
static_assert(std::is_same_v<f8_t, ck::f8_ocp_t>, "OCP FP8 is not enabled");
ASSERT_EQ(sizeof(f8_t), sizeof(ck::fp8_storage_t));
ASSERT_EQ(sizeof(vector_type<f8_t, 2>), sizeof(vector_type<ck::fp8_storage_t, 2>));
ASSERT_EQ(sizeof(vector_type<f8_t, 4>), sizeof(vector_type<ck::fp8_storage_t, 4>));
ASSERT_EQ(sizeof(vector_type<f8_t, 8>), sizeof(vector_type<ck::fp8_storage_t, 8>));
ASSERT_EQ(sizeof(vector_type<f8_t, 16>), sizeof(vector_type<ck::fp8_storage_t, 16>));
ASSERT_EQ(sizeof(vector_type<f8_t, 32>), sizeof(vector_type<ck::fp8_storage_t, 32>));
ASSERT_EQ(sizeof(vector_type<f8_t, 64>), sizeof(vector_type<ck::fp8_storage_t, 64>));
}
TEST(FP8OCP, TestAsType)
{
static_assert(std::is_same_v<f8_t, ck::f8_ocp_t>, "OCP FP8 is not enabled");
// test size
std::array<float, 8> test_vec = {-4, -2, -0.5, -0.25, 1.0 / 8.0, 1, 1.5, 16};
constexpr int size = test_vec.size();
// reference vector
vector_type<f8_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}(
[&](auto i) { ASSERT_EQ(right_vec.template AsType<f8_t>()(Number<i>{}), f8_t{0}); });
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<f8_t>()(Number<i>{}) = ck::type_convert<f8_t>(test_vec.at(i));
});
// copy the vector
vector_type<f8_t, size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<f8_t>()(Number<i>{}),
ck::type_convert<f8_t>(test_vec.at(i)));
});
}
TEST(FP8OCP, TestAsTypeReshape)
{
static_assert(std::is_same_v<f8_t, ck::f8_ocp_t>, "OCP FP8 is not enabled");
// test size
std::array<float, 8> test_vec = {-8, -0.5, -0.25, 1.0 / 8.0, 1 / 256, 1, 1.5, 16};
constexpr int size = test_vec.size();
// reference vector
vector_type<f8_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}(
[&](auto i) { ASSERT_EQ(right_vec.template AsType<f8_t>()(Number<i>{}), f8_t{0}); });
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<f8_t>()(Number<i>{}) = ck::type_convert<f8_t>(test_vec.at(i));
});
// copy the first half of a vector
vector_type<f8_t, size / 2> left_vec{
right_vec.template AsType<vector_type<f8_t, size / 2>::type>()(Number<0>{})};
// check if values were copied correctly
ck::static_for<0, size / 2, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<f8_t>()(Number<i>{}),
ck::type_convert<f8_t>(test_vec.at(i)));
});
}
TEST(BF8OCP, TestSize)
{
static_assert(std::is_same_v<bf8_t, ck::bf8_ocp_t>, "OCP BF8 is not enabled");
ASSERT_EQ(sizeof(bf8_t), sizeof(ck::fp8_storage_t));
ASSERT_EQ(sizeof(vector_type<bf8_t, 2>), sizeof(vector_type<ck::fp8_storage_t, 2>));
ASSERT_EQ(sizeof(vector_type<bf8_t, 4>), sizeof(vector_type<ck::fp8_storage_t, 4>));
ASSERT_EQ(sizeof(vector_type<bf8_t, 8>), sizeof(vector_type<ck::fp8_storage_t, 8>));
ASSERT_EQ(sizeof(vector_type<bf8_t, 16>), sizeof(vector_type<ck::fp8_storage_t, 16>));
ASSERT_EQ(sizeof(vector_type<bf8_t, 32>), sizeof(vector_type<ck::fp8_storage_t, 32>));
ASSERT_EQ(sizeof(vector_type<bf8_t, 64>), sizeof(vector_type<ck::fp8_storage_t, 64>));
}
TEST(BF8OCP, TestAsType)
{
static_assert(std::is_same_v<bf8_t, ck::bf8_ocp_t>, "OCP BF8 is not enabled");
// test size
std::array<float, 8> test_vec = {-4, -2, -0.5, -0.25, 1.0 / 8.0, 1, 1.5, 16};
constexpr int size = test_vec.size();
// reference vector
vector_type<bf8_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}(
[&](auto i) { ASSERT_EQ(right_vec.template AsType<bf8_t>()(Number<i>{}), bf8_t{0}); });
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<bf8_t>()(Number<i>{}) = ck::type_convert<bf8_t>(test_vec.at(i));
});
// copy the vector
vector_type<bf8_t, size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<bf8_t>()(Number<i>{}),
ck::type_convert<bf8_t>(test_vec.at(i)));
});
}
TEST(BF8OCP, TestAsTypeReshape)
{
static_assert(std::is_same_v<bf8_t, ck::bf8_ocp_t>, "OCP BF8 is not enabled");
// test size
std::array<float, 8> test_vec = {-8, -0.5, -0.25, 1.0 / 8.0, 1 / 256, 1, 1.5, 16};
constexpr int size = test_vec.size();
// reference vector
vector_type<bf8_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}(
[&](auto i) { ASSERT_EQ(right_vec.template AsType<bf8_t>()(Number<i>{}), bf8_t{0}); });
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<bf8_t>()(Number<i>{}) = ck::type_convert<bf8_t>(test_vec.at(i));
});
// copy the first half of a vector
vector_type<bf8_t, size / 2> left_vec{
right_vec.template AsType<vector_type<bf8_t, size / 2>::type>()(Number<0>{})};
// check if values were copied correctly
ck::static_for<0, size / 2, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<bf8_t>()(Number<i>{}),
ck::type_convert<bf8_t>(test_vec.at(i)));
});
}
#endif
......@@ -7,154 +7,171 @@
using ck::f8_convert_rne;
using ck::f8_convert_sr;
using ck::f8_t;
using ck::f8_fnuz_t;
using ck::half_t;
using ck::type_convert;
TEST(FP8, NumericLimits)
TEST(FP8FNUZ, NumericLimits)
{
// constants given for negative zero nan mode
EXPECT_EQ(ck::NumericLimits<f8_t>::Min(), type_convert<f8_t>(0x08));
EXPECT_EQ(ck::NumericLimits<f8_t>::Max(), type_convert<f8_t>(0x7F));
EXPECT_EQ(ck::NumericLimits<f8_t>::Lowest(), type_convert<f8_t>(0xFF));
EXPECT_EQ(ck::NumericLimits<f8_t>::QuietNaN(), type_convert<f8_t>(0x80));
EXPECT_EQ(ck::NumericLimits<f8_fnuz_t>::Min(), type_convert<f8_fnuz_t>(0x08));
EXPECT_EQ(ck::NumericLimits<f8_fnuz_t>::Max(), type_convert<f8_fnuz_t>(0x7F));
EXPECT_EQ(ck::NumericLimits<f8_fnuz_t>::Lowest(), type_convert<f8_fnuz_t>(0xFF));
EXPECT_EQ(ck::NumericLimits<f8_fnuz_t>::QuietNaN(), type_convert<f8_fnuz_t>(0x80));
}
TEST(FP8, ConvertFP32Nearest)
TEST(FP8FNUZ, ConvertFP32Nearest)
{
// fix the tolerance value
float abs_tol = 1e-6;
// convert 0 float to fp8 and back, check if holds
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_rne<f8_t>(0.0f)), abs_tol);
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_rne<f8_fnuz_t>(0.0f)), abs_tol);
// don't run the next test on gfx11 devices
#ifndef CK_SKIP_FLAKY_F8_TEST
// convert minimal float to fp8 and back, check if holds
ASSERT_NEAR(std::numeric_limits<float>::min(),
type_convert<float>(f8_convert_rne<f8_t>(std::numeric_limits<float>::min())),
type_convert<float>(f8_convert_rne<f8_fnuz_t>(std::numeric_limits<float>::min())),
abs_tol);
#endif
// convert maximal f8_t to float and check if equal to 240.0
ASSERT_NEAR(240.0f, type_convert<float>(f8_convert_rne<f8_t>(240.0f)), abs_tol);
// convert maximal float to fp8 and back, check if clipped to 240.0
ASSERT_NEAR(240.0f,
type_convert<float>(f8_convert_rne<f8_t>(std::numeric_limits<float>::max())),
const auto max_f8_t_float = type_convert<float>(ck::NumericLimits<f8_fnuz_t>::Max());
// convert maximal f8_fnuz_t to float and check if equal to fp8 max
ASSERT_NEAR(
max_f8_t_float, type_convert<float>(f8_convert_rne<f8_fnuz_t>(max_f8_t_float)), abs_tol);
// XXX: FNUZ f8_convert_rne behavior is inconsistent.
// Clipping large values to fp8 max (saturation to finite) contradicts converting inf float to
// fp8 qNAN (no saturation).
// convert maximal float to fp8 and back, check if clipped to fp8 max
ASSERT_NEAR(max_f8_t_float,
type_convert<float>(f8_convert_rne<f8_fnuz_t>(std::numeric_limits<float>::max())),
abs_tol);
// convert inf float to f8_t and check if it is qNan
ASSERT_NEAR(type_convert<f8_t>(0x80),
f8_convert_rne<f8_t>(std::numeric_limits<float>::infinity()),
// convert inf float to f8_fnuz_t and check if it is qNan
ASSERT_NEAR(ck::NumericLimits<f8_fnuz_t>::QuietNaN(),
f8_convert_rne<f8_fnuz_t>(std::numeric_limits<float>::infinity()),
abs_tol);
// positive norm float value to fp8 and back, check if holds
float pos_float = 0.017578125f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<f8_t>(pos_float)), abs_tol);
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<f8_fnuz_t>(pos_float)), abs_tol);
// negative norm float value to fp8 and back, check if holds
float neg_float = -0.015625f;
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<f8_t>(neg_float)), abs_tol);
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<f8_fnuz_t>(neg_float)), abs_tol);
// positive subnorm float value to fp8 and back, check if holds
pos_float = 0.00390625f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<f8_t>(pos_float)), abs_tol);
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<f8_fnuz_t>(pos_float)), abs_tol);
// negative subnorm float value to fp8 and back, check if holds
neg_float = -0.001953125f;
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<f8_t>(neg_float)), abs_tol);
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<f8_fnuz_t>(neg_float)), abs_tol);
}
TEST(FP8, ConvertFP32Stochastic)
TEST(FP8FNUZ, ConvertFP32Stochastic)
{
// fix the tolerance value
float abs_tol = 1e-6;
// convert 0 float to fp8 and back, check if holds
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_sr<f8_t>(0.0f)), abs_tol);
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_sr<f8_fnuz_t>(0.0f)), abs_tol);
// convert minimal float to fp8 and back, check if holds
ASSERT_NEAR(std::numeric_limits<float>::min(),
type_convert<float>(f8_convert_sr<f8_t>(std::numeric_limits<float>::min())),
type_convert<float>(f8_convert_sr<f8_fnuz_t>(std::numeric_limits<float>::min())),
abs_tol);
// convert maximal f8_t to float and check if equal to 240.0
ASSERT_NEAR(240.0f, type_convert<float>(f8_convert_sr<f8_t>(240.0f)), abs_tol);
// convert maximal float to fp8 and back, check if clipped to 240.0
ASSERT_NEAR(240.0f,
type_convert<float>(f8_convert_sr<f8_t>(std::numeric_limits<float>::max())),
const auto max_f8_t_float = type_convert<float>(ck::NumericLimits<f8_fnuz_t>::Max());
// convert maximal f8_fnuz_t to float and check if equal to fp8 max
ASSERT_NEAR(
max_f8_t_float, type_convert<float>(f8_convert_sr<f8_fnuz_t>(max_f8_t_float)), abs_tol);
// convert maximal float to fp8 and back, check if clipped to fp8 max
ASSERT_NEAR(max_f8_t_float,
type_convert<float>(f8_convert_sr<f8_fnuz_t>(std::numeric_limits<float>::max())),
abs_tol);
// convert inf float to f8_t and check if it is qNan
ASSERT_NEAR(type_convert<f8_t>(0x80),
f8_convert_sr<f8_t>(std::numeric_limits<float>::infinity()),
// convert inf float to f8_fnuz_t and check if it is qNan
ASSERT_NEAR(ck::NumericLimits<f8_fnuz_t>::QuietNaN(),
f8_convert_sr<f8_fnuz_t>(std::numeric_limits<float>::infinity()),
abs_tol);
// positive norm float value to fp8 and back, check if holds
float pos_float = 0.017578125f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<f8_t>(pos_float)), abs_tol);
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<f8_fnuz_t>(pos_float)), abs_tol);
// negative norm float value to fp8 and back, check if holds
float neg_float = -0.015625f;
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<f8_t>(neg_float)), abs_tol);
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<f8_fnuz_t>(neg_float)), abs_tol);
// positive subnorm float value to fp8 and back, check if holds
pos_float = 0.00390625f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<f8_t>(pos_float)), abs_tol);
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<f8_fnuz_t>(pos_float)), abs_tol);
// negative subnorm float value to fp8 and back, check if holds
neg_float = -0.001953125f;
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<f8_t>(neg_float)), abs_tol);
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<f8_fnuz_t>(neg_float)), abs_tol);
}
TEST(FP8, ConvertFP16Nearest)
TEST(FP8FNUZ, ConvertFP16Nearest)
{
// fix the tolerance value
float abs_tol = 1e-3;
// convert 0 fp16 to fp8 and back, check if holds
ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(f8_convert_rne<f8_t>(half_t{0.0})), abs_tol);
ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(f8_convert_rne<f8_fnuz_t>(half_t{0.0})), abs_tol);
// convert minimal fp16 to fp8 and back, check if holds
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
type_convert<half_t>(f8_convert_rne<f8_t>(ck::NumericLimits<half_t>::Min())),
type_convert<half_t>(f8_convert_rne<f8_fnuz_t>(ck::NumericLimits<half_t>::Min())),
abs_tol);
// convert maximal f8_t to fp16 and check if equal to 240.0
ASSERT_NEAR(half_t{240.0}, type_convert<half_t>(f8_convert_rne<f8_t>(half_t{240.0})), abs_tol);
// convert maximal fp16 to fp8 and back, check if clipped to 240.0
ASSERT_NEAR(half_t{240.0},
type_convert<half_t>(f8_convert_rne<f8_t>(ck::NumericLimits<half_t>::Max())),
const auto max_f8_t_half = type_convert<half_t>(ck::NumericLimits<f8_fnuz_t>::Max());
// convert maximal f8_fnuz_t to fp16 and check if equal to fp8 max
ASSERT_NEAR(
max_f8_t_half, type_convert<half_t>(f8_convert_rne<f8_fnuz_t>(max_f8_t_half)), abs_tol);
// convert maximal fp16 to fp8 and back, check if clipped to fp8 max
ASSERT_NEAR(max_f8_t_half,
type_convert<half_t>(f8_convert_rne<f8_fnuz_t>(ck::NumericLimits<half_t>::Max())),
abs_tol);
// convert QuietNaN fp16 to f8_t and check if it is QuietNaN
ASSERT_NEAR(type_convert<f8_t>(0x80),
f8_convert_rne<f8_t>(ck::NumericLimits<half_t>::QuietNaN()),
// convert QuietNaN fp16 to f8_fnuz_t and check if it is QuietNaN
ASSERT_NEAR(ck::NumericLimits<f8_fnuz_t>::QuietNaN(),
f8_convert_rne<f8_fnuz_t>(ck::NumericLimits<half_t>::QuietNaN()),
abs_tol);
// positive norm fp16 value to fp8 and back, check if holds
half_t pos_half = half_t{0.017578125};
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_rne<f8_t>(pos_half)), abs_tol);
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_rne<f8_fnuz_t>(pos_half)), abs_tol);
// negative norm fp16 value to fp8 and back, check if holds
half_t neg_half = half_t{-0.015625};
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_rne<f8_t>(neg_half)), abs_tol);
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_rne<f8_fnuz_t>(neg_half)), abs_tol);
// positive subnorm fp16 value to fp8 and back, check if holds
pos_half = half_t{0.00390625};
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_rne<f8_t>(pos_half)), abs_tol);
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_rne<f8_fnuz_t>(pos_half)), abs_tol);
// negative subnorm fp16 value to fp8 and back, check if holds
neg_half = half_t{-0.001953125};
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_rne<f8_t>(neg_half)), abs_tol);
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_rne<f8_fnuz_t>(neg_half)), abs_tol);
}
TEST(FP8, ConvertFP16Stochastic)
TEST(FP8FNUZ, ConvertFP16Stochastic)
{
// fix the tolerance value
float abs_tol = 1e-3;
// convert 0 fp16 to fp8 and back, check if holds
ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(f8_convert_sr<f8_t>(half_t{0.0})), abs_tol);
ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(f8_convert_sr<f8_fnuz_t>(half_t{0.0})), abs_tol);
// convert minimal fp16 to fp8 and back, check if holds
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
type_convert<half_t>(f8_convert_sr<f8_t>(ck::NumericLimits<half_t>::Min())),
type_convert<half_t>(f8_convert_sr<f8_fnuz_t>(ck::NumericLimits<half_t>::Min())),
abs_tol);
// convert maximal f8_t to fp16 and check if equal to 240.0
ASSERT_NEAR(half_t{240.0}, type_convert<half_t>(f8_convert_sr<f8_t>(half_t{240.0})), abs_tol);
// convert maximal fp16 to fp8 and back, check if clipped to 240.0
ASSERT_NEAR(half_t{240.0},
type_convert<half_t>(f8_convert_sr<f8_t>(ck::NumericLimits<half_t>::Max())),
const auto max_f8_t_half = type_convert<half_t>(ck::NumericLimits<f8_fnuz_t>::Max());
// convert maximal f8_fnuz_t to fp16 and check if equal to fp8 max
ASSERT_NEAR(
max_f8_t_half, type_convert<half_t>(f8_convert_sr<f8_fnuz_t>(max_f8_t_half)), abs_tol);
// convert maximal fp16 to fp8 and back, check if clipped to fp8 max
ASSERT_NEAR(max_f8_t_half,
type_convert<half_t>(f8_convert_sr<f8_fnuz_t>(ck::NumericLimits<half_t>::Max())),
abs_tol);
// convert QuietNaN fp16 to f8_t and check if it is QuietNaN
ASSERT_NEAR(type_convert<f8_t>(0x80),
f8_convert_sr<f8_t>(ck::NumericLimits<half_t>::QuietNaN()),
// convert QuietNaN fp16 to f8_fnuz_t and check if it is QuietNaN
ASSERT_NEAR(ck::NumericLimits<f8_fnuz_t>::QuietNaN(),
f8_convert_sr<f8_fnuz_t>(ck::NumericLimits<half_t>::QuietNaN()),
abs_tol);
// positive norm fp16 value to fp8 and back, check if holds
half_t pos_half = half_t{0.017578125};
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<f8_t>(pos_half)), abs_tol);
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<f8_fnuz_t>(pos_half)), abs_tol);
// negative norm fp16 value to fp8 and back, check if holds
half_t neg_half = half_t{-0.015625};
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<f8_t>(neg_half)), abs_tol);
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<f8_fnuz_t>(neg_half)), abs_tol);
// positive subnorm fp16 value to fp8 and back, check if holds
pos_half = half_t{0.00390625};
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<f8_t>(pos_half)), abs_tol);
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<f8_fnuz_t>(pos_half)), abs_tol);
// negative subnorm fp16 value to fp8 and back, check if holds
neg_half = half_t{-0.001953125};
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<f8_t>(neg_half)), abs_tol);
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<f8_fnuz_t>(neg_half)), abs_tol);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment