Commit 562ec121 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Rearrange f8_utils' namespaces

parent c208a8ac
......@@ -17,6 +17,12 @@ enum class f8_rounding_mode
stochastic
};
} // namespace ck
namespace ck::utils {
namespace {
template <typename T, bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
{
......@@ -127,17 +133,6 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
return (sign << (f8_exp + f8_mant)) | (exponent << f8_mant) | mantissa;
}
template <typename T, bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ f8_t cast_to_f8(T x, uint32_t rng)
{
// check datatype
constexpr bool is_half = std::is_same<T, half_t>::value;
constexpr bool is_float = std::is_same<T, float>::value;
static_assert(is_half || is_float, "Only half and float can be casted to f8.");
return run_cast_to_f8<T, negative_zero_nan, clip, stoch>(x, rng);
}
template <typename T, bool negative_zero_nan>
__host__ __device__ T run_cast_from_f8(f8_t x)
{
......@@ -225,6 +220,19 @@ __host__ __device__ T run_cast_from_f8(f8_t x)
return *(reinterpret_cast<const T*>(&retval));
}
} // namespace
template <typename T, bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ f8_t cast_to_f8(T x, uint32_t rng)
{
// check datatype
constexpr bool is_half = std::is_same<T, half_t>::value;
constexpr bool is_float = std::is_same<T, float>::value;
static_assert(is_half || is_float, "Only half and float can be casted to f8.");
return run_cast_to_f8<T, negative_zero_nan, clip, stoch>(x, rng);
}
template <typename T, bool negative_zero_nan>
__host__ __device__ T cast_from_f8(f8_t x)
{
......@@ -240,4 +248,4 @@ __host__ __device__ T cast_from_f8(f8_t x)
return run_cast_from_f8<T, negative_zero_nan>(x);
}
} // namespace ck
} // namespace ck::utils
......@@ -111,7 +111,8 @@ inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x, rng);
return utils::cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
}
// convert fp8 to fp32
......@@ -119,7 +120,7 @@ template <>
inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
{
constexpr bool negative_zero_nan = true;
return cast_from_f8<float, negative_zero_nan>(x);
return utils::cast_from_f8<float, negative_zero_nan>(x);
}
// convert fp16 to fp8
......@@ -130,8 +131,8 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x,
rng);
return utils::cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
}
// convert fp8 to fp16
......@@ -139,7 +140,7 @@ template <>
inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
{
constexpr bool negative_zero_nan = true;
return cast_from_f8<half_t, negative_zero_nan>(x);
return utils::cast_from_f8<half_t, negative_zero_nan>(x);
}
// Declare a template function for bf16 conversion using RTN
......@@ -213,7 +214,8 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
constexpr int seed = 42;
// as thread id is not available on host, use 0 for prn generation
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
return cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x, rng);
return utils::cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
}
// convert fp16 to fp8 with stochastic rounding
......@@ -226,8 +228,8 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
constexpr int seed = 42;
// as thread id is not available on host, use 0 for prn generation
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x,
rng);
return utils::cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
}
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment