Commit 2ef739bd authored by Brian Pickrell's avatar Brian Pickrell
Browse files

breaks ref_ops_test

parent 1f9e7402
...@@ -29,11 +29,11 @@ ...@@ -29,11 +29,11 @@
* be given as a runtime argument containing a single value, or a compile-time * be given as a runtime argument containing a single value, or a compile-time
* attribute. * attribute.
* *
* Inputs: (1) the shape of the set to be populated. * Inputs: (1) randomization seed (uint32)
* (2) randomization seed (uint32). If not given at inference time, the attribute * (2) the shape of the set to be populated.
* value, or auto seeding, will be used. *
* *
* Attributes: seed uint32 Randomization seed * Attributes: none
* *
* Output: Same shape. * Output: Same shape.
* *
...@@ -53,7 +53,8 @@ namespace op { ...@@ -53,7 +53,8 @@ namespace op {
struct rand_uniform struct rand_uniform
{ {
uint32_t seed = {0}; // The rand_uniform operation does not contain a random number generator seed
// as a member, and expects it to be passed as a runtime input.
// todo: not currently settable // todo: not currently settable
float range_min = 0.0f; float range_min = 0.0f;
...@@ -65,17 +66,21 @@ struct rand_uniform ...@@ -65,17 +66,21 @@ struct rand_uniform
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.dtype, "dtype"), f(self.seed, "seed")); return pack(f(self.dtype, "dtype"));
} }
/**
* Input 1: seed
* Input 2: output shape
*/
std::string name() const { return "rand_uniform"; } std::string name() const { return "rand_uniform"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this, true}.has(1, 2); check_shapes{inputs, *this, true}.has(2);
if(inputs.size() > 1 and inputs.at(1).type() != shape::type_t::uint32_type) if(inputs.front().type() != shape::type_t::uint32_type)
MIGRAPHX_THROW("RAND_UNIFORM: Input 2 (seed) must have type unsigned int"); MIGRAPHX_THROW("RAND_UNIFORM: Input 2 (seed) must have type unsigned int");
auto s = inputs.front(); auto s = inputs.at(1);
if(s.dynamic()) if(s.dynamic())
{ {
return s.with_type(dtype); return s.with_type(dtype);
...@@ -86,25 +91,22 @@ struct rand_uniform ...@@ -86,25 +91,22 @@ struct rand_uniform
} }
} }
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const argument compute(const shape& output, std::vector<argument>& args) const
{ {
argument result{dyn_out.computed_shape}; (void)output;
argument& result{args[1]};
auto local_seed(seed); uint32_t local_seed = args[0].at<uint32_t>(0);
if(args.size() > 1)
{
local_seed = args[1].at<uint32_t>(0);
}
// If a seed argument was not defined, use the value from the seed attribute,
// or the default.
std::mt19937 gen(local_seed); std::mt19937 gen(local_seed);
std::uniform_real_distribution<> dis(range_min, range_max); std::uniform_real_distribution<> dis(range_min, range_max);
result.visit([&](auto output) { result.visit([&](auto output_shape) {
std::generate(output.begin(), output.end(), [&]() { return dis(gen); }); std::generate(output_shape.begin(), output_shape.end(), [&]() { return dis(gen); });
}); });
return result; return result;
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 1; }
}; };
} // namespace op } // namespace op
......
...@@ -2219,17 +2219,9 @@ TEST_CASE(prefix_scan_sum_dyn_2d) ...@@ -2219,17 +2219,9 @@ TEST_CASE(prefix_scan_sum_dyn_2d)
TEST_CASE(rand_uniform) TEST_CASE(rand_uniform)
{ {
std::vector<migraphx::shape::dynamic_dimension> dd{{5, 8}, {3, 7}}; std::vector<migraphx::shape::dynamic_dimension> dd{{5, 8}, {3, 7}};
migraphx::shape s0{migraphx::shape::float_type, {1}};
migraphx::shape s1{migraphx::shape::float_type, dd}; migraphx::shape s1{migraphx::shape::float_type, dd};
expect_shape(s1, migraphx::make_op("rand_uniform", {{"seed", 1}}), s1); expect_shape(s1, migraphx::make_op("rand_uniform"), s0, s1);
}
TEST_CASE(rand_uniform_2args)
{
std::vector<migraphx::shape::dynamic_dimension> dd{{5, 8}, {3, 7}};
migraphx::shape s1{migraphx::shape::float_type, dd};
migraphx::shape s2{migraphx::shape::uint32_type, dd};
expect_shape(s1, migraphx::make_op("rand_uniform", {{"seed", 1}}), s1, s2);
} }
TEST_CASE(random_seed) TEST_CASE(random_seed)
......
...@@ -6477,12 +6477,7 @@ TEST_CASE(rand_uniform_test) ...@@ -6477,12 +6477,7 @@ TEST_CASE(rand_uniform_test)
std::vector<uint32_t> seed_data{seed}; std::vector<uint32_t> seed_data{seed};
auto seed_input = mm->add_literal(migraphx::literal(seed_shape, seed_data)); auto seed_input = mm->add_literal(migraphx::literal(seed_shape, seed_data));
mm->add_instruction(migraphx::make_op("rand_uniform", mm->add_instruction(migraphx::make_op("rand_uniform"), seed_input, input);
{
{"seed", seed},
}),
input,
seed_input);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params0; migraphx::parameter_map params0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment