Commit 0d8e489b authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Format

parent 74d97e51
......@@ -13,12 +13,12 @@ using half_t = _Float16;
using int4_t = _BitInt(4);
#endif
struct f8_t
struct f8_t
{
uint8_t data;
using type = f8_t;
using type = f8_t;
using data_type = uint8_t;
__host__ __device__ f8_t() = default;
__host__ __device__ f8_t(uint8_t init);
};
......@@ -26,7 +26,7 @@ struct f8_t
struct bf8_t
{
uint8_t data;
using type = bf8_t;
using type = bf8_t;
using data_type = uint8_t;
__host__ __device__ bf8_t() = default;
......@@ -205,7 +205,8 @@ struct vector_type<T, 1>
template <typename T>
struct vector_type<T, 2>
{
using T_adjusted = typename std::conditional<std::is_same<T, f8_t>::value, f8_t::data_type, T>::type;
using T_adjusted =
typename std::conditional<std::is_same<T, f8_t>::value, f8_t::data_type, T>::type;
using d1_t = T_adjusted;
typedef T_adjusted d2_t __attribute__((ext_vector_type(2)));
......@@ -257,7 +258,8 @@ struct vector_type<T, 2>
template <typename T>
struct vector_type<T, 4>
{
using T_adjusted = typename std::conditional<std::is_same<T, f8_t>::value, f8_t::data_type, T>::type;
using T_adjusted =
typename std::conditional<std::is_same<T, f8_t>::value, f8_t::data_type, T>::type;
using d1_t = T_adjusted;
typedef T_adjusted d2_t __attribute__((ext_vector_type(2)));
......@@ -321,7 +323,8 @@ struct vector_type<T, 4>
template <typename T>
struct vector_type<T, 8>
{
using T_adjusted = typename std::conditional<std::is_same<T, f8_t>::value, f8_t::data_type, T>::type;
using T_adjusted =
typename std::conditional<std::is_same<T, f8_t>::value, f8_t::data_type, T>::type;
using d1_t = T_adjusted;
typedef T_adjusted d2_t __attribute__((ext_vector_type(2)));
......@@ -397,7 +400,8 @@ struct vector_type<T, 8>
template <typename T>
struct vector_type<T, 16>
{
using T_adjusted = typename std::conditional<std::is_same<T, f8_t>::value, f8_t::data_type, T>::type;
using T_adjusted =
typename std::conditional<std::is_same<T, f8_t>::value, f8_t::data_type, T>::type;
using d1_t = T_adjusted;
typedef T_adjusted d2_t __attribute__((ext_vector_type(2)));
......@@ -485,7 +489,8 @@ struct vector_type<T, 16>
template <typename T>
struct vector_type<T, 32>
{
using T_adjusted = typename std::conditional<std::is_same<T, f8_t>::value, f8_t::data_type, T>::type;
using T_adjusted =
typename std::conditional<std::is_same<T, f8_t>::value, f8_t::data_type, T>::type;
using d1_t = T_adjusted;
typedef T_adjusted d2_t __attribute__((ext_vector_type(2)));
......@@ -583,7 +588,8 @@ struct vector_type<T, 32>
template <typename T>
struct vector_type<T, 64>
{
using T_adjusted = typename std::conditional<std::is_same<T, f8_t>::value, f8_t::data_type, T>::type;
using T_adjusted =
typename std::conditional<std::is_same<T, f8_t>::value, f8_t::data_type, T>::type;
using d1_t = T_adjusted;
typedef T_adjusted d2_t __attribute__((ext_vector_type(2)));
......@@ -693,7 +699,8 @@ struct vector_type<T, 64>
template <typename T>
struct vector_type<T, 128>
{
using T_adjusted = typename std::conditional<std::is_same<T, f8_t>::value, f8_t::data_type, T>::type;
using T_adjusted =
typename std::conditional<std::is_same<T, f8_t>::value, f8_t::data_type, T>::type;
using d1_t = T_adjusted;
typedef T_adjusted d2_t __attribute__((ext_vector_type(2)));
......@@ -813,7 +820,8 @@ struct vector_type<T, 128>
template <typename T>
struct vector_type<T, 256>
{
using T_adjusted = typename std::conditional<std::is_same<T, f8_t>::value, f8_t::data_type, T>::type;
using T_adjusted =
typename std::conditional<std::is_same<T, f8_t>::value, f8_t::data_type, T>::type;
using d1_t = T_adjusted;
typedef T_adjusted d2_t __attribute__((ext_vector_type(2)));
......@@ -1047,33 +1055,33 @@ struct NumericLimits<f8_t>
static constexpr uint8_t binary_max = 0x77; // 0b01110111
static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
static constexpr uint8_t binary_qnan = 0x80; // 0b10000000
__host__ __device__ static f8_t Min()
__host__ __device__ static f8_t Min()
{
f8_t x;
x.data=binary_min;
return x;
x.data = binary_min;
return x;
}
__host__ __device__ static f8_t Max()
__host__ __device__ static f8_t Max()
{
f8_t x;
x.data=binary_max;
return x;
x.data = binary_max;
return x;
}
__host__ __device__ static f8_t Lowest()
__host__ __device__ static f8_t Lowest()
{
f8_t x;
x.data=binary_lowest;
return x;
x.data = binary_lowest;
return x;
}
__host__ __device__ static f8_t QuietNaN()
__host__ __device__ static f8_t QuietNaN()
{
f8_t x;
x.data=binary_qnan;
return x;
x.data = binary_qnan;
return x;
}
};
......
......@@ -250,7 +250,4 @@ __host__ __device__ T cast_from_f8(uint8_t x)
} // namespace ck::utils
// f8_t constuctor impl
inline __host__ __device__ ck::f8_t::f8_t(uint8_t init)
{
data = init;
}
inline __host__ __device__ ck::f8_t::f8_t(uint8_t init) { data = init; }
......@@ -106,8 +106,9 @@ inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return f8_t(utils::cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng));
return f8_t(
utils::cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng));
}
// convert fp8 to fp32
......@@ -126,8 +127,9 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return f8_t(utils::cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng));
return f8_t(
utils::cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng));
}
// convert fp8 to fp16
......@@ -209,8 +211,9 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
constexpr int seed = 42;
// as thread id is not available on host, use 0 for prn generation
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
return f8_t(utils::cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng));
return f8_t(
utils::cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng));
}
// convert fp16 to fp8 with stochastic rounding
......@@ -223,8 +226,9 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
constexpr int seed = 42;
// as thread id is not available on host, use 0 for prn generation
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return f8_t(utils::cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng));
return f8_t(
utils::cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng));
}
} // namespace ck
......@@ -216,7 +216,8 @@ check_err(const Range& out,
template <typename Range, typename RefRange>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, f8_t>), bool>
std::is_same_v<ranges::range_value_t<Range>, f8_t>),
bool>
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment