"src/targets/vscode:/vscode.git/clone" did not exist on "ee46bc9f41a50423cd135c0a27feccdd0e7495ed"
alexnet.cpp 8.11 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
#include <migraphx/operators.hpp>
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include "models.hpp"

namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {

migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
{
    migraphx::program p;
13
    auto* mm = p.get_main_module();
14
    auto m0 =
15
16
        mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {batch, 3, 224, 224}});
    auto mx0 = mm->add_literal(
17
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000}}, 0));
18
    auto mx1 = mm->add_literal(
19
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000, 4096}}, 1));
20
    auto mx2 = mm->add_literal(
21
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 2));
22
    auto mx3 = mm->add_literal(
23
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 4096}}, 3));
24
    auto mx4 = mm->add_literal(
25
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 4));
26
    auto mx5 = mm->add_literal(
27
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 9216}}, 5));
28
    auto mx6 = mm->add_literal(
29
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 6));
30
    auto mx7 = mm->add_literal(migraphx::generate_literal(
31
        migraphx::shape{migraphx::shape::float_type, {256, 256, 3, 3}}, 7));
32
    auto mx8 = mm->add_literal(
33
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 8));
34
    auto mx9  = mm->add_literal(migraphx::generate_literal(
35
        migraphx::shape{migraphx::shape::float_type, {256, 384, 3, 3}}, 9));
36
    auto mx10 = mm->add_literal(
37
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 10));
38
    auto mx11 = mm->add_literal(migraphx::generate_literal(
39
        migraphx::shape{migraphx::shape::float_type, {384, 192, 3, 3}}, 11));
40
    auto mx12 = mm->add_literal(
41
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 12));
42
    auto mx13 = mm->add_literal(migraphx::generate_literal(
43
        migraphx::shape{migraphx::shape::float_type, {192, 64, 5, 5}}, 13));
44
    auto mx14 = mm->add_literal(
45
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 14));
46
    auto mx15 = mm->add_literal(migraphx::generate_literal(
47
48
49
50
51
52
        migraphx::shape{migraphx::shape::float_type, {64, 3, 11, 11}}, 15));
    migraphx::op::convolution convolution16;
    convolution16.padding  = {2, 2};
    convolution16.stride   = {4, 4};
    convolution16.dilation = {1, 1};
    convolution16.group    = 1;
53
    auto mx16              = mm->add_instruction(convolution16, m0, mx15);
54
55
56
    migraphx::op::broadcast broadcast17;
    broadcast17.axis           = 1;
    broadcast17.broadcast_lens = {batch, 64, 55, 55};
57
    auto mx17                  = mm->add_instruction(broadcast17, mx14);
58
    migraphx::op::add add18;
59
    auto mx18 = mm->add_instruction(add18, mx16, mx17);
60
    migraphx::op::relu relu19;
61
    auto mx19 = mm->add_instruction(relu19, mx18);
62
63
64
65
66
    migraphx::op::pooling pooling20;
    pooling20.mode    = "max";
    pooling20.padding = {0, 0};
    pooling20.stride  = {2, 2};
    pooling20.lengths = {3, 3};
67
    auto mx20         = mm->add_instruction(pooling20, mx19);
68
69
70
71
72
    migraphx::op::convolution convolution21;
    convolution21.padding  = {2, 2};
    convolution21.stride   = {1, 1};
    convolution21.dilation = {1, 1};
    convolution21.group    = 1;
73
    auto mx21              = mm->add_instruction(convolution21, mx20, mx13);
74
75
76
    migraphx::op::broadcast broadcast22;
    broadcast22.axis           = 1;
    broadcast22.broadcast_lens = {batch, 192, 27, 27};
77
    auto mx22                  = mm->add_instruction(broadcast22, mx12);
78
    migraphx::op::add add23;
79
    auto mx23 = mm->add_instruction(add23, mx21, mx22);
80
    migraphx::op::relu relu24;
81
    auto mx24 = mm->add_instruction(relu24, mx23);
82
83
84
85
86
    migraphx::op::pooling pooling25;
    pooling25.mode    = "max";
    pooling25.padding = {0, 0};
    pooling25.stride  = {2, 2};
    pooling25.lengths = {3, 3};
87
    auto mx25         = mm->add_instruction(pooling25, mx24);
88
89
90
91
92
    migraphx::op::convolution convolution26;
    convolution26.padding  = {1, 1};
    convolution26.stride   = {1, 1};
    convolution26.dilation = {1, 1};
    convolution26.group    = 1;
93
    auto mx26              = mm->add_instruction(convolution26, mx25, mx11);
94
95
96
    migraphx::op::broadcast broadcast27;
    broadcast27.axis           = 1;
    broadcast27.broadcast_lens = {batch, 384, 13, 13};
97
    auto mx27                  = mm->add_instruction(broadcast27, mx10);
98
    migraphx::op::add add28;
99
    auto mx28 = mm->add_instruction(add28, mx26, mx27);
100
    migraphx::op::relu relu29;
101
    auto mx29 = mm->add_instruction(relu29, mx28);
102
103
104
105
106
    migraphx::op::convolution convolution30;
    convolution30.padding  = {1, 1};
    convolution30.stride   = {1, 1};
    convolution30.dilation = {1, 1};
    convolution30.group    = 1;
107
    auto mx30              = mm->add_instruction(convolution30, mx29, mx9);
108
109
110
    migraphx::op::broadcast broadcast31;
    broadcast31.axis           = 1;
    broadcast31.broadcast_lens = {batch, 256, 13, 13};
111
    auto mx31                  = mm->add_instruction(broadcast31, mx8);
112
    migraphx::op::add add32;
113
    auto mx32 = mm->add_instruction(add32, mx30, mx31);
114
    migraphx::op::relu relu33;
115
    auto mx33 = mm->add_instruction(relu33, mx32);
116
117
118
119
120
    migraphx::op::convolution convolution34;
    convolution34.padding  = {1, 1};
    convolution34.stride   = {1, 1};
    convolution34.dilation = {1, 1};
    convolution34.group    = 1;
121
    auto mx34              = mm->add_instruction(convolution34, mx33, mx7);
122
123
124
    migraphx::op::broadcast broadcast35;
    broadcast35.axis           = 1;
    broadcast35.broadcast_lens = {batch, 256, 13, 13};
125
    auto mx35                  = mm->add_instruction(broadcast35, mx6);
126
    migraphx::op::add add36;
127
    auto mx36 = mm->add_instruction(add36, mx34, mx35);
128
    migraphx::op::relu relu37;
129
    auto mx37 = mm->add_instruction(relu37, mx36);
130
131
132
133
134
    migraphx::op::pooling pooling38;
    pooling38.mode    = "max";
    pooling38.padding = {0, 0};
    pooling38.stride  = {2, 2};
    pooling38.lengths = {3, 3};
135
    auto mx38         = mm->add_instruction(pooling38, mx37);
136
137
    migraphx::op::flatten flatten39;
    flatten39.axis = 1;
138
    auto mx39      = mm->add_instruction(flatten39, mx38);
139
    migraphx::op::identity identity40;
140
    auto mx40 = mm->add_instruction(identity40, mx39);
141
142
    migraphx::op::transpose transpose41;
    transpose41.dims = {1, 0};
143
    auto mx41        = mm->add_instruction(transpose41, mx5);
144
145
    migraphx::op::multibroadcast multibroadcast42;
    multibroadcast42.output_lens = {batch, 4096};
146
    auto mx42                    = mm->add_instruction(multibroadcast42, mx4);
147
148
149
    migraphx::op::dot dot43;
    dot43.alpha = 1;
    dot43.beta  = 1;
150
    auto mx43   = mm->add_instruction(dot43, mx40, mx41, mx42);
151
    migraphx::op::relu relu44;
152
    auto mx44 = mm->add_instruction(relu44, mx43);
153
    migraphx::op::identity identity45;
154
    auto mx45 = mm->add_instruction(identity45, mx44);
155
156
    migraphx::op::transpose transpose46;
    transpose46.dims = {1, 0};
157
    auto mx46        = mm->add_instruction(transpose46, mx3);
158
159
    migraphx::op::multibroadcast multibroadcast47;
    multibroadcast47.output_lens = {batch, 4096};
160
    auto mx47                    = mm->add_instruction(multibroadcast47, mx2);
161
162
163
    migraphx::op::dot dot48;
    dot48.alpha = 1;
    dot48.beta  = 1;
164
    auto mx48   = mm->add_instruction(dot48, mx45, mx46, mx47);
165
    migraphx::op::relu relu49;
166
    auto mx49 = mm->add_instruction(relu49, mx48);
167
168
    migraphx::op::transpose transpose50;
    transpose50.dims = {1, 0};
169
    auto mx50        = mm->add_instruction(transpose50, mx1);
170
171
    migraphx::op::multibroadcast multibroadcast51;
    multibroadcast51.output_lens = {batch, 1000};
172
    auto mx51                    = mm->add_instruction(multibroadcast51, mx0);
173
174
175
    migraphx::op::dot dot52;
    dot52.alpha = 1;
    dot52.beta  = 1;
176
    mm->add_instruction(dot52, mx49, mx50, mx51);
177
178
179
180
181
182
    return p;
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx