Commit 402192b8 authored by Brian Pickrell's avatar Brian Pickrell
Browse files

Re-try add of op. random_seed and remove use_auto_seed.

parent 2e26eb5d
...@@ -31,8 +31,9 @@ ...@@ -31,8 +31,9 @@
* *
* Inputs: (1) the shape of the set to be populated. * Inputs: (1) the shape of the set to be populated.
* (2) randomization seed (uint32). If not given at inference time, the attribute * (2) randomization seed (uint32). If not given at inference time, the attribute
* value, or auto seeding, will be used. Attributes: use_auto_seed bool Have hardware generate * value, or auto seeding, will be used.
* random seed at runtime, overriding the attribute seed seed uint32 Randomization seed *
* Attributes: seed uint32 Randomization seed
* *
* Output: Same shape. * Output: Same shape.
* *
...@@ -53,7 +54,6 @@ namespace op { ...@@ -53,7 +54,6 @@ namespace op {
struct rand_uniform struct rand_uniform
{ {
uint32_t seed = {0}; uint32_t seed = {0};
bool use_auto_seed = false;
// todo: not currently settable // todo: not currently settable
float range_min = 0.0f; float range_min = 0.0f;
...@@ -65,8 +65,7 @@ struct rand_uniform ...@@ -65,8 +65,7 @@ struct rand_uniform
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack( return pack(f(self.dtype, "dtype"), f(self.seed, "seed"));
f(self.dtype, "dtype"), f(self.seed, "seed"), f(self.use_auto_seed, "use_auto_seed"));
} }
std::string name() const { return "rand_uniform"; } std::string name() const { return "rand_uniform"; }
......
...@@ -40,7 +40,7 @@ namespace op { ...@@ -40,7 +40,7 @@ namespace op {
* at runtime guarantees there will be a different random sequence on every execution. * at runtime guarantees there will be a different random sequence on every execution.
* This operation has no inputs or attributes, and outputs an unsigned integer tensor with * This operation has no inputs or attributes, and outputs an unsigned integer tensor with
* a single value. * a single value.
*/ */
struct random_seed struct random_seed
{ {
shape::type_t dtype = shape::type_t::uint32_type; shape::type_t dtype = shape::type_t::uint32_type;
...@@ -48,27 +48,27 @@ struct random_seed ...@@ -48,27 +48,27 @@ struct random_seed
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack( return pack(f(self.dtype, "dtype"));
f(self.dtype, "dtype"));
} }
std::string name() const { return "random_seed"; } std::string name() const { return "random_seed"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
(void) inputs; (void)inputs;
return migraphx::shape(dtype, {1}); return migraphx::shape(dtype, {1});
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
(void) args; (void)args;
argument result(output_shape); argument result(output_shape);
result.visit([&](auto output) { result.visit([&](auto output) {
std::generate(output.begin(), output.end(), [&]() { return uint32_t(std::chrono::system_clock::now().time_since_epoch().count()); }); std::generate(output.begin(), output.end(), [&]() {
return uint32_t(std::chrono::system_clock::now().time_since_epoch().count());
});
}); });
return result; return result;
} }
}; };
......
...@@ -6542,7 +6542,6 @@ TEST_CASE(rand_uniform_dyn_test) ...@@ -6542,7 +6542,6 @@ TEST_CASE(rand_uniform_dyn_test)
EXPECT(migraphx::verify::verify_range(result_vec, rand_samples, 100000)); EXPECT(migraphx::verify::verify_range(result_vec, rand_samples, 100000));
} }
TEST_CASE(rand_uniform_and_seed_test) TEST_CASE(rand_uniform_and_seed_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -6556,9 +6555,7 @@ TEST_CASE(rand_uniform_and_seed_test) ...@@ -6556,9 +6555,7 @@ TEST_CASE(rand_uniform_and_seed_test)
// Runtime randomization seed // Runtime randomization seed
auto seed_input = mm->add_instruction(migraphx::make_op("random_seed")); auto seed_input = mm->add_instruction(migraphx::make_op("random_seed"));
mm->add_instruction(migraphx::make_op("rand_uniform"), mm->add_instruction(migraphx::make_op("rand_uniform"), input, seed_input);
input,
seed_input);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
// Create a dummy input to hold the random data // Create a dummy input to hold the random data
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment