"test/vscode:/vscode.git/clone" did not exist on "e5f05e71ab57ece752f209ffc219d4105dd74e3c"
Commit c5e22952 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Use element location for PRNG

parent 789862ca
......@@ -1151,7 +1151,6 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(h
template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<float, T>{}, bool> = false>
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t)
{
// uint32_t x = reinterpret_cast<uint32_t&>(val);
uint32_t x = *(reinterpret_cast<uint32_t*>(&val));
uint32_t drop_bits = uint32_t(x) & 0xFFFFu;
drop_bits ^= x >> 16;
......@@ -1168,7 +1167,6 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<half_t, T>{}, bool> = false>
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t)
{
// uint16_t x = reinterpret_cast<uint16_t&>(val);
uint16_t x = *(reinterpret_cast<uint16_t*>(&val));
uint32_t drop_bits = uint32_t(x) & 0xFFFFu;
drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5);
......@@ -1210,7 +1208,7 @@ inline __host__ f8_t f8_convert_sr<f8_t, float>(float x)
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42;
// as thread id is not available on host, use 0 for prn generation
uint32_t rng = prand_generator<float, seed>(0, x);
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
return cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x, rng);
}
......@@ -1222,7 +1220,7 @@ inline __device__ f8_t f8_convert_sr<f8_t, float>(float x)
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42;
uint32_t rng = prand_generator<float, seed>(get_thread_global_1d_id(), x);
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
return cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x, rng);
}
......@@ -1235,7 +1233,7 @@ inline __host__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42;
// as thread id is not available on host, use 0 for prn generation
uint32_t rng = prand_generator<half_t, seed>(0, x);
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x,
rng);
}
......@@ -1248,7 +1246,7 @@ inline __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42;
uint32_t rng = prand_generator<half_t, seed>(get_thread_global_1d_id(), x);
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x,
rng);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment