auto_contiguous_test.cpp 3.77 KB
Newer Older
Paul's avatar
Paul committed
1
#include <migraphx/auto_contiguous.hpp>
2
3
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/broadcast.hpp>
Paul's avatar
Paul committed
4
#include <migraphx/instruction.hpp>
5
#include <migraphx/pass_manager.hpp>
Paul's avatar
Paul committed
6
7
8
#include <basic_ops.hpp>
#include <test.hpp>

9
10
11
12
void run_pass(migraphx::program& p)
{
    migraphx::run_passes(*p.get_main_module(), {migraphx::auto_contiguous{}});
}
Paul's avatar
Paul committed
13

Paul's avatar
Paul committed
14
// TODO: Add this test case
Paul's avatar
Paul committed
15
16
void literal_broadcast()
{
Paul's avatar
Paul committed
17
    migraphx::program p;
18
19
20

    auto* mm = p.get_main_module();
    mm->add_literal(get_2_broadcasted());
21
22
    EXPECT(not p.get_output_shapes().back().standard());
    EXPECT(p.get_output_shapes().back().broadcasted());
23
    run_pass(p);
24
25
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().broadcasted());
Paul's avatar
Paul committed
26
27
}

Paul's avatar
Paul committed
28
TEST_CASE(literal_transpose)
Paul's avatar
Paul committed
29
{
Paul's avatar
Paul committed
30
    migraphx::program p;
31
32
33

    auto* mm = p.get_main_module();
    mm->add_literal(get_2x2_transposed());
34
35
    EXPECT(not p.get_output_shapes().back().standard());
    EXPECT(p.get_output_shapes().back().transposed());
36
    run_pass(p);
37
38
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
Paul's avatar
Paul committed
39
40
}

Paul's avatar
Paul committed
41
TEST_CASE(after_literal_transpose)
Paul's avatar
Paul committed
42
{
Paul's avatar
Paul committed
43
    migraphx::program p;
44
45
46

    auto* mm = p.get_main_module();
    auto l   = mm->add_literal(get_2x2());
47
48
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
49
50
    auto t = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
    mm->add_instruction(pass_op{}, t);
51
52
    EXPECT(not p.get_output_shapes().back().standard());
    EXPECT(p.get_output_shapes().back().transposed());
53
    run_pass(p);
54
55
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
Paul's avatar
Paul committed
56
57
}

Paul's avatar
Paul committed
58
TEST_CASE(after_literal_broadcast)
Paul's avatar
Paul committed
59
{
Paul's avatar
Paul committed
60
    migraphx::program p;
61
62
63
64

    auto* mm = p.get_main_module();
    auto l1  = mm->add_literal(get_2x2());
    auto l2  = mm->add_literal(get_2());
65
66
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().broadcasted());
67
68
    auto b = mm->add_instruction(migraphx::op::broadcast{0, l1->get_shape().lens()}, l2);
    mm->add_instruction(pass_op{}, b);
69
70
    EXPECT(not p.get_output_shapes().back().standard());
    EXPECT(p.get_output_shapes().back().broadcasted());
71
    run_pass(p);
72
73
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().broadcasted());
Paul's avatar
Paul committed
74
75
}

Paul's avatar
Paul committed
76
TEST_CASE(after_param_transpose)
Paul's avatar
Paul committed
77
{
Paul's avatar
Paul committed
78
    migraphx::program p;
79
80
81

    auto* mm = p.get_main_module();
    auto l   = mm->add_parameter("2x2", {migraphx::shape::float_type, {2, 2}});
82
83
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
84
85
    auto t = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
    mm->add_instruction(pass_op{}, t);
86
87
    EXPECT(not p.get_output_shapes().back().standard());
    EXPECT(p.get_output_shapes().back().transposed());
88
    run_pass(p);
89
90
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
Paul's avatar
Paul committed
91
92
}

Paul's avatar
Paul committed
93
TEST_CASE(after_param_broadcast)
Paul's avatar
Paul committed
94
{
Paul's avatar
Paul committed
95
    migraphx::program p;
96
97
98
99

    auto* mm = p.get_main_module();
    auto l1  = mm->add_parameter("2x2", {migraphx::shape::float_type, {2, 2}});
    auto l2  = mm->add_parameter("2", {migraphx::shape::float_type, {2}});
100
101
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().broadcasted());
102
103
    auto b = mm->add_instruction(migraphx::op::broadcast{0, l1->get_shape().lens()}, l2);
    mm->add_instruction(pass_op{}, b);
104
105
    EXPECT(not p.get_output_shapes().back().standard());
    EXPECT(p.get_output_shapes().back().broadcasted());
106
    run_pass(p);
107
108
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().broadcasted());
Paul's avatar
Paul committed
109
110
}

Paul's avatar
Paul committed
111
int main(int argc, const char* argv[]) { test::run(argc, argv); }