"tests/vscode:/vscode.git/clone" did not exist on "df7955de37d1505aee7145fa471dc94458d98666"
dot_apply_alpha_beta_test.cpp 7.08 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#include <cstdint>
#include <migraphx/instruction.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>

TEST_CASE(dot_apply_alpha_beta_half)
{
    migraphx::module m1;
    {
        auto x       = m1.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2, 2}});
        auto y       = m1.add_parameter("y", migraphx::shape{migraphx::shape::half_type, {2, 2}});
        auto z       = m1.add_parameter("z", migraphx::shape{migraphx::shape::half_type, {2, 2}});
        auto dot_res = migraphx::insert_apply_alpha_beta(
            m1, m1.end(), {x, y, z}, migraphx::make_op("dot"), 3.0f, 2.0f);
        m1.add_instruction(migraphx::make_op("identity"), dot_res);
    }
    migraphx::module m2;
    {

        auto ht              = migraphx::shape::half_type;
        auto ft              = migraphx::shape::float_type;
        auto x               = m2.add_parameter("x", migraphx::shape{ht, {2, 2}});
        auto y               = m2.add_parameter("y", migraphx::shape{ht, {2, 2}});
        auto z               = m2.add_parameter("z", migraphx::shape{ht, {2, 2}});
        auto alpha_literal   = m2.add_literal(3.0f);
        auto alpha_broadcast = m2.add_instruction(
            migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}),
            alpha_literal);
        auto x_float = m2.add_instruction(migraphx::make_op("convert", {{"target_type", ft}}), x);
        auto x_alpha_float = m2.add_instruction(migraphx::make_op("mul"), alpha_broadcast, x_float);
        auto x_half =
            m2.add_instruction(migraphx::make_op("convert", {{"target_type", ht}}), x_alpha_float);
        auto dot_res      = m2.add_instruction(migraphx::make_op("dot"), x_half, y);
        auto beta_literal = m2.add_literal(2.0f);
        auto z_float = m2.add_instruction(migraphx::make_op("convert", {{"target_type", ft}}), z);
        auto beta_broadcast = m2.add_instruction(
            migraphx::make_op("multibroadcast", {{"out_lens", z->get_shape().lens()}}),
            beta_literal);
        auto z_beta_float = m2.add_instruction(migraphx::make_op("mul"), z_float, beta_broadcast);
        auto z_beta_half =
            m2.add_instruction(migraphx::make_op("convert", {{"target_type", ht}}), z_beta_float);
        auto z_add = m2.add_instruction(migraphx::make_op("add"), dot_res, z_beta_half);
        m2.add_instruction(migraphx::make_op("identity"), z_add);
    }
    EXPECT(m1 == m2);
}

TEST_CASE(dot_apply_alpha_beta_double)
{
    migraphx::module m1;
    {
        auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {2, 2}});
        auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::double_type, {2, 2}});
        auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::double_type, {2, 1}});
        auto dot_res =
            migraphx::add_apply_alpha_beta(m1, {x, y, z}, migraphx::make_op("dot"), 3.0f, 2.0f);
        m1.add_instruction(migraphx::make_op("identity"), dot_res);
    }
    migraphx::module m2;
    {

        auto dt              = migraphx::shape::double_type;
        auto x               = m2.add_parameter("x", migraphx::shape{dt, {2, 2}});
        auto y               = m2.add_parameter("y", migraphx::shape{dt, {2, 2}});
        auto z               = m2.add_parameter("z", migraphx::shape{dt, {2, 1}});
        auto alpha_literal   = m2.add_literal(3.0f);
        auto alpha_broadcast = m2.add_instruction(
            migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}),
            alpha_literal);
        auto alpha_double = m2.add_instruction(migraphx::make_op("convert", {{"target_type", dt}}),
                                               alpha_broadcast);
        auto x_alpha_double = m2.add_instruction(migraphx::make_op("mul"), alpha_double, x);
        auto dot_res        = m2.add_instruction(migraphx::make_op("dot"), x_alpha_double, y);
        auto z_broadcast =
            m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2}}}), z);
        auto beta_literal   = m2.add_literal(2.0f);
        auto beta_broadcast = m2.add_instruction(
            migraphx::make_op("multibroadcast", {{"out_lens", z_broadcast->get_shape().lens()}}),
            beta_literal);
        auto beta_double =
            m2.add_instruction(migraphx::make_op("convert", {{"target_type", dt}}), beta_broadcast);
        auto z_beta_double = m2.add_instruction(migraphx::make_op("mul"), z_broadcast, beta_double);
        auto z_add         = m2.add_instruction(migraphx::make_op("add"), dot_res, z_beta_double);
        m2.add_instruction(migraphx::make_op("identity"), z_add);
    }
    EXPECT(m1 == m2);
}

TEST_CASE(quant_dot_apply_alpha_beta)
{
    migraphx::module m1;
    {
        auto x       = m1.add_parameter("x", migraphx::shape{migraphx::shape::int8_type, {2, 2}});
        auto y       = m1.add_parameter("y", migraphx::shape{migraphx::shape::int8_type, {2, 2}});
        auto z       = m1.add_parameter("z", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
        auto dot_res = migraphx::insert_apply_alpha_beta(m1,
                                                         m1.end(),
                                                         {x, y, z},
                                                         migraphx::make_op("quant_dot"),
                                                         migraphx::literal{int32_t{3}},
                                                         migraphx::literal{int32_t{2}});
        m1.add_instruction(migraphx::make_op("identity"), dot_res);
    }
    migraphx::module m2;
    {

        auto i8              = migraphx::shape::int8_type;
        auto i32             = migraphx::shape::int32_type;
        auto x               = m2.add_parameter("x", migraphx::shape{i8, {2, 2}});
        auto y               = m2.add_parameter("y", migraphx::shape{i8, {2, 2}});
        auto z               = m2.add_parameter("z", migraphx::shape{i32, {2, 2}});
        auto alpha_literal   = m2.add_literal(int32_t(3));
        auto alpha_broadcast = m2.add_instruction(
            migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}),
            alpha_literal);
        auto x_i32 = m2.add_instruction(migraphx::make_op("convert", {{"target_type", i32}}), x);
        auto x_alpha_i32 = m2.add_instruction(migraphx::make_op("mul"), alpha_broadcast, x_i32);
        auto x_i8 =
            m2.add_instruction(migraphx::make_op("convert", {{"target_type", i8}}), x_alpha_i32);
        auto dot_res        = m2.add_instruction(migraphx::make_op("quant_dot"), x_i8, y);
        auto beta_literal   = m2.add_literal(int32_t(2));
        auto beta_broadcast = m2.add_instruction(
            migraphx::make_op("multibroadcast", {{"out_lens", z->get_shape().lens()}}),
            beta_literal);
        auto z_beta_i32 = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
        auto z_add      = m2.add_instruction(migraphx::make_op("add"), dot_res, z_beta_i32);
        m2.add_instruction(migraphx::make_op("identity"), z_add);
    }
    EXPECT(m1 == m2);
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }