Commit 562ec121 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Rearrange f8_utils' namespaces

parent c208a8ac
...@@ -17,6 +17,12 @@ enum class f8_rounding_mode ...@@ -17,6 +17,12 @@ enum class f8_rounding_mode
stochastic stochastic
}; };
} // namespace ck
namespace ck::utils {
namespace {
template <typename T, bool negative_zero_nan, bool clip, bool stoch> template <typename T, bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng) __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
{ {
...@@ -127,17 +133,6 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng) ...@@ -127,17 +133,6 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
return (sign << (f8_exp + f8_mant)) | (exponent << f8_mant) | mantissa; return (sign << (f8_exp + f8_mant)) | (exponent << f8_mant) | mantissa;
} }
template <typename T, bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ f8_t cast_to_f8(T x, uint32_t rng)
{
// check datatype
constexpr bool is_half = std::is_same<T, half_t>::value;
constexpr bool is_float = std::is_same<T, float>::value;
static_assert(is_half || is_float, "Only half and float can be casted to f8.");
return run_cast_to_f8<T, negative_zero_nan, clip, stoch>(x, rng);
}
template <typename T, bool negative_zero_nan> template <typename T, bool negative_zero_nan>
__host__ __device__ T run_cast_from_f8(f8_t x) __host__ __device__ T run_cast_from_f8(f8_t x)
{ {
...@@ -225,6 +220,19 @@ __host__ __device__ T run_cast_from_f8(f8_t x) ...@@ -225,6 +220,19 @@ __host__ __device__ T run_cast_from_f8(f8_t x)
return *(reinterpret_cast<const T*>(&retval)); return *(reinterpret_cast<const T*>(&retval));
} }
} // namespace
template <typename T, bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ f8_t cast_to_f8(T x, uint32_t rng)
{
// check datatype
constexpr bool is_half = std::is_same<T, half_t>::value;
constexpr bool is_float = std::is_same<T, float>::value;
static_assert(is_half || is_float, "Only half and float can be casted to f8.");
return run_cast_to_f8<T, negative_zero_nan, clip, stoch>(x, rng);
}
template <typename T, bool negative_zero_nan> template <typename T, bool negative_zero_nan>
__host__ __device__ T cast_from_f8(f8_t x) __host__ __device__ T cast_from_f8(f8_t x)
{ {
...@@ -240,4 +248,4 @@ __host__ __device__ T cast_from_f8(f8_t x) ...@@ -240,4 +248,4 @@ __host__ __device__ T cast_from_f8(f8_t x)
return run_cast_from_f8<T, negative_zero_nan>(x); return run_cast_from_f8<T, negative_zero_nan>(x);
} }
} // namespace ck } // namespace ck::utils
...@@ -111,7 +111,8 @@ inline __host__ __device__ f8_t type_convert<f8_t, float>(float x) ...@@ -111,7 +111,8 @@ inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard; constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0; constexpr uint32_t rng = 0;
return cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x, rng); return utils::cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
} }
// convert fp8 to fp32 // convert fp8 to fp32
...@@ -119,7 +120,7 @@ template <> ...@@ -119,7 +120,7 @@ template <>
inline __host__ __device__ float type_convert<float, f8_t>(f8_t x) inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
{ {
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
return cast_from_f8<float, negative_zero_nan>(x); return utils::cast_from_f8<float, negative_zero_nan>(x);
} }
// convert fp16 to fp8 // convert fp16 to fp8
...@@ -130,8 +131,8 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x) ...@@ -130,8 +131,8 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard; constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0; constexpr uint32_t rng = 0;
return cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x, return utils::cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
rng); x, rng);
} }
// convert fp8 to fp16 // convert fp8 to fp16
...@@ -139,7 +140,7 @@ template <> ...@@ -139,7 +140,7 @@ template <>
inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x) inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
{ {
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
return cast_from_f8<half_t, negative_zero_nan>(x); return utils::cast_from_f8<half_t, negative_zero_nan>(x);
} }
// Declare a template function for bf16 conversion using RTN // Declare a template function for bf16 conversion using RTN
...@@ -213,7 +214,8 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x) ...@@ -213,7 +214,8 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
constexpr int seed = 42; constexpr int seed = 42;
// as thread id is not available on host, use 0 for prn generation // as thread id is not available on host, use 0 for prn generation
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
return cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x, rng); return utils::cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
} }
// convert fp16 to fp8 with stochastic rounding // convert fp16 to fp8 with stochastic rounding
...@@ -226,8 +228,8 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x) ...@@ -226,8 +228,8 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
constexpr int seed = 42; constexpr int seed = 42;
// as thread id is not available on host, use 0 for prn generation // as thread id is not available on host, use 0 for prn generation
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x, return utils::cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
rng); x, rng);
} }
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment