Commit 333ce7d0 authored by Brian Pickrell's avatar Brian Pickrell
Browse files

add random_uniform_int_test

parent 99a9d56f
......@@ -6493,6 +6493,42 @@ TEST_CASE(random_uniform_test)
EXPECT(migraphx::verify::verify_range(result_vec, rand_samples, 100000));
}
TEST_CASE(random_uniform_int_test)
{
// random uniform distribution with an integer type input shape
migraphx::program p;
auto* mm = p.get_main_module();
uint64_t seed(0);
size_t sample_size(200);
// Shape of the random data
migraphx::shape rs{migraphx::shape::uint16_type, {1, sample_size}};
// data tensor must be allocated at this point but does not need to be initialized.
std::vector<uint16_t> data(sample_size);
auto input = mm->add_literal(migraphx::literal(rs, data));
// Runtime randomization seed
migraphx::shape seed_shape{migraphx::shape::uint64_type, {1}};
std::vector<uint64_t> seed_data{seed};
auto seed_input = mm->add_literal(migraphx::literal(seed_shape, seed_data));
mm->add_instruction(migraphx::make_op("random_uniform"), seed_input, input);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params0;
auto result = p.eval(params0).back();
std::vector<uint16_t> result_vec(sample_size);
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
// Compare result with the STL's mt19937 generator
std::mt19937 gen(seed);
std::uniform_int_distribution<uint16_t> dis;
std::vector<uint16_t> rand_samples(sample_size);
std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); });
EXPECT(migraphx::verify::verify_range(result_vec, rand_samples, 100000));
}
TEST_CASE(random_uniform_dyn_test)
{
migraphx::program p;
......@@ -6516,7 +6552,7 @@ TEST_CASE(random_uniform_dyn_test)
migraphx::parameter_map params0;
params0["Input_1"] = migraphx::argument(input_fixed_shape1);
// migraphx::shape seed_fixed_shape{migraphx::shape::uint64_type, {1}};
std::vector<uint64_t> seed_data = {seed};
params0["Seed"] = migraphx::argument(seed_shape, seed_data.data());
auto result = p.eval(params0).back();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment