Commit f1c2ec74 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Merge host and device implementations

parent ee568bc2
...@@ -1190,17 +1190,13 @@ __host__ __device__ uint32_t prand_generator(int id, T val, uint32_t seed = seed ...@@ -1190,17 +1190,13 @@ __host__ __device__ uint32_t prand_generator(int id, T val, uint32_t seed = seed
return 0; return 0;
} }
// Declare a template function for fp8 conversion using SR on host // Declare a template function for fp8 conversion using SR
template <typename Y, typename X> template <typename Y, typename X>
__host__ constexpr Y f8_convert_sr(X x); __host__ __device__ constexpr Y f8_convert_sr(X x);
// Declare a template function for fp8 conversion using SR on device // convert fp32 to fp8 with stochastic rounding
template <typename Y, typename X>
__device__ constexpr Y f8_convert_sr(X x);
// convert fp32 to fp8 with stochastic rounding on host
template <> template <>
inline __host__ f8_t f8_convert_sr<f8_t, float>(float x) inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
{ {
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
constexpr bool clip = true; constexpr bool clip = true;
...@@ -1211,21 +1207,9 @@ inline __host__ f8_t f8_convert_sr<f8_t, float>(float x) ...@@ -1211,21 +1207,9 @@ inline __host__ f8_t f8_convert_sr<f8_t, float>(float x)
return cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x, rng); return cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x, rng);
} }
// convert fp32 to fp8 with stochastic rounding on device // convert fp16 to fp8 with stochastic rounding
template <>
inline __device__ f8_t f8_convert_sr<f8_t, float>(float x)
{
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
return cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x, rng);
}
// convert fp16 to fp8 with stochastic rounding on host
template <> template <>
inline __host__ f8_t f8_convert_sr<f8_t, half_t>(half_t x) inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
{ {
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
constexpr bool clip = true; constexpr bool clip = true;
...@@ -1237,19 +1221,6 @@ inline __host__ f8_t f8_convert_sr<f8_t, half_t>(half_t x) ...@@ -1237,19 +1221,6 @@ inline __host__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
rng); rng);
} }
// convert fp16 to fp8 with stochastic rounding on device
template <>
inline __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
{
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42;
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x,
rng);
}
template <typename T> template <typename T>
struct NumericLimits struct NumericLimits
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment