alexnet.cpp 9.64 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
/*
 * The MIT License (MIT)
 *
 * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
24
25
26
#include <migraphx/operators.hpp>
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
27
#include <migraphx/apply_alpha_beta.hpp>
28
29
30
31
32
33
34
35
36
#include "models.hpp"

namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {

migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
{
    migraphx::program p;
37
    auto* mm = p.get_main_module();
38
    auto m0 =
39
40
        mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {batch, 3, 224, 224}});
    auto mx0 = mm->add_literal(
41
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000}}, 0));
42
    auto mx1 = mm->add_literal(
43
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000, 4096}}, 1));
44
    auto mx2 = mm->add_literal(
45
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 2));
46
    auto mx3 = mm->add_literal(
47
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 4096}}, 3));
48
    auto mx4 = mm->add_literal(
49
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 4));
50
    auto mx5 = mm->add_literal(
51
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 9216}}, 5));
52
    auto mx6 = mm->add_literal(
53
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 6));
54
    auto mx7 = mm->add_literal(migraphx::generate_literal(
55
        migraphx::shape{migraphx::shape::float_type, {256, 256, 3, 3}}, 7));
56
    auto mx8 = mm->add_literal(
57
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 8));
58
    auto mx9  = mm->add_literal(migraphx::generate_literal(
59
        migraphx::shape{migraphx::shape::float_type, {256, 384, 3, 3}}, 9));
60
    auto mx10 = mm->add_literal(
61
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 10));
62
    auto mx11 = mm->add_literal(migraphx::generate_literal(
63
        migraphx::shape{migraphx::shape::float_type, {384, 192, 3, 3}}, 11));
64
    auto mx12 = mm->add_literal(
65
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 12));
66
    auto mx13 = mm->add_literal(migraphx::generate_literal(
67
        migraphx::shape{migraphx::shape::float_type, {192, 64, 5, 5}}, 13));
68
    auto mx14 = mm->add_literal(
69
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 14));
70
    auto mx15 = mm->add_literal(migraphx::generate_literal(
71
72
73
74
75
76
        migraphx::shape{migraphx::shape::float_type, {64, 3, 11, 11}}, 15));
    migraphx::op::convolution convolution16;
    convolution16.padding  = {2, 2};
    convolution16.stride   = {4, 4};
    convolution16.dilation = {1, 1};
    convolution16.group    = 1;
77
    auto mx16              = mm->add_instruction(convolution16, m0, mx15);
78
79
80
    migraphx::op::broadcast broadcast17;
    broadcast17.axis           = 1;
    broadcast17.broadcast_lens = {batch, 64, 55, 55};
81
    auto mx17                  = mm->add_instruction(broadcast17, mx14);
82
    migraphx::op::add add18;
83
    auto mx18 = mm->add_instruction(add18, mx16, mx17);
84
    migraphx::op::relu relu19;
85
    auto mx19 = mm->add_instruction(relu19, mx18);
86
    migraphx::op::pooling pooling20;
87
    pooling20.mode    = migraphx::op::pooling_mode::max;
88
89
90
    pooling20.padding = {0, 0};
    pooling20.stride  = {2, 2};
    pooling20.lengths = {3, 3};
91
    auto mx20         = mm->add_instruction(pooling20, mx19);
92
93
94
95
96
    migraphx::op::convolution convolution21;
    convolution21.padding  = {2, 2};
    convolution21.stride   = {1, 1};
    convolution21.dilation = {1, 1};
    convolution21.group    = 1;
97
    auto mx21              = mm->add_instruction(convolution21, mx20, mx13);
98
99
100
    migraphx::op::broadcast broadcast22;
    broadcast22.axis           = 1;
    broadcast22.broadcast_lens = {batch, 192, 27, 27};
101
    auto mx22                  = mm->add_instruction(broadcast22, mx12);
102
    migraphx::op::add add23;
103
    auto mx23 = mm->add_instruction(add23, mx21, mx22);
104
    migraphx::op::relu relu24;
105
    auto mx24 = mm->add_instruction(relu24, mx23);
106
    migraphx::op::pooling pooling25;
107
    pooling25.mode    = migraphx::op::pooling_mode::max;
108
109
110
    pooling25.padding = {0, 0};
    pooling25.stride  = {2, 2};
    pooling25.lengths = {3, 3};
111
    auto mx25         = mm->add_instruction(pooling25, mx24);
112
113
114
115
116
    migraphx::op::convolution convolution26;
    convolution26.padding  = {1, 1};
    convolution26.stride   = {1, 1};
    convolution26.dilation = {1, 1};
    convolution26.group    = 1;
117
    auto mx26              = mm->add_instruction(convolution26, mx25, mx11);
118
119
120
    migraphx::op::broadcast broadcast27;
    broadcast27.axis           = 1;
    broadcast27.broadcast_lens = {batch, 384, 13, 13};
121
    auto mx27                  = mm->add_instruction(broadcast27, mx10);
122
    migraphx::op::add add28;
123
    auto mx28 = mm->add_instruction(add28, mx26, mx27);
124
    migraphx::op::relu relu29;
125
    auto mx29 = mm->add_instruction(relu29, mx28);
126
127
128
129
130
    migraphx::op::convolution convolution30;
    convolution30.padding  = {1, 1};
    convolution30.stride   = {1, 1};
    convolution30.dilation = {1, 1};
    convolution30.group    = 1;
131
    auto mx30              = mm->add_instruction(convolution30, mx29, mx9);
132
133
134
    migraphx::op::broadcast broadcast31;
    broadcast31.axis           = 1;
    broadcast31.broadcast_lens = {batch, 256, 13, 13};
135
    auto mx31                  = mm->add_instruction(broadcast31, mx8);
136
    migraphx::op::add add32;
137
    auto mx32 = mm->add_instruction(add32, mx30, mx31);
138
    migraphx::op::relu relu33;
139
    auto mx33 = mm->add_instruction(relu33, mx32);
140
141
142
143
144
    migraphx::op::convolution convolution34;
    convolution34.padding  = {1, 1};
    convolution34.stride   = {1, 1};
    convolution34.dilation = {1, 1};
    convolution34.group    = 1;
145
    auto mx34              = mm->add_instruction(convolution34, mx33, mx7);
146
147
148
    migraphx::op::broadcast broadcast35;
    broadcast35.axis           = 1;
    broadcast35.broadcast_lens = {batch, 256, 13, 13};
149
    auto mx35                  = mm->add_instruction(broadcast35, mx6);
150
    migraphx::op::add add36;
151
    auto mx36 = mm->add_instruction(add36, mx34, mx35);
152
    migraphx::op::relu relu37;
153
    auto mx37 = mm->add_instruction(relu37, mx36);
154
    migraphx::op::pooling pooling38;
155
    pooling38.mode    = migraphx::op::pooling_mode::max;
156
157
158
    pooling38.padding = {0, 0};
    pooling38.stride  = {2, 2};
    pooling38.lengths = {3, 3};
159
    auto mx38         = mm->add_instruction(pooling38, mx37);
160
161
    migraphx::op::flatten flatten39;
    flatten39.axis = 1;
162
    auto mx39      = mm->add_instruction(flatten39, mx38);
163
    migraphx::op::identity identity40;
164
    auto mx40 = mm->add_instruction(identity40, mx39);
165
166
    migraphx::op::transpose transpose41;
    transpose41.dims = {1, 0};
167
    auto mx41        = mm->add_instruction(transpose41, mx5);
168
169
    migraphx::op::multibroadcast multibroadcast42;
    multibroadcast42.output_lens = {batch, 4096};
170
    auto mx42                    = mm->add_instruction(multibroadcast42, mx4);
171
172
173
174
    float dot43_alpha            = 1;
    float dot43_beta             = 1;
    auto mx43                    = migraphx::add_apply_alpha_beta(
        *mm, {mx40, mx41, mx42}, migraphx::make_op("dot"), dot43_alpha, dot43_beta);
175
    migraphx::op::relu relu44;
176
    auto mx44 = mm->add_instruction(relu44, mx43);
177
    migraphx::op::identity identity45;
178
    auto mx45 = mm->add_instruction(identity45, mx44);
179
180
    migraphx::op::transpose transpose46;
    transpose46.dims = {1, 0};
181
    auto mx46        = mm->add_instruction(transpose46, mx3);
182
183
    migraphx::op::multibroadcast multibroadcast47;
    multibroadcast47.output_lens = {batch, 4096};
184
    auto mx47                    = mm->add_instruction(multibroadcast47, mx2);
185
186
187
188
    float dot48_alpha            = 1;
    float dot48_beta             = 1;
    auto mx48                    = migraphx::add_apply_alpha_beta(
        *mm, {mx45, mx46, mx47}, migraphx::make_op("dot"), dot48_alpha, dot48_beta);
189
    migraphx::op::relu relu49;
190
    auto mx49 = mm->add_instruction(relu49, mx48);
191
192
    migraphx::op::transpose transpose50;
    transpose50.dims = {1, 0};
193
    auto mx50        = mm->add_instruction(transpose50, mx1);
194
195
    migraphx::op::multibroadcast multibroadcast51;
    multibroadcast51.output_lens = {batch, 1000};
196
    auto mx51                    = mm->add_instruction(multibroadcast51, mx0);
197
198
199
200
    float dot52_alpha            = 1;
    float dot52_beta             = 1;
    migraphx::add_apply_alpha_beta(
        *mm, {mx49, mx50, mx51}, migraphx::make_op("dot"), dot52_alpha, dot52_beta);
201
202
203
204
205
206
    return p;
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx