"vscode:/vscode.git/clone" did not exist on "2b52fbd24a63e6d43081b4db0913b6e9cca8e400"
Commit 5c043ac8 authored by Brian Pickrell's avatar Brian Pickrell
Browse files

refactored rand_uniform op. to accept a mandatory shape argument and an optional seed argument.

parent ec9e35c1
...@@ -25,7 +25,9 @@ ...@@ -25,7 +25,9 @@
/** /**
* Random Uniform distribution operator. Given a shape, populate it with random values. * Random Uniform distribution operator. Given a shape, populate it with random values.
* *
* Inputs: any tensor shape. * Inputs: (1) the shape of the set to be populated.
* (2) randomization seed. Optional--if not given, a seed will be generated
* automatically, for nonrepeatable random results.
* Attributes: TBD * Attributes: TBD
* *
Output: Same shape. Output: Same shape.
...@@ -47,8 +49,23 @@ namespace op { ...@@ -47,8 +49,23 @@ namespace op {
struct rand_uniform struct rand_uniform
{ {
uint32_t sample_size = {20}; uint32_t sample_size = {20};
uint32_t seed = {0}; uint32_t seed = {3};
shape::type_t dtype = shape::type_t::float_type; float range_min = 0.0f;
float range_max = 1.0f;
// From Onnx RandomUniform:
// dtype : int (default is 1)
// The data type for the elements of the output tensor. If not specified, default is
// TensorProto::FLOAT. high : float (default is 1.0) Upper boundary of the output values. low :
// float (default is 0.0) Lower boundary of the output values. seed : float (Optional) Seed to
// the random generator, if not specified we will auto generate one. shape : list of ints
// (required) The shape of the output tensor.
// TODO: consider removing this and simply using the type of the passed argument.
// The only bar to doing this currently is that we can't create random integers within the
// current bounds of (0, 1).
shape::type_t dtype = shape::type_t::float_type;
// std::vector<size_t> output_lens = {1};
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -62,19 +79,18 @@ struct rand_uniform ...@@ -62,19 +79,18 @@ struct rand_uniform
std::string name() const { return "rand_uniform"; } std::string name() const { return "rand_uniform"; }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this, true}.has(1); check_shapes{inputs, *this, true}.has(1, 2);
if(inputs.size() > 1 and inputs.at(1).element_space() > 0 and
inputs.at(1).type() != shape::type_t::uint32_type)
MIGRAPHX_THROW("RAND_UNIFORM: Input 2 (seed) must have type unsigned int");
auto s = inputs.front(); auto s = inputs.front();
if(s.dynamic()) if(s.dynamic())
{ {
return s; return s.with_type(dtype);
}
else if(s.broadcasted())
{
return {s.type(), s.lens()};
} }
else else
{ {
return s.with_lens(s.lens()); return s.with_lens(s.lens()).with_type(dtype);
} }
} }
...@@ -83,18 +99,24 @@ struct rand_uniform ...@@ -83,18 +99,24 @@ struct rand_uniform
(void)args; // suppress compiler warning (void)args; // suppress compiler warning
argument result{dyn_out.computed_shape}; argument result{dyn_out.computed_shape};
std::mt19937 gen(seed); auto local_seed(seed);
std::uniform_real_distribution<> dis(0.0, 1.0); if(args.size() > 1)
size_t elts(dyn_out.computed_shape.elements()); {
// Use of our visitor and par_for replaces a call like if(args.at(1).get_shape().element_space() > 0)
// std::vector<float> rand_samples(sample_size); {
// std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); }); visit_all(args[1])([&](auto data) { local_seed = data[0]; });
local_seed++;
}
else // obtain a seed from the system clock:
local_seed = std::chrono::system_clock::now().time_since_epoch().count();
}
// If a seed argument was not defined, use the value from the seed attribute,
// or the default.
std::mt19937 gen(local_seed);
std::uniform_real_distribution<> dis(range_min, range_max);
result.visit([&](auto output) { result.visit([&](auto output) {
par_for(elts, [&](auto i) { std::generate(output.begin(), output.end(), [&]() { return dis(gen); });
output[i] = dis(gen);
// output[i] = rand_samples[i];
});
}); });
return result; return result;
} }
......
...@@ -5305,7 +5305,7 @@ TEST_CASE(multinomial_dyn_test) ...@@ -5305,7 +5305,7 @@ TEST_CASE(multinomial_dyn_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
size_t sample_size = 1000000; size_t sample_size = 1000000;
float seed = 0.0f; uint32_t seed = 4;
// Shape of the random data // Shape of the random data
migraphx::shape rs{migraphx::shape::float_type, {{1, 2}, {2, sample_size + 1}}}; migraphx::shape rs{migraphx::shape::float_type, {{1, 2}, {2, sample_size + 1}}};
...@@ -5332,7 +5332,18 @@ TEST_CASE(multinomial_dyn_test) ...@@ -5332,7 +5332,18 @@ TEST_CASE(multinomial_dyn_test)
cdf = mm->add_instruction( cdf = mm->add_instruction(
migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), cdf); migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), cdf);
auto randoms = mm->add_instruction(migraphx::make_op("rand_uniform", {{"seed", seed}}), input); // To seed the rand_uniform, we can provide a value by literal or input,
// or we can pass it a 0-size shape in which case it will auto-seed.
migraphx::shape seed_shape{migraphx::shape::uint32_type, {0}};
std::vector<int32_t> seed_data = {};
auto seed_input = mm->add_literal(migraphx::literal(seed_shape, seed_data));
auto randoms = mm->add_instruction(migraphx::make_op("rand_uniform",
{
{"seed", seed},
}),
input,
seed_input); // <==some_seed is something user-supplied
mm->add_instruction(migraphx::make_op("multinomial"), cdf, randoms); mm->add_instruction(migraphx::make_op("multinomial"), cdf, randoms);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
...@@ -5366,7 +5377,7 @@ TEST_CASE(multinomial_dyn_test) ...@@ -5366,7 +5377,7 @@ TEST_CASE(multinomial_dyn_test)
std::transform(res_dist.begin(), res_dist.end(), res_norm.begin(), [&](auto n) { std::transform(res_dist.begin(), res_dist.end(), res_norm.begin(), [&](auto n) {
return static_cast<double>(n) / res_dist_sum; return static_cast<double>(n) / res_dist_sum;
}); });
EXPECT(migraphx::verify::verify_range(norm, res_norm, 1000000)); EXPECT(migraphx::verify::verify_range(norm, res_norm, 100000));
} }
TEST_CASE(neg_test) TEST_CASE(neg_test)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment