multinomial_int64_test.cpp 1.45 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34

#include <onnx_test.hpp>

TEST_CASE(multinomial_int64_test)
{
    migraphx::program p;
    auto* mm                      = p.get_main_module();
    size_t sample_size            = 10;
    float seed                    = 1.0;
    uint32_t batch_size           = 1;
    migraphx::shape::type_t dtype = migraphx::shape::type_t::int64_type;

    auto input = mm->add_parameter("input", migraphx::shape{migraphx::shape::float_type, {1, 10}});
    auto maxes = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), input);

    auto cdf = add_common_op(*mm, migraphx::make_op("sub"), {input, maxes});
    cdf      = mm->add_instruction(migraphx::make_op("exp"), cdf);
    cdf      = mm->add_instruction(
        migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), cdf);

    migraphx::shape s{migraphx::shape::float_type, {1}};
    std::vector<float> data = {seed};
    auto seed_input         = mm->add_literal(migraphx::literal(s, data));

    // static size
    auto rand_dummy = mm->add_literal(
        migraphx::literal{migraphx::shape{migraphx::shape::float_type, {batch_size, sample_size}},
                          std::vector<float>(batch_size * sample_size)});
    auto randoms = mm->add_instruction(migraphx::make_op("random_uniform"), seed_input, rand_dummy);
    mm->add_instruction(migraphx::make_op("multinomial", {{"dtype", dtype}}), cdf, randoms);
    auto prog = optimize_onnx("multinomial_int64_test.onnx");

    EXPECT(p == prog);
}