Commit 50da6a8c authored by Brian Pickrell's avatar Brian Pickrell
Browse files

misc code cleanup. Seed can be any type.

parent aa517bd9
...@@ -44,35 +44,30 @@ ...@@ -44,35 +44,30 @@
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
#include <migraphx/reflect.hpp>
#include <random> #include <random>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct random_uniform /**
{ * random_uniform populates the passed shape with random numbers, in a uniform
// The random_uniform operation does not contain a random number generator seed * distribution. Range for floating-point data types is (0, 1);
// as a member, and expects it to be passed as a runtime input. * for integer types it is [0, <max value for the type>]
*
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack();
}
/**
* Input 1: seed * Input 1: seed
* Input 2: output shape * Input 2: output shape
*/ */
struct random_uniform
{
// The random_uniform operation needs the random number generator seed
// to be passed as a runtime input.
std::string name() const { return "random_uniform"; } std::string name() const { return "random_uniform"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this, true}.has(2); check_shapes{inputs, *this, true}.has(2);
if(inputs.front().type() != shape::type_t::uint64_type)
MIGRAPHX_THROW("RANDOM_UNIFORM: Input 1 (seed) must have type long unsigned int");
auto s = inputs.at(1); auto s = inputs.at(1);
if(s.dynamic()) if(s.dynamic())
{ {
...@@ -98,7 +93,7 @@ struct random_uniform ...@@ -98,7 +93,7 @@ struct random_uniform
{ {
// default range for all integer types is (0, // default range for all integer types is (0,
// std::uniform_int_distribution<type>::max()). // std::uniform_int_distribution<type>::max()).
// To clamp to a different range, apply min or max ops. to the output of this. // Todo: enable different ranges
std::uniform_int_distribution<type> dis; std::uniform_int_distribution<type> dis;
std::generate(output.begin(), output.end(), [&] { return dis(gen); }); std::generate(output.begin(), output.end(), [&] { return dis(gen); });
} }
......
...@@ -6498,7 +6498,7 @@ TEST_CASE(random_uniform_int_test) ...@@ -6498,7 +6498,7 @@ TEST_CASE(random_uniform_int_test)
// random uniform distribution with an integer type input shape // random uniform distribution with an integer type input shape
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
uint64_t seed(0); float seed(0.1);
size_t sample_size(200); size_t sample_size(200);
// Shape of the random data // Shape of the random data
...@@ -6509,8 +6509,8 @@ TEST_CASE(random_uniform_int_test) ...@@ -6509,8 +6509,8 @@ TEST_CASE(random_uniform_int_test)
auto input = mm->add_literal(migraphx::literal(rs, data)); auto input = mm->add_literal(migraphx::literal(rs, data));
// Runtime randomization seed // Runtime randomization seed
migraphx::shape seed_shape{migraphx::shape::uint64_type, {1}}; migraphx::shape seed_shape{migraphx::shape::float_type, {1}};
std::vector<uint64_t> seed_data{seed}; std::vector<float> seed_data{seed};
auto seed_input = mm->add_literal(migraphx::literal(seed_shape, seed_data)); auto seed_input = mm->add_literal(migraphx::literal(seed_shape, seed_data));
mm->add_instruction(migraphx::make_op("random_uniform"), seed_input, input); mm->add_instruction(migraphx::make_op("random_uniform"), seed_input, input);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment