Commit c597a533 authored by Brian Pickrell's avatar Brian Pickrell
Browse files

misc cleanup in response to PR feedback

parent 50da6a8c
......@@ -27,8 +27,6 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/reflect.hpp>
#include <random>
namespace migraphx {
......
......@@ -29,7 +29,7 @@
* be given as a runtime argument containing a single value, or a compile-time
* attribute.
*
* Inputs: (1) randomization seed (uint64)
* Inputs: (1) randomization seed (any type is allowed)
* (2) the shape of the set to be populated.
*
*
......@@ -43,7 +43,6 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <random>
namespace migraphx {
......@@ -68,15 +67,7 @@ struct random_uniform
{
check_shapes{inputs, *this, true}.has(2);
auto s = inputs.at(1);
if(s.dynamic())
{
return s;
}
else
{
return s.with_lens(s.lens());
}
return inputs.at(1);
}
argument compute(const shape&, std::vector<argument> args) const
......
......@@ -6480,8 +6480,8 @@ TEST_CASE(random_uniform_test)
mm->add_instruction(migraphx::make_op("random_uniform"), seed_input, input);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params0;
auto result = p.eval(params0).back();
// no params_map needed
auto result = p.eval({}).back();
std::vector<float> result_vec(sample_size);
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
......@@ -6490,7 +6490,7 @@ TEST_CASE(random_uniform_test)
std::uniform_real_distribution<> dis(0.0, 1.0);
std::vector<float> rand_samples(sample_size);
std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); });
EXPECT(migraphx::verify::verify_range(result_vec, rand_samples, 100000));
EXPECT(migraphx::verify::verify_range(result_vec, rand_samples, 100));
}
TEST_CASE(random_uniform_int_test)
......@@ -6611,7 +6611,7 @@ TEST_CASE(random_seed_test)
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<uint32_t> result_vec1(1);
std::vector<uint64_t> result_vec1(1);
result.visit([&](auto output) { result_vec1.assign(output.begin(), output.end()); });
std::vector<uint64_t> result_vec2(1);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment