auto_contiguous_test.cpp 3.84 KB
Newer Older
Paul's avatar
Paul committed
1
2
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/instruction.hpp>
3
#include <migraphx/pass_manager.hpp>
Paul's avatar
Paul committed
4
#include <basic_ops.hpp>
5
6
#include <migraphx/make_op.hpp>

Paul's avatar
Paul committed
7
8
#include <test.hpp>

9
10
11
12
void run_pass(migraphx::program& p)
{
    migraphx::run_passes(*p.get_main_module(), {migraphx::auto_contiguous{}});
}
Paul's avatar
Paul committed
13

Paul's avatar
Paul committed
14
// TODO: Add this test case
Paul's avatar
Paul committed
15
16
void literal_broadcast()
{
Paul's avatar
Paul committed
17
    migraphx::program p;
18
19
20

    auto* mm = p.get_main_module();
    mm->add_literal(get_2_broadcasted());
21
22
    EXPECT(not p.get_output_shapes().back().standard());
    EXPECT(p.get_output_shapes().back().broadcasted());
23
    run_pass(p);
24
25
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().broadcasted());
Paul's avatar
Paul committed
26
27
}

Paul's avatar
Paul committed
28
TEST_CASE(literal_transpose)
Paul's avatar
Paul committed
29
{
Paul's avatar
Paul committed
30
    migraphx::program p;
31
32
33

    auto* mm = p.get_main_module();
    mm->add_literal(get_2x2_transposed());
34
35
    EXPECT(not p.get_output_shapes().back().standard());
    EXPECT(p.get_output_shapes().back().transposed());
36
    run_pass(p);
37
38
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
Paul's avatar
Paul committed
39
40
}

Paul's avatar
Paul committed
41
TEST_CASE(after_literal_transpose)
Paul's avatar
Paul committed
42
{
Paul's avatar
Paul committed
43
    migraphx::program p;
44
45
46

    auto* mm = p.get_main_module();
    auto l   = mm->add_literal(get_2x2());
47
48
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
49
    auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
50
    mm->add_instruction(pass_op{}, t);
51
52
    EXPECT(not p.get_output_shapes().back().standard());
    EXPECT(p.get_output_shapes().back().transposed());
53
    run_pass(p);
54
55
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
Paul's avatar
Paul committed
56
57
}

Paul's avatar
Paul committed
58
TEST_CASE(after_literal_broadcast)
Paul's avatar
Paul committed
59
{
Paul's avatar
Paul committed
60
    migraphx::program p;
61
62
63
64

    auto* mm = p.get_main_module();
    auto l1  = mm->add_literal(get_2x2());
    auto l2  = mm->add_literal(get_2());
65
66
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().broadcasted());
67
68
    auto b = mm->add_instruction(
        migraphx::make_op("broadcast", {{"axis", 0}, {"dims", l1->get_shape().lens()}}), l2);
69
    mm->add_instruction(pass_op{}, b);
70
71
    EXPECT(not p.get_output_shapes().back().standard());
    EXPECT(p.get_output_shapes().back().broadcasted());
72
    run_pass(p);
73
74
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().broadcasted());
Paul's avatar
Paul committed
75
76
}

Paul's avatar
Paul committed
77
TEST_CASE(after_param_transpose)
Paul's avatar
Paul committed
78
{
Paul's avatar
Paul committed
79
    migraphx::program p;
80
81
82

    auto* mm = p.get_main_module();
    auto l   = mm->add_parameter("2x2", {migraphx::shape::float_type, {2, 2}});
83
84
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
85
    auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
86
    mm->add_instruction(pass_op{}, t);
87
88
    EXPECT(not p.get_output_shapes().back().standard());
    EXPECT(p.get_output_shapes().back().transposed());
89
    run_pass(p);
90
91
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
Paul's avatar
Paul committed
92
93
}

Paul's avatar
Paul committed
94
TEST_CASE(after_param_broadcast)
Paul's avatar
Paul committed
95
{
Paul's avatar
Paul committed
96
    migraphx::program p;
97
98
99
100

    auto* mm = p.get_main_module();
    auto l1  = mm->add_parameter("2x2", {migraphx::shape::float_type, {2, 2}});
    auto l2  = mm->add_parameter("2", {migraphx::shape::float_type, {2}});
101
102
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().broadcasted());
103
104
    auto b = mm->add_instruction(
        migraphx::make_op("broadcast", {{"axis", 0}, {"dims", l1->get_shape().lens()}}), l2);
105
    mm->add_instruction(pass_op{}, b);
106
107
    EXPECT(not p.get_output_shapes().back().standard());
    EXPECT(p.get_output_shapes().back().broadcasted());
108
    run_pass(p);
109
110
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().broadcasted());
Paul's avatar
Paul committed
111
112
}

Paul's avatar
Paul committed
113
int main(int argc, const char* argv[]) { test::run(argc, argv); }