Commit a15551e5 authored by Brian Pickrell's avatar Brian Pickrell
Browse files

change multinomial_dyn_test to send 0-size seed input. This triggers auto seed

parent b4b864a7
......@@ -75,6 +75,8 @@ struct parse_multinomial : op_parser<parse_multinomial>
{
shape s0 = args[0]->get_shape();
// TODO: Use literal if batch size is fixed
// TODO: Add second argument for seed (an Migraphx rule, not Onnx) if Onnx seed not given
// It will be a literal with a shape of 0 size
if(s0.dynamic())
{
// Dynamic batch_size will be taken from args[0]. Other contents of input are
......
......@@ -5311,6 +5311,12 @@ TEST_CASE(multinomial_dyn_test)
migraphx::shape rs{migraphx::shape::float_type, {{1, 2}, {2, sample_size + 1}}};
auto input = mm->add_parameter("Input_1", rs);
// Runtime randomization seed
// To seed the rand_uniform, we can provide a value by literal or input,
// or we can pass it a 0-size shape in which case it will auto-seed.
migraphx::shape seed_shape{migraphx::shape::uint32_type, {migraphx::shape::dynamic_dimension{0, 1}}};
auto seed_input = mm->add_parameter("Seed", seed_shape);
// Shape of the probability distribution, which also defines the number of categories
migraphx::shape s{migraphx::shape::float_type, {{1, 1}, {5, 6}}};
std::vector<int> dist{15, 25, 15, 25, 20};
......@@ -5332,18 +5338,12 @@ TEST_CASE(multinomial_dyn_test)
cdf = mm->add_instruction(
migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), cdf);
// To seed the rand_uniform, we can provide a value by literal or input,
// or we can pass it a 0-size shape in which case it will auto-seed.
migraphx::shape seed_shape{migraphx::shape::uint32_type, {0}};
std::vector<int32_t> seed_data = {};
auto seed_input = mm->add_literal(migraphx::literal(seed_shape, seed_data));
auto randoms = mm->add_instruction(migraphx::make_op("rand_uniform",
{
{"seed", seed},
}),
input,
seed_input); // <==some_seed is something user-supplied
seed_input);
mm->add_instruction(migraphx::make_op("multinomial"), cdf, randoms);
p.compile(migraphx::make_target("ref"));
......@@ -5354,6 +5354,11 @@ TEST_CASE(multinomial_dyn_test)
migraphx::shape input_fixed_shape2{migraphx::shape::float_type, {1, 5}};
migraphx::parameter_map params0;
params0["Input_1"] = migraphx::argument(input_fixed_shape1, dummy.data());
migraphx::shape seed_fixed_shape{migraphx::shape::uint32_type, {0}};
std::vector<uint32_t> seed_data = {};
params0["Seed"] = migraphx::argument(seed_fixed_shape, seed_data.data());
params0["Input_2"] = migraphx::argument(input_fixed_shape2, data.data());
auto result = p.eval(params0).back();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment