"example/15_grouped_gemm/CMakeLists.txt" did not exist on "9a8ee8a39a0aa6059c55faba05f6abb904fff6dd"
tf_test.cpp 27.8 KB
Newer Older
1
2
#include <iostream>
#include <vector>
Shucai Xiao's avatar
Shucai Xiao committed
3
#include <unordered_map>
4
#include <migraphx/literal.hpp>
Paul's avatar
Paul committed
5
6
7
8
#include <migraphx/pass_manager.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp>
9
10
11
12
13
14
#include <migraphx/operators.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/tf.hpp>
#include "test.hpp"

Shucai Xiao's avatar
Shucai Xiao committed
15
16
17
18
migraphx::program
parse_tf(const std::string& name,
         bool is_nhwc,
         const std::unordered_map<std::string, std::vector<std::size_t>>& dim_params = {})
19
{
Shucai Xiao's avatar
Shucai Xiao committed
20
    return migraphx::parse_tf(name, migraphx::tf_options{is_nhwc, 1, dim_params});
21
22
}

Paul's avatar
Paul committed
23
24
migraphx::program optimize_tf(const std::string& name, bool is_nhwc)
{
25
    auto prog = migraphx::parse_tf(name, migraphx::tf_options{is_nhwc, 1});
26
    auto* mm  = prog.get_main_module();
Paul's avatar
Paul committed
27
    if(is_nhwc)
28
        migraphx::run_passes(*mm,
Paul's avatar
Paul committed
29
30
31
                             {migraphx::simplify_reshapes{},
                              migraphx::dead_code_elimination{},
                              migraphx::eliminate_identity{}});
Paul's avatar
Paul committed
32
33
34
    return prog;
}

35
36
37
TEST_CASE(add_test)
{
    migraphx::program p;
38
39
40
41
42

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    mm->add_instruction(migraphx::op::add{}, l0, l1);
Paul's avatar
Paul committed
43
    auto prog = optimize_tf("add_test.pb", false);
44
45
46
47

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
48
49
50
51
52
53
54
55
56
57
58
TEST_CASE(addv2_test)
{
    migraphx::program p;
    auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    p.add_instruction(migraphx::op::add{}, l0, l1);
    auto prog = optimize_tf("addv2_test.pb", false);

    EXPECT(p == prog);
}

59
60
TEST_CASE(add_bcast_test)
{
Khalique's avatar
Khalique committed
61

62
    migraphx::program p;
63
64

    auto* mm = p.get_main_module();
65
    migraphx::shape s0{migraphx::shape::float_type, {2, 3}};
66
67
68
69
70
    auto l0 = mm->add_parameter("0", s0);
    auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2, 1}});
    auto l2 = mm->add_instruction(migraphx::op::multibroadcast{s0.lens()}, l0);
    auto l3 = mm->add_instruction(migraphx::op::multibroadcast{s0.lens()}, l1);
    mm->add_instruction(migraphx::op::add{}, l2, l3);
Paul's avatar
Paul committed
71
    auto prog = optimize_tf("add_bcast_test.pb", false);
72
73
74
75

    EXPECT(p == prog);
}

76
77
78
TEST_CASE(argmax_test)
{
    migraphx::program p;
79
80

    auto* mm = p.get_main_module();
Shucai Xiao's avatar
Shucai Xiao committed
81
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 5, 6, 7}});
82
83
84
    mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {2}});
    auto ins = mm->add_instruction(migraphx::op::argmax{2}, l0);
    mm->add_instruction(migraphx::op::squeeze{{2}}, ins);
Shucai Xiao's avatar
Shucai Xiao committed
85
    auto prog = parse_tf("argmax_test.pb", false, {{"0", {4, 5, 6, 7}}});
86
87
88
89
90
91
92

    EXPECT(p == prog);
}

TEST_CASE(argmin_test)
{
    migraphx::program p;
93
94
95
96
97
98

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
    mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {2}});
    auto ins = mm->add_instruction(migraphx::op::argmin{2}, l0);
    mm->add_instruction(migraphx::op::squeeze{{2}}, ins);
99
100
101
102
103
    auto prog = parse_tf("argmin_test.pb", false);

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
104
105
106
TEST_CASE(assert_less_equal_test)
{
    migraphx::program p;
107
108

    auto* mm = p.get_main_module();
Khalique's avatar
Khalique committed
109
    migraphx::shape s0{migraphx::shape::float_type, {2, 3}};
110
111
    auto l0 = mm->add_parameter("0", s0);
    auto l1 = mm->add_parameter("1", s0);
Khalique's avatar
Khalique committed
112
    migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {0, 1}};
113
114
115
116
    auto l2 = mm->add_literal(l);
    mm->add_instruction(migraphx::op::add{}, l0, l1);
    auto l3 = mm->add_instruction(migraphx::op::identity{}, l0, l1);
    mm->add_instruction(migraphx::op::identity{}, l3, l2);
Khalique's avatar
Khalique committed
117
118
119
120
121
    auto prog = optimize_tf("assert_less_equal_test.pb", false);

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
122
123
124
125
TEST_CASE(batchmatmul_test)
{
    migraphx::program p;

126
127
128
129
130
131
    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 4}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 4, 8}});

    auto trans_l0 = mm->add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l0);
    auto trans_l1 = mm->add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l1);
Khalique's avatar
Khalique committed
132

133
    mm->add_instruction(migraphx::op::dot{}, trans_l0, trans_l1);
Khalique's avatar
Khalique committed
134
135
136
137
138
    auto prog = optimize_tf("batchmatmul_test.pb", false);

    EXPECT(p == prog);
}

139
140
TEST_CASE(batchnorm_test)
{
Khalique's avatar
Khalique committed
141
142
    float epsilon  = 1.001e-5f;
    float momentum = 0.9f;
143
144

    migraphx::program p;
145
146

    auto* mm = p.get_main_module();
Khalique's avatar
Khalique committed
147
148
    migraphx::op::batch_norm_inference op{
        epsilon, momentum, migraphx::op::batch_norm_inference::spatial};
149
    migraphx::shape s0{migraphx::shape::float_type, {32}};
150
    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 32, 16, 16}});
151
152
    std::vector<float> const_vals(32);
    std::fill(const_vals.begin(), const_vals.end(), 1.0f);
Khalique's avatar
Khalique committed
153

154
155
156
157
158
    auto l2 = mm->add_parameter("2", s0);
    auto l3 = mm->add_parameter("3", s0);
    auto l4 = mm->add_parameter("4", s0);
    auto l1 = mm->add_literal(migraphx::literal{s0, const_vals});
    mm->add_instruction(op, l0, l1, l2, l3, l4);
Paul's avatar
Paul committed
159
    auto prog = optimize_tf("batchnorm_test.pb", true);
160
161
162
163

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
TEST_CASE(batchnormv3_test)
{
    float epsilon  = 1.0e-5f;
    float momentum = 0.9f;

    migraphx::program p;
    migraphx::op::batch_norm_inference op{
        epsilon, momentum, migraphx::op::batch_norm_inference::spatial};
    migraphx::shape s0{migraphx::shape::float_type, {32}};
    auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 32, 16, 16}});
    std::vector<float> const_vals(32);
    std::fill(const_vals.begin(), const_vals.end(), 1.0f);

    auto l2 = p.add_parameter("2", s0);
    auto l3 = p.add_parameter("3", s0);
    auto l4 = p.add_parameter("4", s0);
    auto l1 = p.add_literal(migraphx::literal{s0, const_vals});
    p.add_instruction(op, l0, l1, l2, l3, l4);
    auto prog = optimize_tf("batchnormv3_test.pb", true);

    EXPECT(p == prog);
}

187
188
189
TEST_CASE(biasadd_test)
{
    migraphx::program p;
190
191

    auto* mm = p.get_main_module();
192
    migraphx::shape s0{migraphx::shape::float_type, {1, 500, 1, 1}};
193
    uint64_t axis = 1;
194
195
196
197
    auto l0       = mm->add_parameter("0", s0);
    auto l1       = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {500}});
    auto l2       = mm->add_instruction(migraphx::op::broadcast{axis, l0->get_shape().lens()}, l1);
    mm->add_instruction(migraphx::op::add{}, l0, l2);
Paul's avatar
Paul committed
198
    auto prog = optimize_tf("biasadd_test.pb", true);
199
200
201
202

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
TEST_CASE(biasadd_scalar_test)
{
    migraphx::program p;

    auto* mm = p.get_main_module();
    migraphx::shape s0{migraphx::shape::float_type, {1, 1}};
    uint64_t axis = 1;
    auto l0       = mm->add_parameter("0", s0);
    auto l1       = mm->add_literal(
        migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}, {0}}, {1.0}});
    auto l2 = mm->add_instruction(migraphx::op::broadcast{axis, l0->get_shape().lens()}, l1);
    mm->add_instruction(migraphx::op::add{}, l0, l2);
    auto prog = optimize_tf("biasadd_scalar_test.pb", true);

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
220
221
222
TEST_CASE(cast_test)
{
    migraphx::program p;
223
224
225
226

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
    mm->add_instruction(migraphx::op::convert{migraphx::shape::int32_type}, l0);
Khalique's avatar
Khalique committed
227
228
229
230
231
    auto prog = optimize_tf("cast_test.pb", false);

    EXPECT(p == prog);
}

232
233
234
TEST_CASE(concat_test)
{
    migraphx::program p;
Khalique's avatar
Khalique committed
235

236
237
238
239
    auto* mm = p.get_main_module();

    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 7, 3}});
    auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}});
240
241
242

    int axis = 1;
    // tf uses axis as the third input, and it is in int32 format
Khalique's avatar
Khalique committed
243
    // add the literal using a vector in order to set stride to 1 (like in tf parser)
244
    mm->add_literal(migraphx::shape{migraphx::shape::int32_type}, std::vector<int>{axis});
245

246
    mm->add_instruction(migraphx::op::concat{axis}, l0, l1);
Paul's avatar
Paul committed
247
    auto prog = optimize_tf("concat_test.pb", false);
248
249
250
251
252
253
254

    EXPECT(p == prog);
}

TEST_CASE(const_test)
{
    migraphx::program p;
255
256
257

    auto* mm = p.get_main_module();
    mm->add_literal(migraphx::shape{migraphx::shape::float_type}, std::vector<float>{1.0f});
Paul's avatar
Paul committed
258
    auto prog = optimize_tf("constant_test.pb", false);
259
260
261
262

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
263
migraphx::program create_conv()
264
265
{
    migraphx::program p;
Khalique's avatar
Khalique committed
266

267
268
269
    auto* mm = p.get_main_module();

    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
Khalique's avatar
Khalique committed
270
    std::vector<float> weight_data(3 * 3 * 3 * 32);
271
    std::fill(weight_data.begin(), weight_data.end(), 1.0f);
Khalique's avatar
Khalique committed
272
    auto l1 =
273
        mm->add_literal(migraphx::shape{migraphx::shape::float_type, {3, 3, 3, 32}}, weight_data);
274
275
276

    migraphx::op::convolution op;
    op.padding_mode = migraphx::op::padding_mode_t::same;
Khalique's avatar
Khalique committed
277
    op.padding      = {1, 1};
Khalique's avatar
Khalique committed
278
279
    op.stride       = {1, 1};
    op.dilation     = {1, 1};
280
281
    auto l2         = mm->add_instruction(migraphx::op::transpose{{3, 2, 0, 1}}, l1);
    mm->add_instruction(op, l0, l2);
kahmed10's avatar
kahmed10 committed
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
    return p;
}

TEST_CASE(conv_test)
{
    migraphx::program p = create_conv();
    auto prog           = optimize_tf("conv_test.pb", true);

    EXPECT(p == prog);
}

TEST_CASE(conv_nchw_test)
{
    migraphx::program p = create_conv();
    auto prog           = optimize_tf("conv_nchw_test.pb", false);
297
298
299
300

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
301
302
303
304
TEST_CASE(depthwiseconv_test)
{
    migraphx::program p;

305
306
307
    auto* mm = p.get_main_module();

    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
Khalique's avatar
Khalique committed
308
309
310
    std::vector<float> weight_data(3 * 3 * 3 * 1);
    std::fill(weight_data.begin(), weight_data.end(), 1.0f);
    auto l1 =
311
        mm->add_literal(migraphx::shape{migraphx::shape::float_type, {3, 3, 3, 1}}, weight_data);
Khalique's avatar
Khalique committed
312
313
314

    migraphx::op::convolution op;
    op.padding_mode = migraphx::op::padding_mode_t::same;
Khalique's avatar
Khalique committed
315
    op.padding      = {1, 1};
Khalique's avatar
Khalique committed
316
317
318
    op.stride       = {1, 1};
    op.dilation     = {1, 1};
    op.group        = 3;
319
320
321
322
    auto l3         = mm->add_instruction(migraphx::op::transpose{{3, 2, 0, 1}}, l1);
    auto l4         = mm->add_instruction(migraphx::op::contiguous{}, l3);
    auto l5         = mm->add_instruction(migraphx::op::reshape{{3, 1, 3, 3}}, l4);
    mm->add_instruction(op, l0, l5);
Paul's avatar
Paul committed
323
    auto prog = optimize_tf("depthwise_conv_test.pb", true);
Khalique's avatar
Khalique committed
324
325
326
327

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
328
329
330
TEST_CASE(expanddims_test)
{
    migraphx::program p;
Khalique's avatar
Khalique committed
331

332
333
334
335
336
    auto* mm = p.get_main_module();

    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4}});
    mm->add_literal(0);
    mm->add_instruction(migraphx::op::reshape{{1, 2, 3, 4}}, l0);
337
    auto prog = optimize_tf("expanddims_test.pb", false);
Khalique's avatar
Khalique committed
338
339
340
341
342
343
344
345
346

    EXPECT(p == prog);
}

TEST_CASE(expanddims_test_neg_dims)
{
    // this check makes sure the pb parses negative dim value correctly
    migraphx::program p;

347
348
349
350
351
    auto* mm = p.get_main_module();

    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4}});
    mm->add_literal(-1);
    mm->add_instruction(migraphx::op::reshape{{2, 3, 4, 1}}, l0);
352
    auto prog = optimize_tf("expanddims_neg_test.pb", false);
Khalique's avatar
Khalique committed
353
354
355
356

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
357
358
359
360
TEST_CASE(gather_test)
{
    migraphx::program p;

361
362
363
364
365
366
    auto* mm = p.get_main_module();

    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 4}});
    auto l1 = mm->add_literal(
        migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 1}});
    mm->add_literal(1);
Khalique's avatar
Khalique committed
367
368

    int axis = 1;
369
    mm->add_instruction(migraphx::op::gather{axis}, l0, l1);
Khalique's avatar
Khalique committed
370
371
    auto prog = optimize_tf("gather_test.pb", false);

Khalique's avatar
Khalique committed
372
373
374
    EXPECT(p == prog);
}

375
376
377
TEST_CASE(identity_test)
{
    migraphx::program p;
378
379
380
381

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
    mm->add_instruction(migraphx::op::identity{}, l0);
Paul's avatar
Paul committed
382
    auto prog = optimize_tf("identity_test.pb", false);
383
384
385
386

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
387
388
389
TEST_CASE(matmul_test)
{
    migraphx::program p;
Khalique's avatar
Khalique committed
390

391
392
393
    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {8, 4}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 8}});
Khalique's avatar
Khalique committed
394

395
396
397
398
    auto trans_l0 = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l0);
    auto trans_l1 = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l1);

    mm->add_instruction(migraphx::op::dot{}, trans_l0, trans_l1);
Paul's avatar
Paul committed
399
    auto prog = optimize_tf("matmul_test.pb", false);
Khalique's avatar
Khalique committed
400
401
402
403

    EXPECT(p == prog);
}

404
405
406
TEST_CASE(mean_test)
{
    migraphx::program p;
407
408

    auto* mm = p.get_main_module();
Khalique's avatar
Khalique committed
409
    migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {2, 3}};
410
411
412
    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
    mm->add_literal(l);
    mm->add_literal(l);
Paul's avatar
Paul committed
413
    migraphx::op::reduce_mean op{{2, 3}};
414
415
416
    mm->add_instruction(op, l0);
    auto l3 = mm->add_instruction(op, l0);
    mm->add_instruction(migraphx::op::squeeze{{2, 3}}, l3);
Paul's avatar
Paul committed
417
    auto prog = optimize_tf("mean_test.pb", false);
418
419
420
421
422
423
424

    EXPECT(p == prog);
}

TEST_CASE(mean_test_nhwc)
{
    migraphx::program p;
425
426

    auto* mm = p.get_main_module();
Khalique's avatar
Khalique committed
427
    migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 2}};
428
429
    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
    auto l1 = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0);
Khalique's avatar
Khalique committed
430
    migraphx::op::reduce_mean op{{1, 2}};
431
432
    auto l2 = mm->add_instruction(op, l1);
    mm->add_instruction(migraphx::op::squeeze{{1, 2}}, l2);
Paul's avatar
Paul committed
433
    auto prog = optimize_tf("mean_test_nhwc.pb", true);
434
435
436
437

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
438
439
440
441
TEST_CASE(mul_test)
{
    migraphx::program p;

442
443
444
445
446
    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 16}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 16}});

    mm->add_instruction(migraphx::op::mul{}, l0, l1);
Paul's avatar
Paul committed
447
    auto prog = optimize_tf("mul_test.pb", false);
Khalique's avatar
Khalique committed
448
449
450
451

    EXPECT(p == prog);
}

452
453
454
TEST_CASE(onehot_test)
{
    migraphx::program p;
455
456
457

    auto* mm = p.get_main_module();
    auto l0  = mm->add_literal(
Khalique's avatar
Khalique committed
458
        migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {5}}, {1, 1, 1, 1, 1}});
459
460
461
462
    mm->add_literal(2);
    mm->add_literal(1.0f);
    mm->add_literal(0.0f);
    auto l1 = mm->add_literal(
Khalique's avatar
Khalique committed
463
        migraphx::literal{migraphx::shape{migraphx::shape::float_type, {2, 2}}, {1, 0, 0, 1}});
464
    int axis = 0;
465
    mm->add_instruction(migraphx::op::gather{axis}, l1, l0);
466
467
468
469
470
    auto prog = optimize_tf("onehot_test.pb", false);

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
471
472
473
474
475
476
477
478
TEST_CASE(noop_test)
{
    migraphx::program p;
    auto prog = optimize_tf("noop_test.pb", false);

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
479
480
481
TEST_CASE(pack_test)
{
    migraphx::program p;
482
483
484
485
486

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2}});
    auto l2  = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {2}});
Khalique's avatar
Khalique committed
487
488
489
490
    std::vector<migraphx::instruction_ref> args{l0, l1, l2};
    std::vector<migraphx::instruction_ref> unsqueezed_args;
    int64_t axis = 1;

Khalique's avatar
Khalique committed
491
492
493
494
    std::transform(args.begin(),
                   args.end(),
                   std::back_inserter(unsqueezed_args),
                   [&](migraphx::instruction_ref arg) {
495
                       return mm->add_instruction(migraphx::op::unsqueeze{{axis}}, arg);
Khalique's avatar
Khalique committed
496
                   });
497
    mm->add_instruction(migraphx::op::concat{static_cast<int>(axis)}, unsqueezed_args);
Paul's avatar
Paul committed
498
    auto prog = optimize_tf("pack_test.pb", false);
Khalique's avatar
Khalique committed
499
500
501
502

    EXPECT(p == prog);
}

503
504
505
TEST_CASE(pack_test_nhwc)
{
    migraphx::program p;
506
507
508
509
510
511
512
513

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
    auto lt0 = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0);
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
    auto lt1 = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l1);
    auto l2  = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
    auto lt2 = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l2);
Paul's avatar
Paul committed
514
    std::vector<migraphx::instruction_ref> args{lt0, lt1, lt2};
515
    std::vector<migraphx::instruction_ref> unsqueezed_args;
Paul's avatar
Paul committed
516
    int64_t nchw_axis = 3;
517
518
519
520
521

    std::transform(args.begin(),
                   args.end(),
                   std::back_inserter(unsqueezed_args),
                   [&](migraphx::instruction_ref arg) {
522
                       return mm->add_instruction(migraphx::op::unsqueeze{{nchw_axis}}, arg);
523
                   });
524
    mm->add_instruction(migraphx::op::concat{static_cast<int>(nchw_axis)}, unsqueezed_args);
Paul's avatar
Paul committed
525
    auto prog = optimize_tf("pack_test_nhwc.pb", true);
526
527
528
529

    EXPECT(p == prog);
}

530
531
532
TEST_CASE(pooling_test)
{
    migraphx::program p;
533
534
535

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
536
537
    migraphx::op::pooling avg_pool_op{"average"};
    migraphx::op::pooling max_pool_op{"max"};
Shucai Xiao's avatar
Shucai Xiao committed
538
539
540
541
    avg_pool_op.stride  = {2, 2};
    max_pool_op.stride  = {2, 2};
    avg_pool_op.lengths = {2, 2};
    max_pool_op.lengths = {2, 2};
542
    mm->add_instruction(max_pool_op, l0);
Paul's avatar
Paul committed
543
    auto prog = optimize_tf("pooling_test.pb", true);
544
545
546
547

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
548
549
550
TEST_CASE(pow_test)
{
    migraphx::program p;
551
552
553
554
555

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    mm->add_instruction(migraphx::op::pow{}, l0, l1);
Khalique's avatar
Khalique committed
556
557
558
559
560
    auto prog = optimize_tf("pow_test.pb", false);

    EXPECT(p == prog);
}

561
562
563
TEST_CASE(relu_test)
{
    migraphx::program p;
564
565
566
567

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
    mm->add_instruction(migraphx::op::relu{}, l0);
Paul's avatar
Paul committed
568
    auto prog = optimize_tf("relu_test.pb", false);
569
570
571
572

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
573
574
575
TEST_CASE(relu6_test)
{
    migraphx::program p;
576
577

    auto* mm = p.get_main_module();
kahmed10's avatar
kahmed10 committed
578
    std::vector<size_t> input_lens{1, 3, 16, 16};
579
580
581
582
583
584
    auto l0      = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, input_lens});
    auto min_val = mm->add_literal(0.0f);
    auto max_val = mm->add_literal(6.0f);
    min_val      = mm->add_instruction(migraphx::op::multibroadcast{input_lens}, min_val);
    max_val      = mm->add_instruction(migraphx::op::multibroadcast{input_lens}, max_val);
    mm->add_instruction(migraphx::op::clip{}, l0, min_val, max_val);
Paul's avatar
Paul committed
585
    auto prog = optimize_tf("relu6_test.pb", false);
Khalique's avatar
Khalique committed
586
587
588
589

    EXPECT(p == prog);
}

590
591
592
TEST_CASE(reshape_test)
{
    migraphx::program p;
593
594
595

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {16}});
596
597
    migraphx::shape s0{migraphx::shape::int32_type, {4}};
    // in tf, the second arg is a literal that contains new dimensions
598
599
    mm->add_literal(migraphx::literal{s0, {1, 1, 1, 16}});
    mm->add_instruction(migraphx::op::reshape{{1, 1, 1, 16}}, l0);
Paul's avatar
Paul committed
600
    auto prog = optimize_tf("reshape_test.pb", false);
601
602
603
604

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
605
606
607
TEST_CASE(rsqrt_test)
{
    migraphx::program p;
608
609
610
611

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
    mm->add_instruction(migraphx::op::rsqrt{}, l0);
Khalique's avatar
Khalique committed
612
613
614
615
616
    auto prog = optimize_tf("rsqrt_test.pb", false);

    EXPECT(p == prog);
}

617
618
619
TEST_CASE(shape_test)
{
    migraphx::program p;
620
621
622
623

    auto* mm = p.get_main_module();
    mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
    mm->add_literal(
624
625
626
627
628
629
        migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {4}}, {1, 3, 16, 16}});
    auto prog = optimize_tf("shape_test.pb", false);

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
630
631
632
TEST_CASE(slice_test)
{
    migraphx::program p;
633
634

    auto* mm             = p.get_main_module();
Khalique's avatar
Khalique committed
635
    std::size_t num_axes = 2;
636
    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 10}});
Khalique's avatar
Khalique committed
637
    migraphx::shape s0{migraphx::shape::int32_type, {num_axes}};
638
639
    mm->add_literal(migraphx::literal{s0, {1, 0}});
    mm->add_literal(migraphx::literal{s0, {2, -1}});
Khalique's avatar
Khalique committed
640
641
642
643
644
645

    migraphx::op::slice op;
    op.starts = {1, 0};
    op.ends   = {3, 10};
    op.axes   = std::vector<int64_t>(num_axes);
    std::iota(op.axes.begin(), op.axes.end(), 0);
646
    mm->add_instruction(op, l0);
Khalique's avatar
Khalique committed
647
648
649
650
651
    auto prog = optimize_tf("slice_test.pb", false);

    EXPECT(p == prog);
}

652
653
654
TEST_CASE(softmax_test)
{
    migraphx::program p;
655
656
657
658

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3}});
    mm->add_instruction(migraphx::op::softmax{1}, l0);
Paul's avatar
Paul committed
659
    auto prog = optimize_tf("softmax_test.pb", false);
660
661
662
663

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
664
665
666
TEST_CASE(split_test)
{
    migraphx::program p;
667
668

    auto* mm = p.get_main_module();
kahmed10's avatar
kahmed10 committed
669
    std::vector<int64_t> axes{0, 1};
670
671
672
673
674
675
676
677
678
679
    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}});
    mm->add_literal(3); // num_splits
    mm->add_literal(1); // split axis
    mm->add_literal(1); // concat axis
    mm->add_literal(1); // concat axis
    auto l1 = mm->add_instruction(migraphx::op::slice{axes, {0, 0}, {5, 10}}, l0);
    auto l2 = mm->add_instruction(migraphx::op::slice{axes, {0, 10}, {5, 20}}, l0);
    auto l3 = mm->add_instruction(migraphx::op::slice{axes, {0, 20}, {5, 30}}, l0);
    mm->add_instruction(migraphx::op::concat{1}, l1, l2);
    mm->add_instruction(migraphx::op::concat{1}, l2, l3);
kahmed10's avatar
kahmed10 committed
680

681
    auto prog = parse_tf("split_test.pb", false);
kahmed10's avatar
kahmed10 committed
682
683
684
685
686
687
688

    EXPECT(p == prog);
}

TEST_CASE(split_test_one_output)
{
    migraphx::program p;
689
690
691
692
693
694

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}});
    mm->add_literal(1); // num_splits
    mm->add_literal(1); // split axis
    mm->add_instruction(migraphx::op::identity{}, l0);
kahmed10's avatar
kahmed10 committed
695

696
    auto prog = parse_tf("split_test_one_output.pb", false);
kahmed10's avatar
kahmed10 committed
697
698
699
700
701
702
703

    EXPECT(p == prog);
}

TEST_CASE(split_test_vector_as_input)
{
    migraphx::program p;
704
705

    auto* mm = p.get_main_module();
kahmed10's avatar
kahmed10 committed
706
    std::vector<int64_t> axes{0, 1};
707
    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}});
kahmed10's avatar
kahmed10 committed
708
    // split sizes
709
    mm->add_literal(
kahmed10's avatar
kahmed10 committed
710
        migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {3}}, {4, 15, 11}});
711
712
713
714
715
716
717
718
    mm->add_literal(1); // split axis
    mm->add_literal(1); // concat axis
    mm->add_literal(1); // concat axis
    auto l1 = mm->add_instruction(migraphx::op::slice{axes, {0, 0}, {5, 4}}, l0);
    auto l2 = mm->add_instruction(migraphx::op::slice{axes, {0, 4}, {5, 19}}, l0);
    auto l3 = mm->add_instruction(migraphx::op::slice{axes, {0, 19}, {5, 30}}, l0);
    mm->add_instruction(migraphx::op::concat{1}, l1, l2);
    mm->add_instruction(migraphx::op::concat{1}, l2, l3);
kahmed10's avatar
kahmed10 committed
719

720
    auto prog = parse_tf("split_test_vector_as_input.pb", false);
kahmed10's avatar
kahmed10 committed
721
722
723
724

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
725
726
727
TEST_CASE(sqdiff_test)
{
    migraphx::program p;
728
729
730
731
732

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    mm->add_instruction(migraphx::op::sqdiff{}, l0, l1);
Khalique's avatar
Khalique committed
733
734
735
736
737
    auto prog = optimize_tf("sqdiff_test.pb", false);

    EXPECT(p == prog);
}

738
739
740
TEST_CASE(squeeze_test)
{
    migraphx::program p;
741
742
743
744

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 1}});
    mm->add_instruction(migraphx::op::squeeze{{0, 3}}, l0);
Paul's avatar
Paul committed
745
    auto prog = optimize_tf("squeeze_test.pb", false);
746
747
748

    EXPECT(p == prog);
}
Khalique's avatar
Khalique committed
749

Khalique's avatar
Khalique committed
750
751
752
TEST_CASE(stopgradient_test)
{
    migraphx::program p;
753
754
755
756

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
    mm->add_instruction(migraphx::op::identity{}, l0);
Khalique's avatar
Khalique committed
757
758
    auto prog = optimize_tf("stopgradient_test.pb", false);

Khalique's avatar
Khalique committed
759
    EXPECT(p == prog);
Khalique's avatar
Khalique committed
760
761
}

Khalique's avatar
Khalique committed
762
763
764
TEST_CASE(stridedslice_test)
{
    migraphx::program p;
765
766
767
768

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 1, 1}});
    auto l1  = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0);
Khalique's avatar
Khalique committed
769
770
    std::size_t num_axes = 4;
    migraphx::op::slice op;
Khalique's avatar
Khalique committed
771
    op.starts = {0, 0, 0, 0};
Paul's avatar
Paul committed
772
    op.ends   = {1, 1, 1, 5};
Khalique's avatar
Khalique committed
773
774
    op.axes   = std::vector<int64_t>(num_axes);
    std::iota(op.axes.begin(), op.axes.end(), 0);
775
    auto l2          = mm->add_instruction(op, l1);
Paul's avatar
Paul committed
776
    auto shrink_axis = 1;
777
    mm->add_instruction(migraphx::op::squeeze{{shrink_axis}}, l2);
Paul's avatar
Paul committed
778
    auto prog = optimize_tf("stridedslice_test.pb", true);
Khalique's avatar
Khalique committed
779
780
781
782

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
783
784
785
TEST_CASE(stridedslice_masks_test)
{
    migraphx::program p;
786
787
788

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 3, 3}});
Khalique's avatar
Khalique committed
789
790
    std::size_t num_axes = 4;
    migraphx::op::slice op;
791
792
    op.starts = {0, 1, 1, 0};
    op.ends   = {1, 3, 3, 10};
Khalique's avatar
Khalique committed
793
794
795
    op.axes   = std::vector<int64_t>(num_axes);
    std::iota(op.axes.begin(), op.axes.end(), 0);
    // add literals for starts, ends, and strides in tf (NHWC format)
796
797
798
799
800
801
802
803
804
805
    mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}},
                    std::vector<int>{0, 1, 1, 0});
    mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}},
                    std::vector<int>{0, 0, 0, 0});
    mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}},
                    std::vector<int>{1, 1, 1, 1});

    auto l1 = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0);
    auto l2 = mm->add_instruction(op, l1);
    mm->add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l2);
806
    auto prog = parse_tf("stridedslice_masks_test.pb", true);
Khalique's avatar
Khalique committed
807
808
809
810

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
811
812
813
TEST_CASE(sub_test)
{
    migraphx::program p;
814
815
816
817
818

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    mm->add_instruction(migraphx::op::sub{}, l0, l1);
819
    auto prog = parse_tf("sub_test.pb", false);
Khalique's avatar
Khalique committed
820
821
822
823

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
824
825
826
TEST_CASE(tanh_test)
{
    migraphx::program p;
827
828
829
830
831

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    mm->add_instruction(migraphx::op::sub{}, l0, l1);
832
    auto prog = parse_tf("sub_test.pb", false);
Khalique's avatar
Khalique committed
833
834
835
836

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
837
838
839
TEST_CASE(transpose_test)
{
    migraphx::program p;
840
841
842

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
Khalique's avatar
Khalique committed
843
    migraphx::shape s0{migraphx::shape::int32_type, {4}};
844
845
    mm->add_literal(migraphx::literal{s0, {0, 2, 3, 1}});
    mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0);
Khalique's avatar
Khalique committed
846
847
848
849
850
    auto prog = optimize_tf("transpose_test.pb", false);

    EXPECT(p == prog);
}

851
852
853
TEST_CASE(variable_batch_test)
{
    migraphx::program p;
854
855
856
857

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
    mm->add_instruction(migraphx::op::identity{}, l0);
858
859
860
861
862
    auto prog = optimize_tf("variable_batch_test.pb", false);

    EXPECT(p == prog);
}

863
int main(int argc, const char* argv[]) { test::run(argc, argv); }