alexnet.cpp 8.48 KB
Newer Older
1
2
3
#include <migraphx/operators.hpp>
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
4
#include <migraphx/apply_alpha_beta.hpp>
5
6
7
8
9
10
11
12
13
#include "models.hpp"

namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {

migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
{
    migraphx::program p;
14
    auto* mm = p.get_main_module();
15
    auto m0 =
16
17
        mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {batch, 3, 224, 224}});
    auto mx0 = mm->add_literal(
18
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000}}, 0));
19
    auto mx1 = mm->add_literal(
20
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000, 4096}}, 1));
21
    auto mx2 = mm->add_literal(
22
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 2));
23
    auto mx3 = mm->add_literal(
24
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 4096}}, 3));
25
    auto mx4 = mm->add_literal(
26
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 4));
27
    auto mx5 = mm->add_literal(
28
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 9216}}, 5));
29
    auto mx6 = mm->add_literal(
30
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 6));
31
    auto mx7 = mm->add_literal(migraphx::generate_literal(
32
        migraphx::shape{migraphx::shape::float_type, {256, 256, 3, 3}}, 7));
33
    auto mx8 = mm->add_literal(
34
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 8));
35
    auto mx9  = mm->add_literal(migraphx::generate_literal(
36
        migraphx::shape{migraphx::shape::float_type, {256, 384, 3, 3}}, 9));
37
    auto mx10 = mm->add_literal(
38
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 10));
39
    auto mx11 = mm->add_literal(migraphx::generate_literal(
40
        migraphx::shape{migraphx::shape::float_type, {384, 192, 3, 3}}, 11));
41
    auto mx12 = mm->add_literal(
42
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 12));
43
    auto mx13 = mm->add_literal(migraphx::generate_literal(
44
        migraphx::shape{migraphx::shape::float_type, {192, 64, 5, 5}}, 13));
45
    auto mx14 = mm->add_literal(
46
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 14));
47
    auto mx15 = mm->add_literal(migraphx::generate_literal(
48
49
50
51
52
53
        migraphx::shape{migraphx::shape::float_type, {64, 3, 11, 11}}, 15));
    migraphx::op::convolution convolution16;
    convolution16.padding  = {2, 2};
    convolution16.stride   = {4, 4};
    convolution16.dilation = {1, 1};
    convolution16.group    = 1;
54
    auto mx16              = mm->add_instruction(convolution16, m0, mx15);
55
56
57
    migraphx::op::broadcast broadcast17;
    broadcast17.axis           = 1;
    broadcast17.broadcast_lens = {batch, 64, 55, 55};
58
    auto mx17                  = mm->add_instruction(broadcast17, mx14);
59
    migraphx::op::add add18;
60
    auto mx18 = mm->add_instruction(add18, mx16, mx17);
61
    migraphx::op::relu relu19;
62
    auto mx19 = mm->add_instruction(relu19, mx18);
63
    migraphx::op::pooling pooling20;
64
    pooling20.mode    = migraphx::op::pooling_mode::max;
65
66
67
    pooling20.padding = {0, 0};
    pooling20.stride  = {2, 2};
    pooling20.lengths = {3, 3};
68
    auto mx20         = mm->add_instruction(pooling20, mx19);
69
70
71
72
73
    migraphx::op::convolution convolution21;
    convolution21.padding  = {2, 2};
    convolution21.stride   = {1, 1};
    convolution21.dilation = {1, 1};
    convolution21.group    = 1;
74
    auto mx21              = mm->add_instruction(convolution21, mx20, mx13);
75
76
77
    migraphx::op::broadcast broadcast22;
    broadcast22.axis           = 1;
    broadcast22.broadcast_lens = {batch, 192, 27, 27};
78
    auto mx22                  = mm->add_instruction(broadcast22, mx12);
79
    migraphx::op::add add23;
80
    auto mx23 = mm->add_instruction(add23, mx21, mx22);
81
    migraphx::op::relu relu24;
82
    auto mx24 = mm->add_instruction(relu24, mx23);
83
    migraphx::op::pooling pooling25;
84
    pooling25.mode    = migraphx::op::pooling_mode::max;
85
86
87
    pooling25.padding = {0, 0};
    pooling25.stride  = {2, 2};
    pooling25.lengths = {3, 3};
88
    auto mx25         = mm->add_instruction(pooling25, mx24);
89
90
91
92
93
    migraphx::op::convolution convolution26;
    convolution26.padding  = {1, 1};
    convolution26.stride   = {1, 1};
    convolution26.dilation = {1, 1};
    convolution26.group    = 1;
94
    auto mx26              = mm->add_instruction(convolution26, mx25, mx11);
95
96
97
    migraphx::op::broadcast broadcast27;
    broadcast27.axis           = 1;
    broadcast27.broadcast_lens = {batch, 384, 13, 13};
98
    auto mx27                  = mm->add_instruction(broadcast27, mx10);
99
    migraphx::op::add add28;
100
    auto mx28 = mm->add_instruction(add28, mx26, mx27);
101
    migraphx::op::relu relu29;
102
    auto mx29 = mm->add_instruction(relu29, mx28);
103
104
105
106
107
    migraphx::op::convolution convolution30;
    convolution30.padding  = {1, 1};
    convolution30.stride   = {1, 1};
    convolution30.dilation = {1, 1};
    convolution30.group    = 1;
108
    auto mx30              = mm->add_instruction(convolution30, mx29, mx9);
109
110
111
    migraphx::op::broadcast broadcast31;
    broadcast31.axis           = 1;
    broadcast31.broadcast_lens = {batch, 256, 13, 13};
112
    auto mx31                  = mm->add_instruction(broadcast31, mx8);
113
    migraphx::op::add add32;
114
    auto mx32 = mm->add_instruction(add32, mx30, mx31);
115
    migraphx::op::relu relu33;
116
    auto mx33 = mm->add_instruction(relu33, mx32);
117
118
119
120
121
    migraphx::op::convolution convolution34;
    convolution34.padding  = {1, 1};
    convolution34.stride   = {1, 1};
    convolution34.dilation = {1, 1};
    convolution34.group    = 1;
122
    auto mx34              = mm->add_instruction(convolution34, mx33, mx7);
123
124
125
    migraphx::op::broadcast broadcast35;
    broadcast35.axis           = 1;
    broadcast35.broadcast_lens = {batch, 256, 13, 13};
126
    auto mx35                  = mm->add_instruction(broadcast35, mx6);
127
    migraphx::op::add add36;
128
    auto mx36 = mm->add_instruction(add36, mx34, mx35);
129
    migraphx::op::relu relu37;
130
    auto mx37 = mm->add_instruction(relu37, mx36);
131
    migraphx::op::pooling pooling38;
132
    pooling38.mode    = migraphx::op::pooling_mode::max;
133
134
135
    pooling38.padding = {0, 0};
    pooling38.stride  = {2, 2};
    pooling38.lengths = {3, 3};
136
    auto mx38         = mm->add_instruction(pooling38, mx37);
137
138
    migraphx::op::flatten flatten39;
    flatten39.axis = 1;
139
    auto mx39      = mm->add_instruction(flatten39, mx38);
140
    migraphx::op::identity identity40;
141
    auto mx40 = mm->add_instruction(identity40, mx39);
142
143
    migraphx::op::transpose transpose41;
    transpose41.dims = {1, 0};
144
    auto mx41        = mm->add_instruction(transpose41, mx5);
145
146
    migraphx::op::multibroadcast multibroadcast42;
    multibroadcast42.output_lens = {batch, 4096};
147
    auto mx42                    = mm->add_instruction(multibroadcast42, mx4);
148
149
150
151
    float dot43_alpha            = 1;
    float dot43_beta             = 1;
    auto mx43                    = migraphx::add_apply_alpha_beta(
        *mm, {mx40, mx41, mx42}, migraphx::make_op("dot"), dot43_alpha, dot43_beta);
152
    migraphx::op::relu relu44;
153
    auto mx44 = mm->add_instruction(relu44, mx43);
154
    migraphx::op::identity identity45;
155
    auto mx45 = mm->add_instruction(identity45, mx44);
156
157
    migraphx::op::transpose transpose46;
    transpose46.dims = {1, 0};
158
    auto mx46        = mm->add_instruction(transpose46, mx3);
159
160
    migraphx::op::multibroadcast multibroadcast47;
    multibroadcast47.output_lens = {batch, 4096};
161
    auto mx47                    = mm->add_instruction(multibroadcast47, mx2);
162
163
164
165
    float dot48_alpha            = 1;
    float dot48_beta             = 1;
    auto mx48                    = migraphx::add_apply_alpha_beta(
        *mm, {mx45, mx46, mx47}, migraphx::make_op("dot"), dot48_alpha, dot48_beta);
166
    migraphx::op::relu relu49;
167
    auto mx49 = mm->add_instruction(relu49, mx48);
168
169
    migraphx::op::transpose transpose50;
    transpose50.dims = {1, 0};
170
    auto mx50        = mm->add_instruction(transpose50, mx1);
171
172
    migraphx::op::multibroadcast multibroadcast51;
    multibroadcast51.output_lens = {batch, 1000};
173
    auto mx51                    = mm->add_instruction(multibroadcast51, mx0);
174
175
176
177
    float dot52_alpha            = 1;
    float dot52_beta             = 1;
    migraphx::add_apply_alpha_beta(
        *mm, {mx49, mx50, mx51}, migraphx::make_op("dot"), dot52_alpha, dot52_beta);
178
179
180
181
182
183
    return p;
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx