dot_apply_alpha_beta_test.cpp 4.61 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#include "migraphx/instruction.hpp"
#include <migraphx/common.hpp>
#include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>

TEST_CASE(dot_apply_alpha_beta_half)
{
    migraphx::module m1;
    {
        auto x       = m1.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2, 2}});
        auto y       = m1.add_parameter("y", migraphx::shape{migraphx::shape::half_type, {2, 2}});
        auto z       = m1.add_parameter("z", migraphx::shape{migraphx::shape::half_type, {2, 2}});
        auto dot_res = migraphx::insert_dot_apply_alpha_beta(m1, m1.end(), {x, y, z}, 3, 2);
        m1.add_instruction(migraphx::make_op("identity"), dot_res);
    }
    migraphx::module m2;
    {

        auto ht              = migraphx::shape::half_type;
        auto ft              = migraphx::shape::float_type;
        auto x               = m2.add_parameter("x", migraphx::shape{ht, {2, 2}});
        auto y               = m2.add_parameter("y", migraphx::shape{ht, {2, 2}});
        auto z               = m2.add_parameter("z", migraphx::shape{ht, {2, 2}});
        auto alpha_literal   = m2.add_literal(3.0f);
        auto alpha_broadcast = m2.add_instruction(
            migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}),
            alpha_literal);
        auto x_float = m2.add_instruction(migraphx::make_op("convert", {{"target_type", ft}}), x);
        auto x_alpha_float = m2.add_instruction(migraphx::make_op("mul"), alpha_broadcast, x_float);
        auto x_half =
            m2.add_instruction(migraphx::make_op("convert", {{"target_type", ht}}), x_alpha_float);
        auto dot_res      = m2.add_instruction(migraphx::make_op("dot"), x_half, y);
        auto beta_literal = m2.add_literal(2.0f);
        auto z_float = m2.add_instruction(migraphx::make_op("convert", {{"target_type", ft}}), z);
        auto beta_broadcast = m2.add_instruction(
            migraphx::make_op("multibroadcast", {{"out_lens", z->get_shape().lens()}}),
            beta_literal);
        auto z_beta_float = m2.add_instruction(migraphx::make_op("mul"), z_float, beta_broadcast);
        auto z_beta_half =
            m2.add_instruction(migraphx::make_op("convert", {{"target_type", ht}}), z_beta_float);
        auto z_add = m2.add_instruction(migraphx::make_op("add"), dot_res, z_beta_half);
        m2.add_instruction(migraphx::make_op("identity"), z_add);
    }
    EXPECT(m1 == m2);
}

TEST_CASE(dot_apply_alpha_beta_double)
{
    migraphx::module m1;
    {
        auto x       = m1.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {2, 2}});
        auto y       = m1.add_parameter("y", migraphx::shape{migraphx::shape::double_type, {2, 2}});
        auto z       = m1.add_parameter("z", migraphx::shape{migraphx::shape::double_type, {2, 1}});
        auto dot_res = migraphx::add_dot_apply_alpha_beta(m1, {x, y, z}, 3, 2);
        m1.add_instruction(migraphx::make_op("identity"), dot_res);
    }
    migraphx::module m2;
    {

        auto dt              = migraphx::shape::double_type;
        auto x               = m2.add_parameter("x", migraphx::shape{dt, {2, 2}});
        auto y               = m2.add_parameter("y", migraphx::shape{dt, {2, 2}});
        auto z               = m2.add_parameter("z", migraphx::shape{dt, {2, 1}});
        auto alpha_literal   = m2.add_literal(3.0f);
        auto alpha_broadcast = m2.add_instruction(
            migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}),
            alpha_literal);
        auto alpha_double = m2.add_instruction(migraphx::make_op("convert", {{"target_type", dt}}),
                                               alpha_broadcast);
        auto x_alpha_double = m2.add_instruction(migraphx::make_op("mul"), alpha_double, x);
        auto dot_res        = m2.add_instruction(migraphx::make_op("dot"), x_alpha_double, y);
        auto z_broadcast =
            m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2}}}), z);
        auto beta_literal   = m2.add_literal(2.0f);
        auto beta_broadcast = m2.add_instruction(
            migraphx::make_op("multibroadcast", {{"out_lens", z_broadcast->get_shape().lens()}}),
            beta_literal);
        auto beta_double =
            m2.add_instruction(migraphx::make_op("convert", {{"target_type", dt}}), beta_broadcast);
        auto z_beta_double = m2.add_instruction(migraphx::make_op("mul"), z_broadcast, beta_double);
        auto z_add         = m2.add_instruction(migraphx::make_op("add"), dot_res, z_beta_double);
        m2.add_instruction(migraphx::make_op("identity"), z_add);
    }
    EXPECT(m1 == m2);
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }