Commit c597a533 authored by Brian Pickrell's avatar Brian Pickrell
Browse files

misc cleanup in response to PR feedback

parent 50da6a8c
...@@ -27,8 +27,6 @@ ...@@ -27,8 +27,6 @@
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/reflect.hpp>
#include <random> #include <random>
namespace migraphx { namespace migraphx {
......
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
* be given as a runtime argument containing a single value, or a compile-time * be given as a runtime argument containing a single value, or a compile-time
* attribute. * attribute.
* *
* Inputs: (1) randomization seed (uint64) * Inputs: (1) randomization seed (any type is allowed)
* (2) the shape of the set to be populated. * (2) the shape of the set to be populated.
* *
* *
...@@ -43,7 +43,6 @@ ...@@ -43,7 +43,6 @@
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <random> #include <random>
namespace migraphx { namespace migraphx {
...@@ -68,15 +67,7 @@ struct random_uniform ...@@ -68,15 +67,7 @@ struct random_uniform
{ {
check_shapes{inputs, *this, true}.has(2); check_shapes{inputs, *this, true}.has(2);
auto s = inputs.at(1); return inputs.at(1);
if(s.dynamic())
{
return s;
}
else
{
return s.with_lens(s.lens());
}
} }
argument compute(const shape&, std::vector<argument> args) const argument compute(const shape&, std::vector<argument> args) const
......
...@@ -6480,8 +6480,8 @@ TEST_CASE(random_uniform_test) ...@@ -6480,8 +6480,8 @@ TEST_CASE(random_uniform_test)
mm->add_instruction(migraphx::make_op("random_uniform"), seed_input, input); mm->add_instruction(migraphx::make_op("random_uniform"), seed_input, input);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params0; // no params_map needed
auto result = p.eval(params0).back(); auto result = p.eval({}).back();
std::vector<float> result_vec(sample_size); std::vector<float> result_vec(sample_size);
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
...@@ -6490,7 +6490,7 @@ TEST_CASE(random_uniform_test) ...@@ -6490,7 +6490,7 @@ TEST_CASE(random_uniform_test)
std::uniform_real_distribution<> dis(0.0, 1.0); std::uniform_real_distribution<> dis(0.0, 1.0);
std::vector<float> rand_samples(sample_size); std::vector<float> rand_samples(sample_size);
std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); }); std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); });
EXPECT(migraphx::verify::verify_range(result_vec, rand_samples, 100000)); EXPECT(migraphx::verify::verify_range(result_vec, rand_samples, 100));
} }
TEST_CASE(random_uniform_int_test) TEST_CASE(random_uniform_int_test)
...@@ -6611,7 +6611,7 @@ TEST_CASE(random_seed_test) ...@@ -6611,7 +6611,7 @@ TEST_CASE(random_seed_test)
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<uint32_t> result_vec1(1); std::vector<uint64_t> result_vec1(1);
result.visit([&](auto output) { result_vec1.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vec1.assign(output.begin(), output.end()); });
std::vector<uint64_t> result_vec2(1); std::vector<uint64_t> result_vec2(1);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment