Commit 789862ca authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Use seed as a runtime arg

parent 502942fe
......@@ -1148,8 +1148,8 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(h
// Pseudo random number generator
// version for fp32
template <typename T, uint32_t seed, std::enable_if_t<std::is_same<float, T>{}, bool> = false>
__host__ __device__ uint32_t prand_generator(index_t id, T val)
template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<float, T>{}, bool> = false>
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t)
{
// uint32_t x = reinterpret_cast<uint32_t&>(val);
uint32_t x = *(reinterpret_cast<uint32_t*>(&val));
......@@ -1165,8 +1165,8 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val)
}
// version for fp16
template <typename T, uint32_t seed, std::enable_if_t<std::is_same<half_t, T>{}, bool> = false>
__host__ __device__ uint32_t prand_generator(index_t id, T val)
template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<half_t, T>{}, bool> = false>
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t)
{
// uint16_t x = reinterpret_cast<uint16_t&>(val);
uint16_t x = *(reinterpret_cast<uint16_t*>(&val));
......@@ -1182,12 +1182,13 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val)
// return 0 if data is not fp16 or fp32
template <typename T,
uint32_t seed,
uint32_t seed_t,
std::enable_if_t<!(std::is_same<float, T>{} || std::is_same<half_t, T>{}), bool> = false>
__host__ __device__ uint32_t prand_generator(int id, T val)
__host__ __device__ uint32_t prand_generator(int id, T val, uint32_t seed = seed_t)
{
std::ignore = id;
std::ignore = val;
std::ignore = seed;
return 0;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment