Commit 052ab48a authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Update type_converts

parent 0c460962
......@@ -1058,7 +1058,7 @@ inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return cast_to_f8<negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x, rng);
return cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x, rng);
}
// convert fp8 to fp32
......@@ -1066,7 +1066,26 @@ template <>
inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
{
constexpr bool negative_zero_nan = true;
return cast_from_f8<negative_zero_nan>(x);
return cast_from_f8<float, negative_zero_nan>(x);
}
// convert fp16 to fp8
template <>
inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
{
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x, rng);
}
// convert fp8 to fp16
template <>
inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
{
constexpr bool negative_zero_nan = true;
return cast_from_f8<half_t, negative_zero_nan>(x);
}
// Declare a template function for bf16 conversion using RTN
......@@ -1190,7 +1209,7 @@ inline __host__ f8_t f8_convert_sr<f8_t, float>(float x)
constexpr int seed = 42;
// as thread id is not available on host, use 0 for prn generation
uint32_t rng = prand_generator<float, seed>(0, x);
return cast_to_f8<negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x, rng);
return cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x, rng);
}
// convert fp32 to fp8 with stochastic rounding on device
......@@ -1202,7 +1221,32 @@ inline __device__ f8_t f8_convert_sr<f8_t, float>(float x)
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42;
uint32_t rng = prand_generator<float, seed>(get_thread_global_1d_id(), x);
return cast_to_f8<negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x, rng);
return cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x, rng);
}
// convert fp16 to fp8 with stochastic rounding on host
template <>
inline __host__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
{
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42;
// as thread id is not available on host, use 0 for prn generation
uint32_t rng = prand_generator<half_t, seed>(0, x);
return cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x, rng);
}
// convert fp16 to fp8 with stochastic rounding on device
template <>
inline __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
{
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42;
uint32_t rng = prand_generator<half_t, seed>(get_thread_global_1d_id(), x);
return cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x, rng);
}
template <typename T>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment