Commit c5e22952 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Use element location for PRNG

parent 789862ca
...@@ -1151,7 +1151,6 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(h ...@@ -1151,7 +1151,6 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(h
template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<float, T>{}, bool> = false> template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<float, T>{}, bool> = false>
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t) __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t)
{ {
// uint32_t x = reinterpret_cast<uint32_t&>(val);
uint32_t x = *(reinterpret_cast<uint32_t*>(&val)); uint32_t x = *(reinterpret_cast<uint32_t*>(&val));
uint32_t drop_bits = uint32_t(x) & 0xFFFFu; uint32_t drop_bits = uint32_t(x) & 0xFFFFu;
drop_bits ^= x >> 16; drop_bits ^= x >> 16;
...@@ -1168,7 +1167,6 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = ...@@ -1168,7 +1167,6 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<half_t, T>{}, bool> = false> template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<half_t, T>{}, bool> = false>
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t) __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t)
{ {
// uint16_t x = reinterpret_cast<uint16_t&>(val);
uint16_t x = *(reinterpret_cast<uint16_t*>(&val)); uint16_t x = *(reinterpret_cast<uint16_t*>(&val));
uint32_t drop_bits = uint32_t(x) & 0xFFFFu; uint32_t drop_bits = uint32_t(x) & 0xFFFFu;
drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5); drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5);
...@@ -1210,7 +1208,7 @@ inline __host__ f8_t f8_convert_sr<f8_t, float>(float x) ...@@ -1210,7 +1208,7 @@ inline __host__ f8_t f8_convert_sr<f8_t, float>(float x)
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42; constexpr int seed = 42;
// as thread id is not available on host, use 0 for prn generation // as thread id is not available on host, use 0 for prn generation
uint32_t rng = prand_generator<float, seed>(0, x); uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
return cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x, rng); return cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x, rng);
} }
...@@ -1222,7 +1220,7 @@ inline __device__ f8_t f8_convert_sr<f8_t, float>(float x) ...@@ -1222,7 +1220,7 @@ inline __device__ f8_t f8_convert_sr<f8_t, float>(float x)
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42; constexpr int seed = 42;
uint32_t rng = prand_generator<float, seed>(get_thread_global_1d_id(), x); uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
return cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x, rng); return cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x, rng);
} }
...@@ -1235,7 +1233,7 @@ inline __host__ f8_t f8_convert_sr<f8_t, half_t>(half_t x) ...@@ -1235,7 +1233,7 @@ inline __host__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42; constexpr int seed = 42;
// as thread id is not available on host, use 0 for prn generation // as thread id is not available on host, use 0 for prn generation
uint32_t rng = prand_generator<half_t, seed>(0, x); uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x, return cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x,
rng); rng);
} }
...@@ -1248,7 +1246,7 @@ inline __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x) ...@@ -1248,7 +1246,7 @@ inline __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42; constexpr int seed = 42;
uint32_t rng = prand_generator<half_t, seed>(get_thread_global_1d_id(), x); uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x, return cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x,
rng); rng);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment