Commit 50da6a8c authored by Brian Pickrell's avatar Brian Pickrell
Browse files

misc code cleanup. Seed can be any type.

parent aa517bd9
......@@ -44,35 +44,30 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/reflect.hpp>
#include <random>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
/**
* random_uniform populates the passed shape with random numbers, in a uniform
* distribution. Range for floating-point data types is (0, 1);
* for integer types it is [0, <max value for the type>]
*
* Input 1: seed
* Input 2: output shape
*/
struct random_uniform
{
// The random_uniform operation does not contain a random number generator seed
// as a member, and expects it to be passed as a runtime input.
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack();
}
// The random_uniform operation needs the random number generator seed
// to be passed as a runtime input.
/**
* Input 1: seed
* Input 2: output shape
*/
std::string name() const { return "random_uniform"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this, true}.has(2);
if(inputs.front().type() != shape::type_t::uint64_type)
MIGRAPHX_THROW("RANDOM_UNIFORM: Input 1 (seed) must have type long unsigned int");
auto s = inputs.at(1);
if(s.dynamic())
{
......@@ -98,7 +93,7 @@ struct random_uniform
{
// default range for all integer types is (0,
// std::uniform_int_distribution<type>::max()).
// To clamp to a different range, apply min or max ops. to the output of this.
// Todo: enable different ranges
std::uniform_int_distribution<type> dis;
std::generate(output.begin(), output.end(), [&] { return dis(gen); });
}
......
......@@ -6498,7 +6498,7 @@ TEST_CASE(random_uniform_int_test)
// random uniform distribution with an integer type input shape
migraphx::program p;
auto* mm = p.get_main_module();
uint64_t seed(0);
float seed(0.1);
size_t sample_size(200);
// Shape of the random data
......@@ -6509,8 +6509,8 @@ TEST_CASE(random_uniform_int_test)
auto input = mm->add_literal(migraphx::literal(rs, data));
// Runtime randomization seed
migraphx::shape seed_shape{migraphx::shape::uint64_type, {1}};
std::vector<uint64_t> seed_data{seed};
migraphx::shape seed_shape{migraphx::shape::float_type, {1}};
std::vector<float> seed_data{seed};
auto seed_input = mm->add_literal(migraphx::literal(seed_shape, seed_data));
mm->add_instruction(migraphx::make_op("random_uniform"), seed_input, input);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment