Commit 99a9d56f authored by Brian Pickrell's avatar Brian Pickrell
Browse files

added integer types to random_uniform and removed range_max and range_min attributes

parent 10432892
...@@ -62,9 +62,7 @@ struct random_seed ...@@ -62,9 +62,7 @@ struct random_seed
{ {
argument result(output_shape); argument result(output_shape);
result.visit([&](auto output) { result.visit([&](auto output) { output.front() = std::random_device{}(); });
output.front() = std::random_device{}();
});
return result; return result;
} }
}; };
......
...@@ -56,13 +56,10 @@ struct random_uniform ...@@ -56,13 +56,10 @@ struct random_uniform
// The random_uniform operation does not contain a random number generator seed // The random_uniform operation does not contain a random number generator seed
// as a member, and expects it to be passed as a runtime input. // as a member, and expects it to be passed as a runtime input.
float range_min = 0.0f;
float range_max = 1.0f;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.range_min, "range_min"), f(self.range_max, "range_max")); return pack();
} }
/** /**
...@@ -90,14 +87,27 @@ struct random_uniform ...@@ -90,14 +87,27 @@ struct random_uniform
argument compute(const shape&, std::vector<argument> args) const argument compute(const shape&, std::vector<argument> args) const
{ {
// Output goes into the passed buffer, not the shape output. // Output goes into the passed buffer, not the shape output.
argument result = args[1]; auto result = args[1];
uint64_t local_seed = args[0].at<uint64_t>(0); uint64_t local_seed = args[0].at<uint64_t>(0);
std::mt19937 gen(local_seed); std::mt19937 gen(local_seed);
std::uniform_real_distribution<> dis(range_min, range_max);
result.visit([&](auto output_shape) { result.visit([&](auto output) {
std::generate(output_shape.begin(), output_shape.end(), [&]() { return dis(gen); }); using type = typename decltype(output)::value_type;
if constexpr(std::is_integral<type>{})
{
// default range for all integer types is (0, INT_MAX) which depends
// on the integral type. To clamp
// to a different range, apply min or max ops. to the output of this.
std::uniform_int_distribution<type> dis;
std::generate(output.begin(), output.end(), [&] { return dis(gen); });
}
else
{
// default real distribution type is double with range (0, 1);
std::uniform_real_distribution<> dis;
std::generate(output.begin(), output.end(), [&] { return dis(gen); });
}
}); });
return result; return result;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment