auto_contiguous_test.cpp 3.59 KB
Newer Older
Paul's avatar
Paul committed
1
2
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/instruction.hpp>
3
#include <migraphx/pass_manager.hpp>
Paul's avatar
Paul committed
4
#include <basic_ops.hpp>
5
6
#include <migraphx/make_op.hpp>

Paul's avatar
Paul committed
7
8
#include <test.hpp>

Paul Fultz II's avatar
Paul Fultz II committed
9
void run_pass(migraphx::module& m) { migraphx::run_passes(m, {migraphx::auto_contiguous{}}); }
Paul's avatar
Paul committed
10

Paul's avatar
Paul committed
11
// TODO: Add this test case
Paul's avatar
Paul committed
12
13
void literal_broadcast()
{
Paul Fultz II's avatar
Paul Fultz II committed
14
    migraphx::module m;
15

Paul Fultz II's avatar
Paul Fultz II committed
16
17
18
19
20
21
    m.add_literal(get_2_broadcasted());
    EXPECT(not m.get_output_shapes().back().standard());
    EXPECT(m.get_output_shapes().back().broadcasted());
    run_pass(m);
    EXPECT(m.get_output_shapes().back().standard());
    EXPECT(not m.get_output_shapes().back().broadcasted());
Paul's avatar
Paul committed
22
23
}

Paul's avatar
Paul committed
24
TEST_CASE(literal_transpose)
Paul's avatar
Paul committed
25
{
Paul Fultz II's avatar
Paul Fultz II committed
26
    migraphx::module m;
27

Paul Fultz II's avatar
Paul Fultz II committed
28
29
30
31
32
33
    m.add_literal(get_2x2_transposed());
    EXPECT(not m.get_output_shapes().back().standard());
    EXPECT(m.get_output_shapes().back().transposed());
    run_pass(m);
    EXPECT(m.get_output_shapes().back().standard());
    EXPECT(not m.get_output_shapes().back().transposed());
Paul's avatar
Paul committed
34
35
}

Paul's avatar
Paul committed
36
TEST_CASE(after_literal_transpose)
Paul's avatar
Paul committed
37
{
Paul Fultz II's avatar
Paul Fultz II committed
38
    migraphx::module m;
39

Paul Fultz II's avatar
Paul Fultz II committed
40
41
42
    auto l = m.add_literal(get_2x2());
    EXPECT(m.get_output_shapes().back().standard());
    EXPECT(not m.get_output_shapes().back().transposed());
43
    auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
Paul Fultz II's avatar
Paul Fultz II committed
44
45
46
47
48
49
    m.add_instruction(pass_op{}, t);
    EXPECT(not m.get_output_shapes().back().standard());
    EXPECT(m.get_output_shapes().back().transposed());
    run_pass(m);
    EXPECT(m.get_output_shapes().back().standard());
    EXPECT(not m.get_output_shapes().back().transposed());
Paul's avatar
Paul committed
50
51
}

Paul's avatar
Paul committed
52
TEST_CASE(after_literal_broadcast)
Paul's avatar
Paul committed
53
{
Paul Fultz II's avatar
Paul Fultz II committed
54
    migraphx::module m;
55

Paul Fultz II's avatar
Paul Fultz II committed
56
57
58
59
60
    auto l1 = m.add_literal(get_2x2());
    auto l2 = m.add_literal(get_2());
    EXPECT(m.get_output_shapes().back().standard());
    EXPECT(not m.get_output_shapes().back().broadcasted());
    auto b = m.add_instruction(
61
        migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", l1->get_shape().lens()}}), l2);
Paul Fultz II's avatar
Paul Fultz II committed
62
63
64
65
66
67
    m.add_instruction(pass_op{}, b);
    EXPECT(not m.get_output_shapes().back().standard());
    EXPECT(m.get_output_shapes().back().broadcasted());
    run_pass(m);
    EXPECT(m.get_output_shapes().back().standard());
    EXPECT(not m.get_output_shapes().back().broadcasted());
Paul's avatar
Paul committed
68
69
}

Paul's avatar
Paul committed
70
TEST_CASE(after_param_transpose)
Paul's avatar
Paul committed
71
{
Paul Fultz II's avatar
Paul Fultz II committed
72
    migraphx::module m;
73

Paul Fultz II's avatar
Paul Fultz II committed
74
75
76
    auto l = m.add_parameter("2x2", {migraphx::shape::float_type, {2, 2}});
    EXPECT(m.get_output_shapes().back().standard());
    EXPECT(not m.get_output_shapes().back().transposed());
77
    auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
Paul Fultz II's avatar
Paul Fultz II committed
78
79
80
81
82
83
    m.add_instruction(pass_op{}, t);
    EXPECT(not m.get_output_shapes().back().standard());
    EXPECT(m.get_output_shapes().back().transposed());
    run_pass(m);
    EXPECT(m.get_output_shapes().back().standard());
    EXPECT(not m.get_output_shapes().back().transposed());
Paul's avatar
Paul committed
84
85
}

Paul's avatar
Paul committed
86
TEST_CASE(after_param_broadcast)
Paul's avatar
Paul committed
87
{
Paul Fultz II's avatar
Paul Fultz II committed
88
    migraphx::module m;
89

Paul Fultz II's avatar
Paul Fultz II committed
90
91
92
93
94
    auto l1 = m.add_parameter("2x2", {migraphx::shape::float_type, {2, 2}});
    auto l2 = m.add_parameter("2", {migraphx::shape::float_type, {2}});
    EXPECT(m.get_output_shapes().back().standard());
    EXPECT(not m.get_output_shapes().back().broadcasted());
    auto b = m.add_instruction(
95
        migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", l1->get_shape().lens()}}), l2);
Paul Fultz II's avatar
Paul Fultz II committed
96
97
98
99
100
101
    m.add_instruction(pass_op{}, b);
    EXPECT(not m.get_output_shapes().back().standard());
    EXPECT(m.get_output_shapes().back().broadcasted());
    run_pass(m);
    EXPECT(m.get_output_shapes().back().standard());
    EXPECT(not m.get_output_shapes().back().broadcasted());
Paul's avatar
Paul committed
102
103
}

Paul's avatar
Paul committed
104
int main(int argc, const char* argv[]) { test::run(argc, argv); }