eliminate_allocation_test.cpp 3.99 KB
Newer Older
Paul's avatar
Paul committed
1
2
#include <migraphx/eliminate_allocation.hpp>
#include <migraphx/dead_code_elimination.hpp>
3
#include <migraphx/pass_manager.hpp>
4
5
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
Paul's avatar
Paul committed
6
7
8
#include <basic_ops.hpp>
#include <test.hpp>

9
void run_pass(migraphx::program& p, std::size_t align = 32)
Paul's avatar
Paul committed
10
{
11
    migraphx::run_passes(
12
13
        *p.get_main_module(),
        {migraphx::eliminate_allocation{"allocate", align}, migraphx::dead_code_elimination{}});
14
}
Paul's avatar
Paul committed
15
16
17

struct allocate
{
Paul's avatar
Paul committed
18
    migraphx::shape s{};
19
20
21
22
23
24

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::pack(f(self.s, "shape"));
    }
Paul's avatar
Paul committed
25

Paul's avatar
Paul committed
26
    std::string name() const { return "allocate"; }
Paul's avatar
Paul committed
27
    migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
Paul's avatar
Paul committed
28
    {
29
        migraphx::check_shapes{inputs, *this}.has(0);
Paul's avatar
Paul committed
30
31
        return s;
    }
Paul's avatar
Paul committed
32
    migraphx::argument compute(migraphx::context&,
Paul's avatar
Paul committed
33
34
                               const migraphx::shape& output_shape,
                               const std::vector<migraphx::argument>&) const
Paul's avatar
Paul committed
35
36
37
38
39
    {
        return {output_shape};
    }
};

Paul's avatar
Paul committed
40
TEST_CASE(basic)
Paul's avatar
Paul committed
41
{
Paul's avatar
Paul committed
42
    migraphx::program p;
Paul's avatar
Paul committed
43

44
45
46
    auto* mm = p.get_main_module();
    auto a1  = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {8}}});
    auto p1  = mm->add_instruction(pass_op{}, a1);
Paul's avatar
Paul committed
47

48
49
50
51
52
    auto a2 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {40}}});
    auto p2 = mm->add_instruction(pass_op{}, a2, p1);

    auto a3 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {200}}});
    mm->add_instruction(pass_op{}, a3, p2);
Paul's avatar
Paul committed
53

54
    run_pass(p);
55
    EXPECT(p.get_output_shapes().back() == migraphx::shape{migraphx::shape::float_type, {200}});
Paul's avatar
Paul committed
56
    EXPECT(p.get_parameter_shape("memory").bytes() == (8 * 4 + 40 * 4 + 200 * 4));
Paul's avatar
Paul committed
57
58
}

Paul's avatar
Paul committed
59
TEST_CASE(aligned)
Paul's avatar
Paul committed
60
{
Paul's avatar
Paul committed
61
    migraphx::program p;
Paul's avatar
Paul committed
62

63
64
65
66
67
68
    auto* mm = p.get_main_module();
    auto a1  = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1}}});
    auto p1  = mm->add_instruction(pass_op{}, a1);

    auto a2 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2}}});
    auto p2 = mm->add_instruction(pass_op{}, a2, p1);
Paul's avatar
Paul committed
69

70
71
    auto a3 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {200}}});
    mm->add_instruction(pass_op{}, a3, p2);
Paul's avatar
Paul committed
72

73
    run_pass(p);
74
    EXPECT(p.get_output_shapes().back() == migraphx::shape{migraphx::shape::float_type, {200}});
Paul's avatar
Paul committed
75
    EXPECT(p.get_parameter_shape("memory").bytes() == (32 + 32 + 200 * 4));
Paul's avatar
Paul committed
76
77
}

Paul's avatar
Paul committed
78
TEST_CASE(unaligned)
Paul's avatar
Paul committed
79
{
Paul's avatar
Paul committed
80
    migraphx::program p;
Paul's avatar
Paul committed
81

82
83
84
    auto* mm = p.get_main_module();
    auto a1  = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1}}});
    auto p1  = mm->add_instruction(pass_op{}, a1);
Paul's avatar
Paul committed
85

86
87
88
89
90
    auto a2 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2}}});
    auto p2 = mm->add_instruction(pass_op{}, a2, p1);

    auto a3 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {200}}});
    mm->add_instruction(pass_op{}, a3, p2);
Paul's avatar
Paul committed
91

92
    run_pass(p, 1);
93
    EXPECT(p.get_output_shapes().back() == migraphx::shape{migraphx::shape::float_type, {200}});
Paul's avatar
Paul committed
94
    EXPECT(p.get_parameter_shape("memory").bytes() == (1 * 4 + 2 * 4 + 200 * 4));
Paul's avatar
Paul committed
95
96
}

Paul's avatar
Paul committed
97
TEST_CASE(float_aligned)
Paul's avatar
Paul committed
98
{
Paul's avatar
Paul committed
99
    migraphx::program p;
Paul's avatar
Paul committed
100

101
102
103
104
105
106
    auto* mm = p.get_main_module();
    auto a1  = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1}}});
    auto p1  = mm->add_instruction(pass_op{}, a1);

    auto a2 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2}}});
    auto p2 = mm->add_instruction(pass_op{}, a2, p1);
Paul's avatar
Paul committed
107

108
109
    auto a3 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {200}}});
    mm->add_instruction(pass_op{}, a3, p2);
Paul's avatar
Paul committed
110

111
    run_pass(p, 4);
112
    EXPECT(p.get_output_shapes().back() == migraphx::shape{migraphx::shape::float_type, {200}});
Paul's avatar
Paul committed
113
    EXPECT(p.get_parameter_shape("memory").bytes() == (1 * 4 + 2 * 4 + 200 * 4));
Paul's avatar
Paul committed
114
115
}

116
int main(int argc, const char* argv[]) { test::run(argc, argv); }