alexnet.cpp 11.9 KB
Newer Older
1

2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
/*
 * The MIT License (MIT)
 *
 * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
25
#include <migraphx/make_op.hpp>
26
27
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
28
#include <migraphx/json.hpp>
29
30
31
32
33
34
35
#include "models.hpp"
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
{
    migraphx::program p;
36
37
38
39
40
41
42
    migraphx::module_ref mmain = p.get_main_module();
    auto x_main_module_0       = mmain->add_literal(migraphx::abs(
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1}}, 0)));
    auto x_main_module_1       = mmain->add_literal(migraphx::abs(
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1}}, 1)));
    auto x_main_module_2       = mmain->add_literal(migraphx::abs(
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1}}, 2)));
43
44
    auto x_data_0              = mmain->add_parameter(
        "data_0", migraphx::shape{migraphx::shape::float_type, {batch, 3, 224, 224}});
45
    auto x_main_module_4 = mmain->add_literal(
46
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000, 4096}}, 3));
47
    auto x_main_module_5 = mmain->add_literal(
48
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000}}, 4));
49
    auto x_main_module_6 = mmain->add_literal(
50
51
52
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 4096}}, 5));
    auto x_main_module_7 = mmain->add_literal(migraphx::abs(
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 6)));
53
    auto x_main_module_8 = mmain->add_literal(
54
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 9216}}, 7));
55
    auto x_main_module_9 = mmain->add_literal(
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 8));
    auto x_main_module_10 = mmain->add_literal(migraphx::generate_literal(
        migraphx::shape{migraphx::shape::float_type, {256, 192, 3, 3}}, 9));
    auto x_main_module_11 = mmain->add_literal(
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 10));
    auto x_main_module_12 = mmain->add_literal(migraphx::generate_literal(
        migraphx::shape{migraphx::shape::float_type, {384, 192, 3, 3}}, 11));
    auto x_main_module_13 = mmain->add_literal(
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 12));
    auto x_main_module_14 = mmain->add_literal(migraphx::generate_literal(
        migraphx::shape{migraphx::shape::float_type, {384, 256, 3, 3}}, 13));
    auto x_main_module_15 = mmain->add_literal(
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 14));
    auto x_main_module_16 = mmain->add_literal(migraphx::generate_literal(
        migraphx::shape{migraphx::shape::float_type, {256, 48, 5, 5}}, 15));
    auto x_main_module_17 = mmain->add_literal(migraphx::abs(
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 16)));
    auto x_main_module_18 = mmain->add_literal(migraphx::generate_literal(
        migraphx::shape{migraphx::shape::float_type, {96, 3, 11, 11}}, 17));
    auto x_main_module_19 = mmain->add_literal(
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {96}}, 18));
77
    auto x_main_module_20 = mmain->add_instruction(
78
        migraphx::make_json_op("convolution",
79
                               "{dilation:[1,1],group:1,padding:[0,0,0,0],padding_mode:0,stride:[4,"
80
                               "4],use_dynamic_same_auto_pad:0}"),
81
82
        x_data_0,
        x_main_module_18);
83
    auto x_main_module_21 = mmain->add_instruction(
84
        migraphx::make_json_op("broadcast", "{axis:1,out_lens:[1,96,54,54]}"), x_main_module_19);
85
86
87
88
    auto x_main_module_22 =
        mmain->add_instruction(migraphx::make_op("add"), x_main_module_20, x_main_module_21);
    auto x_main_module_23 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_22);
    auto x_main_module_24 = mmain->add_instruction(
89
90
91
        migraphx::make_json_op("lrn", "{alpha:9.999999747378752e-05,beta:0.75,bias:1.0,size:5}"),
        x_main_module_23);
    auto x_main_module_25 = mmain->add_instruction(
92
        migraphx::make_json_op(
93
            "pooling",
94
            "{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"),
95
96
        x_main_module_24);
    auto x_main_module_26 = mmain->add_instruction(
97
        migraphx::make_json_op("convolution",
98
                               "{dilation:[1,1],group:2,padding:[2,2,2,2],padding_mode:0,stride:[1,"
99
                               "1],use_dynamic_same_auto_pad:0}"),
100
101
102
103
104
105
106
107
108
109
110
        x_main_module_25,
        x_main_module_16);
    auto x_main_module_27 = mmain->add_instruction(
        migraphx::make_json_op("broadcast", "{axis:1,out_lens:[1,256,26,26]}"), x_main_module_17);
    auto x_main_module_28 =
        mmain->add_instruction(migraphx::make_op("add"), x_main_module_26, x_main_module_27);
    auto x_main_module_29 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_28);
    auto x_main_module_30 = mmain->add_instruction(
        migraphx::make_json_op("lrn", "{alpha:9.999999747378752e-05,beta:0.75,bias:1.0,size:5}"),
        x_main_module_29);
    auto x_main_module_31 = mmain->add_instruction(
111
        migraphx::make_json_op(
112
            "pooling",
113
            "{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"),
114
115
        x_main_module_30);
    auto x_main_module_32 = mmain->add_instruction(
116
117
118
        migraphx::make_json_op("convolution",
                               "{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,"
                               "1],use_dynamic_same_auto_pad:0}"),
119
120
121
122
123
124
125
126
        x_main_module_31,
        x_main_module_14);
    auto x_main_module_33 = mmain->add_instruction(
        migraphx::make_json_op("broadcast", "{axis:1,out_lens:[1,384,12,12]}"), x_main_module_15);
    auto x_main_module_34 =
        mmain->add_instruction(migraphx::make_op("add"), x_main_module_32, x_main_module_33);
    auto x_main_module_35 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_34);
    auto x_main_module_36 = mmain->add_instruction(
127
        migraphx::make_json_op("convolution",
128
                               "{dilation:[1,1],group:2,padding:[1,1,1,1],padding_mode:0,stride:[1,"
129
                               "1],use_dynamic_same_auto_pad:0}"),
130
131
132
133
134
135
136
137
        x_main_module_35,
        x_main_module_12);
    auto x_main_module_37 = mmain->add_instruction(
        migraphx::make_json_op("broadcast", "{axis:1,out_lens:[1,384,12,12]}"), x_main_module_13);
    auto x_main_module_38 =
        mmain->add_instruction(migraphx::make_op("add"), x_main_module_36, x_main_module_37);
    auto x_main_module_39 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_38);
    auto x_main_module_40 = mmain->add_instruction(
138
        migraphx::make_json_op("convolution",
139
                               "{dilation:[1,1],group:2,padding:[1,1,1,1],padding_mode:0,stride:[1,"
140
                               "1],use_dynamic_same_auto_pad:0}"),
141
142
143
144
145
146
147
148
        x_main_module_39,
        x_main_module_10);
    auto x_main_module_41 = mmain->add_instruction(
        migraphx::make_json_op("broadcast", "{axis:1,out_lens:[1,256,12,12]}"), x_main_module_11);
    auto x_main_module_42 =
        mmain->add_instruction(migraphx::make_op("add"), x_main_module_40, x_main_module_41);
    auto x_main_module_43 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_42);
    auto x_main_module_44 = mmain->add_instruction(
149
        migraphx::make_json_op(
150
            "pooling",
151
152
            "{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,1,1],stride:[2,2]}"),
        x_main_module_43);
153
    auto x_main_module_45 = mmain->add_instruction(
154
155
156
157
158
        migraphx::make_json_op("reshape", "{dims:[1,9216]}"), x_main_module_44);
    auto x_main_module_46 = mmain->add_instruction(
        migraphx::make_json_op("transpose", "{permutation:[1,0]}"), x_main_module_8);
    auto x_main_module_47 =
        mmain->add_instruction(migraphx::make_op("dot"), x_main_module_45, x_main_module_46);
159
    auto x_main_module_48 = mmain->add_instruction(
160
161
        migraphx::make_json_op("multibroadcast", "{out_lens:[1,4096]}"), x_main_module_9);
    auto x_main_module_49 = mmain->add_instruction(
162
163
        migraphx::make_json_op("multibroadcast", "{out_lens:[1,4096]}"), x_main_module_2);
    auto x_main_module_50 =
164
165
166
167
168
169
170
171
172
        mmain->add_instruction(migraphx::make_op("mul"), x_main_module_48, x_main_module_49);
    auto x_main_module_51 =
        mmain->add_instruction(migraphx::make_op("add"), x_main_module_47, x_main_module_50);
    auto x_main_module_52 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_51);
    auto x_main_module_53 = mmain->add_instruction(migraphx::make_op("identity"), x_main_module_52);
    auto x_main_module_54 = mmain->add_instruction(
        migraphx::make_json_op("transpose", "{permutation:[1,0]}"), x_main_module_6);
    auto x_main_module_55 =
        mmain->add_instruction(migraphx::make_op("dot"), x_main_module_53, x_main_module_54);
173
    auto x_main_module_56 = mmain->add_instruction(
174
175
        migraphx::make_json_op("multibroadcast", "{out_lens:[1,4096]}"), x_main_module_7);
    auto x_main_module_57 = mmain->add_instruction(
176
177
        migraphx::make_json_op("multibroadcast", "{out_lens:[1,4096]}"), x_main_module_1);
    auto x_main_module_58 =
178
179
180
181
182
        mmain->add_instruction(migraphx::make_op("mul"), x_main_module_56, x_main_module_57);
    auto x_main_module_59 =
        mmain->add_instruction(migraphx::make_op("add"), x_main_module_55, x_main_module_58);
    auto x_main_module_60 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_59);
    auto x_main_module_61 = mmain->add_instruction(migraphx::make_op("identity"), x_main_module_60);
183
    auto x_main_module_62 = mmain->add_instruction(
184
185
186
187
188
189
        migraphx::make_json_op("transpose", "{permutation:[1,0]}"), x_main_module_4);
    auto x_main_module_63 =
        mmain->add_instruction(migraphx::make_op("dot"), x_main_module_61, x_main_module_62);
    auto x_main_module_64 = mmain->add_instruction(
        migraphx::make_json_op("multibroadcast", "{out_lens:[1,1000]}"), x_main_module_5);
    auto x_main_module_65 = mmain->add_instruction(
190
        migraphx::make_json_op("multibroadcast", "{out_lens:[1,1000]}"), x_main_module_0);
191
192
193
194
195
196
197
    auto x_main_module_66 =
        mmain->add_instruction(migraphx::make_op("mul"), x_main_module_64, x_main_module_65);
    auto x_main_module_67 =
        mmain->add_instruction(migraphx::make_op("add"), x_main_module_63, x_main_module_66);
    auto x_main_module_68 =
        mmain->add_instruction(migraphx::make_json_op("softmax", "{axis:1}"), x_main_module_67);
    mmain->add_return({x_main_module_68});
198

199
200
201
202
203
    return p;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx