propagate_constant_test.cpp 3.47 KB
Newer Older
Paul's avatar
Paul committed
1
#include <migraphx/propagate_constant.hpp>
Paul's avatar
Paul committed
2
#include <migraphx/dead_code_elimination.hpp>
3
#include <migraphx/op/add.hpp>
Paul's avatar
Paul committed
4
#include <migraphx/op/scalar.hpp>
Paul's avatar
Paul committed
5
#include <migraphx/op/mul.hpp>
Paul's avatar
Paul committed
6
7
8
9
10
11
#include <basic_ops.hpp>
#include <test.hpp>

struct const_prop_target
{
    std::string name() const { return "const_prop"; }
Paul's avatar
Paul committed
12
    std::vector<migraphx::pass> get_passes(migraphx::context&) const
Paul's avatar
Paul committed
13
    {
Paul's avatar
Paul committed
14
        return {migraphx::propagate_constant{}, migraphx::dead_code_elimination{}};
Paul's avatar
Paul committed
15
    }
Paul's avatar
Paul committed
16
    migraphx::context get_context() const { return {}; }
Paul's avatar
Paul committed
17
18
};

Paul's avatar
Paul committed
19
TEST_CASE(const_add)
Paul's avatar
Paul committed
20
{
Paul's avatar
Paul committed
21
    migraphx::program p1;
Paul's avatar
Paul committed
22
23
    auto one = p1.add_literal(1);
    auto two = p1.add_literal(2);
Paul's avatar
Paul committed
24
    auto sum = p1.add_instruction(migraphx::op::add{}, one, two);
Paul's avatar
Paul committed
25
26
27
    p1.add_instruction(pass_op{}, sum);
    p1.compile(const_prop_target{});

Paul's avatar
Paul committed
28
    migraphx::program p2;
Paul's avatar
Paul committed
29
30
31
32
33
    auto total = p2.add_literal(3);
    p2.add_instruction(pass_op{}, total);
    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
34
TEST_CASE(const_add_parameter)
Paul's avatar
Paul committed
35
{
Paul's avatar
Paul committed
36
37
    migraphx::program p1;
    auto one = p1.add_parameter("one", {migraphx::shape::int32_type, {1}});
Paul's avatar
Paul committed
38
    auto two = p1.add_literal(2);
Paul's avatar
Paul committed
39
    auto sum = p1.add_instruction(migraphx::op::add{}, one, two);
Paul's avatar
Paul committed
40
41
42
    p1.add_instruction(pass_op{}, sum);
    p1.compile(const_prop_target{});

Paul's avatar
Paul committed
43
    migraphx::program p2;
Paul's avatar
Paul committed
44
45
46
47
48
    auto total = p2.add_literal(3);
    p2.add_instruction(pass_op{}, total);
    EXPECT(p1 != p2);
}

Paul's avatar
Paul committed
49
TEST_CASE(const_multiadd)
Paul's avatar
Paul committed
50
{
Paul's avatar
Paul committed
51
    migraphx::program p1;
Paul's avatar
Paul committed
52
53
    auto one  = p1.add_literal(1);
    auto two  = p1.add_literal(2);
Paul's avatar
Paul committed
54
55
    auto sum1 = p1.add_instruction(migraphx::op::add{}, one, two);
    auto sum2 = p1.add_instruction(migraphx::op::add{}, sum1, two);
Paul's avatar
Paul committed
56
57
58
    p1.add_instruction(pass_op{}, sum2);
    p1.compile(const_prop_target{});

Paul's avatar
Paul committed
59
    migraphx::program p2;
Paul's avatar
Paul committed
60
61
62
63
64
    auto total = p2.add_literal(5);
    p2.add_instruction(pass_op{}, total);
    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
65
TEST_CASE(const_add_mul)
Paul's avatar
Paul committed
66
67
68
69
{
    migraphx::program p1;
    auto one  = p1.add_literal(1);
    auto two  = p1.add_literal(2);
Paul's avatar
Paul committed
70
    auto mul  = p1.add_instruction(migraphx::op::mul{}, two, two);
Paul's avatar
Paul committed
71
72
73
74
75
76
77
78
79
80
81
    auto sum1 = p1.add_instruction(migraphx::op::add{}, one, mul);
    auto sum2 = p1.add_instruction(migraphx::op::add{}, sum1, two);
    p1.add_instruction(pass_op{}, sum2);
    p1.compile(const_prop_target{});

    migraphx::program p2;
    auto total = p2.add_literal(7);
    p2.add_instruction(pass_op{}, total);
    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
TEST_CASE(const_add_scalar)
{
    migraphx::program p1;
    auto one = p1.add_instruction(migraphx::op::scalar{{migraphx::shape::int32_type, {2, 2}}}, p1.add_literal(1));
    auto two = p1.add_instruction(migraphx::op::scalar{{migraphx::shape::int32_type, {2, 2}}}, p1.add_literal(2));
    auto sum = p1.add_instruction(migraphx::op::add{}, one, two);
    p1.add_instruction(pass_op{}, sum);
    p1.compile(const_prop_target{});

    migraphx::program p2;
    auto total = p2.add_literal(migraphx::literal{{migraphx::shape::int32_type, {2, 2}}, {3, 3, 3, 3}});
    p2.add_instruction(pass_op{}, total);
    EXPECT(p1 == p2);
}

TEST_CASE(const_scalar)
{
    migraphx::program p1;
    {
        auto one = p1.add_instruction(migraphx::op::scalar{{migraphx::shape::int32_type, {2, 2}}}, p1.add_literal(1));
        p1.add_instruction(pass_op{}, one);
    }
    p1.compile(const_prop_target{});

    migraphx::program p2;
    {
        auto one = p2.add_instruction(migraphx::op::scalar{{migraphx::shape::int32_type, {2, 2}}}, p2.add_literal(1));
        p2.add_instruction(pass_op{}, one);
    }
    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
114
int main(int argc, const char* argv[]) { test::run(argc, argv); }