multinomial_dyn_test.cpp 2.5 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58

#include <onnx_test.hpp>

TEST_CASE(multinomial_dyn_test)
{
    // compile-time random seed
    migraphx::program p;
    auto* mm           = p.get_main_module();
    size_t sample_size = 100000;
    size_t categories  = 5;
    float seed         = 1.3f;

    auto input = mm->add_parameter(
        "input",
        migraphx::shape{migraphx::shape::float_type, {{1, categories}, {categories, categories}}});

    auto maxes = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), input);

    auto cdf = add_common_op(*mm, migraphx::make_op("sub"), {input, maxes});
    cdf      = mm->add_instruction(migraphx::make_op("exp"), cdf);
    cdf      = mm->add_instruction(
        migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), cdf);

    migraphx::shape s{migraphx::shape::float_type, {1}};
    std::vector<float> seed_data = {seed};
    auto seed_input              = mm->add_literal(migraphx::literal(s, seed_data));

    // dynamic input only:  must calculate alloc_shape as (batch_size, sample_size)
    //                read the runtime input dimensions
    auto dim_of = mm->add_instruction(migraphx::make_op("dimensions_of", {{"end", 2}}), input);
    // make an argument of (1, 0)
    migraphx::shape lit_shape(migraphx::shape::int64_type, {2});
    std::vector<int64_t> data1{1, 0};
    auto l1        = mm->add_literal(lit_shape, data1);
    auto batch_arg = mm->add_instruction(migraphx::make_op("mul"), dim_of, l1);
    std::vector<int64_t> data2(2, 0);
    // make an argument of (0, sample_size)
    data2[1]         = sample_size;
    auto l2          = mm->add_literal(lit_shape, data2);
    auto alloc_shape = mm->add_instruction(migraphx::make_op("add"), batch_arg, l2);
    migraphx::shape compile_shape =
        migraphx::shape(migraphx::shape::float_type,
                        {input->get_shape().dyn_dims().front(), {sample_size, sample_size}});

    auto alloc = mm->add_instruction(
        migraphx::make_op("allocate", {{"shape", to_value(compile_shape)}}), alloc_shape);

    auto randoms = mm->add_instruction(migraphx::make_op("random_uniform"), seed_input, alloc);
    auto ret     = mm->add_instruction(
        migraphx::make_op("multinomial", {{"dtype", migraphx::shape::float_type}}), cdf, randoms);
    mm->add_return({ret});

    migraphx::onnx_options options;
    options.default_dyn_dim_value  = {1, categories};
    options.print_program_on_error = true;
    auto prog                      = migraphx::parse_onnx("multinomial_dyn_test.onnx", options);
    EXPECT(p == prog);
}