Commit e15e7ee3 authored by Brian Pickrell's avatar Brian Pickrell
Browse files

tidying up tests

parent c6a7d7cb
...@@ -29,8 +29,8 @@ ...@@ -29,8 +29,8 @@
* *
* In the large number limit, the fractional counts approach the multinomial distribution. * In the large number limit, the fractional counts approach the multinomial distribution.
* *
* Inputs: args[0] - a vector of probabilities for each category. Values are running totals * Inputs: args[0] - a vector of probabilities for each category. Values are running
as provided by op prefix_scan_sum. * totals as provided by op prefix_scan_sum.
* Values are log normalized (i.e. start with any set of numbers > 0, then * Values are log normalized (i.e. start with any set of numbers > 0, then
* val[i] = log(val[i]) / sum (log(val[0]) + log(val[1])+ ...) ) * val[i] = log(val[i]) / sum (log(val[0]) + log(val[1])+ ...) )
* This input has Rank 2. Dimension 0 is batch #. The size of dimension * This input has Rank 2. Dimension 0 is batch #. The size of dimension
......
...@@ -80,7 +80,7 @@ struct parse_multinomial : op_parser<parse_multinomial> ...@@ -80,7 +80,7 @@ struct parse_multinomial : op_parser<parse_multinomial>
} }
instruction_ref randoms; instruction_ref randoms;
if(args.size() > 0) if(not args.empty())
{ {
shape s0 = args[0]->get_shape(); shape s0 = args[0]->get_shape();
......
...@@ -4191,6 +4191,7 @@ TEST_CASE(multinomial_test) ...@@ -4191,6 +4191,7 @@ TEST_CASE(multinomial_test)
TEST_CASE(multinomial_dyn_test) TEST_CASE(multinomial_dyn_test)
{ {
// compile-time random seed
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
size_t sample_size = 13; size_t sample_size = 13;
...@@ -4206,7 +4207,6 @@ TEST_CASE(multinomial_dyn_test) ...@@ -4206,7 +4207,6 @@ TEST_CASE(multinomial_dyn_test)
migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), cdf); migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), cdf);
migraphx::shape rs{migraphx::shape::float_type, {1, sample_size}}; migraphx::shape rs{migraphx::shape::float_type, {1, sample_size}};
std::vector<float> data(rs.elements(), 0.3f);
migraphx::shape s{migraphx::shape::uint32_type, {1}}; migraphx::shape s{migraphx::shape::uint32_type, {1}};
std::vector<float> seed_data = {seed}; std::vector<float> seed_data = {seed};
...@@ -4216,7 +4216,6 @@ TEST_CASE(multinomial_dyn_test) ...@@ -4216,7 +4216,6 @@ TEST_CASE(multinomial_dyn_test)
auto ret = mm->add_instruction(migraphx::make_op("multinomial"), cdf, randoms); auto ret = mm->add_instruction(migraphx::make_op("multinomial"), cdf, randoms);
mm->add_return({ret}); mm->add_return({ret});
// auto prog = optimize_onnx("multinomial_dyn_test.onnx");
migraphx::onnx_options options; migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 10}; options.default_dyn_dim_value = {1, 10};
options.print_program_on_error = true; options.print_program_on_error = true;
...@@ -4226,6 +4225,7 @@ TEST_CASE(multinomial_dyn_test) ...@@ -4226,6 +4225,7 @@ TEST_CASE(multinomial_dyn_test)
TEST_CASE(multinomial_autoseed_dyn_test) TEST_CASE(multinomial_autoseed_dyn_test)
{ {
// runtime random seed
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
size_t sample_size = 13; size_t sample_size = 13;
...@@ -4247,7 +4247,6 @@ TEST_CASE(multinomial_autoseed_dyn_test) ...@@ -4247,7 +4247,6 @@ TEST_CASE(multinomial_autoseed_dyn_test)
auto ret = mm->add_instruction(migraphx::make_op("multinomial"), cdf, randoms); auto ret = mm->add_instruction(migraphx::make_op("multinomial"), cdf, randoms);
mm->add_return({ret}); mm->add_return({ret});
// auto prog = optimize_onnx("multinomial_dyn_test.onnx");
migraphx::onnx_options options; migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 10}; options.default_dyn_dim_value = {1, 10};
options.print_program_on_error = true; options.print_program_on_error = true;
......
...@@ -5377,6 +5377,7 @@ TEST_CASE(multinomial_dyn_test) ...@@ -5377,6 +5377,7 @@ TEST_CASE(multinomial_dyn_test)
std::transform(res_dist.begin(), res_dist.end(), res_norm.begin(), [&](auto n) { std::transform(res_dist.begin(), res_dist.end(), res_norm.begin(), [&](auto n) {
return static_cast<double>(n) / res_dist_sum; return static_cast<double>(n) / res_dist_sum;
}); });
// The given test tolerance is about 10x the typical error
EXPECT(migraphx::verify::verify_range(norm, res_norm, 100000)); EXPECT(migraphx::verify::verify_range(norm, res_norm, 100000));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment