gemm_half_test.cpp 1.64 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33

#include <onnx_test.hpp>
#include <migraphx/apply_alpha_beta.hpp>

TEST_CASE(gemm_half_test)
{
    migraphx::program p;
    auto* mm   = p.get_main_module();
    auto l0    = mm->add_parameter("A", migraphx::shape{migraphx::shape::half_type, {8, 6}});
    auto l1    = mm->add_parameter("B", migraphx::shape{migraphx::shape::half_type, {8, 7}});
    auto l2    = mm->add_parameter("C", migraphx::shape{migraphx::shape::half_type, {6, 1}});
    auto alpha = 0.5f;
    auto beta  = 0.8f;
    auto a_l   = mm->add_literal(alpha);
    auto t_a   = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
    t_a        = mm->add_instruction(
        migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), t_a);
    t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a);
    std::vector<std::size_t> lens = {6, 7};
    auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, l1}, migraphx::make_op("dot"), 1.0f, 0.0f);
    l2       = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), l2);
    l2       = mm->add_instruction(
        migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), l2);
    auto b_l  = mm->add_literal(beta);
    auto b_b  = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), b_l);
    auto l2_b = mm->add_instruction(migraphx::make_op("mul"), l2, b_b);
    l2_b      = mm->add_instruction(
        migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), l2_b);
    mm->add_instruction(migraphx::make_op("add"), dot, l2_b);

    auto prog = optimize_onnx("gemm_half_test.onnx");
    EXPECT(p == prog);
}