tf_test.cpp 32.4 KB
Newer Older
1
2
#include <iostream>
#include <vector>
Shucai Xiao's avatar
Shucai Xiao committed
3
#include <unordered_map>
4
#include <migraphx/literal.hpp>
Paul's avatar
Paul committed
5
6
7
8
#include <migraphx/pass_manager.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp>
9
10
11
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/tf.hpp>
12
#include <migraphx/make_op.hpp>
turneram's avatar
turneram committed
13
14
15
16
17
#include <migraphx/op/batch_norm_inference.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/reduce_mean.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/slice.hpp>
18
19
20

#include <migraphx/serialize.hpp>

21
22
#include "test.hpp"

Shucai Xiao's avatar
Shucai Xiao committed
23
24
25
migraphx::program
parse_tf(const std::string& name,
         bool is_nhwc,
kahmed10's avatar
kahmed10 committed
26
27
         const std::unordered_map<std::string, std::vector<std::size_t>>& dim_params = {},
         const std::vector<std::string>& output_node_names                           = {})
28
{
kahmed10's avatar
kahmed10 committed
29
30
    return migraphx::parse_tf(name,
                              migraphx::tf_options{is_nhwc, 1, dim_params, output_node_names});
31
32
}

Paul's avatar
Paul committed
33
34
migraphx::program optimize_tf(const std::string& name, bool is_nhwc)
{
35
    auto prog = migraphx::parse_tf(name, migraphx::tf_options{is_nhwc, 1});
36
    auto* mm  = prog.get_main_module();
Paul's avatar
Paul committed
37
    if(is_nhwc)
38
        migraphx::run_passes(*mm,
Paul's avatar
Paul committed
39
40
41
                             {migraphx::simplify_reshapes{},
                              migraphx::dead_code_elimination{},
                              migraphx::eliminate_identity{}});
kahmed10's avatar
kahmed10 committed
42
43
44
45

    // remove the last return instruction
    auto last_ins = std::prev(mm->end());
    if(last_ins != mm->end())
kahmed10's avatar
kahmed10 committed
46
    {
kahmed10's avatar
kahmed10 committed
47
48
49
50
        if(last_ins->name() == "@return")
        {
            mm->remove_instruction(last_ins);
        }
kahmed10's avatar
kahmed10 committed
51
    }
Paul's avatar
Paul committed
52
53
54
    return prog;
}

55
56
57
TEST_CASE(add_test)
{
    migraphx::program p;
58
59
60
61

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
62
    mm->add_instruction(migraphx::make_op("add"), l0, l1);
Paul's avatar
Paul committed
63
    auto prog = optimize_tf("add_test.pb", false);
64
65
66
67

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
68
69
70
TEST_CASE(addv2_test)
{
    migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
71
72
73
74
    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    mm->add_instruction(migraphx::make_op("add"), l0, l1);
kahmed10's avatar
kahmed10 committed
75
76
77
78
79
    auto prog = optimize_tf("addv2_test.pb", false);

    EXPECT(p == prog);
}

80
81
TEST_CASE(add_bcast_test)
{
Khalique's avatar
Khalique committed
82

83
    migraphx::program p;
84
85

    auto* mm = p.get_main_module();
86
    migraphx::shape s0{migraphx::shape::float_type, {2, 3}};
87
88
    auto l0 = mm->add_parameter("0", s0);
    auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2, 1}});
89
90
    auto l2 =
        mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", s0.lens()}}), l1);
kahmed10's avatar
kahmed10 committed
91
    mm->add_instruction(migraphx::make_op("add"), l0, l2);
Paul's avatar
Paul committed
92
    auto prog = optimize_tf("add_bcast_test.pb", false);
93
94
95
96

    EXPECT(p == prog);
}

97
98
99
TEST_CASE(argmax_test)
{
    migraphx::program p;
100
101

    auto* mm = p.get_main_module();
Shucai Xiao's avatar
Shucai Xiao committed
102
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 5, 6, 7}});
103
    mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {2}});
104
    auto ins = mm->add_instruction(migraphx::make_op("argmax", {{"axis", 2}}), l0);
kahmed10's avatar
kahmed10 committed
105
106
    auto l1  = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), ins);
    mm->add_return({l1});
Shucai Xiao's avatar
Shucai Xiao committed
107
    auto prog = parse_tf("argmax_test.pb", false, {{"0", {4, 5, 6, 7}}});
108
109
110
111
112
113
114

    EXPECT(p == prog);
}

TEST_CASE(argmin_test)
{
    migraphx::program p;
115
116
117
118

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
    mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {2}});
119
    auto ins = mm->add_instruction(migraphx::make_op("argmin", {{"axis", 2}}), l0);
kahmed10's avatar
kahmed10 committed
120
121
    auto l1  = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), ins);
    mm->add_return({l1});
122
123
124
125
126
    auto prog = parse_tf("argmin_test.pb", false);

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
127
128
129
TEST_CASE(assert_less_equal_test)
{
    migraphx::program p;
130
131

    auto* mm = p.get_main_module();
Khalique's avatar
Khalique committed
132
    migraphx::shape s0{migraphx::shape::float_type, {2, 3}};
133
134
    auto l0 = mm->add_parameter("0", s0);
    auto l1 = mm->add_parameter("1", s0);
Khalique's avatar
Khalique committed
135
    migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {0, 1}};
136
    auto l2 = mm->add_literal(l);
137
138
139
    mm->add_instruction(migraphx::make_op("add"), l0, l1);
    auto l3 = mm->add_instruction(migraphx::make_op("identity"), l0, l1);
    mm->add_instruction(migraphx::make_op("identity"), l3, l2);
Khalique's avatar
Khalique committed
140
141
142
143
144
    auto prog = optimize_tf("assert_less_equal_test.pb", false);

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
145
146
147
148
TEST_CASE(batchmatmul_test)
{
    migraphx::program p;

149
150
151
152
    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 4}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 4, 8}});

153
154
155
156
    auto trans_l0 =
        mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l0);
    auto trans_l1 =
        mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l1);
Khalique's avatar
Khalique committed
157

158
    mm->add_instruction(migraphx::make_op("dot"), trans_l0, trans_l1);
Khalique's avatar
Khalique committed
159
160
161
162
163
    auto prog = optimize_tf("batchmatmul_test.pb", false);

    EXPECT(p == prog);
}

164
165
TEST_CASE(batchnorm_test)
{
Khalique's avatar
Khalique committed
166
167
    float epsilon  = 1.001e-5f;
    float momentum = 0.9f;
168
169

    migraphx::program p;
170
171

    auto* mm = p.get_main_module();
Khalique's avatar
Khalique committed
172
173
    migraphx::op::batch_norm_inference op{
        epsilon, momentum, migraphx::op::batch_norm_inference::spatial};
174
    migraphx::shape s0{migraphx::shape::float_type, {32}};
175
    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 32, 16, 16}});
176
177
    std::vector<float> const_vals(32);
    std::fill(const_vals.begin(), const_vals.end(), 1.0f);
Khalique's avatar
Khalique committed
178

179
180
181
182
183
    auto l2 = mm->add_parameter("2", s0);
    auto l3 = mm->add_parameter("3", s0);
    auto l4 = mm->add_parameter("4", s0);
    auto l1 = mm->add_literal(migraphx::literal{s0, const_vals});
    mm->add_instruction(op, l0, l1, l2, l3, l4);
Paul's avatar
Paul committed
184
    auto prog = optimize_tf("batchnorm_test.pb", true);
185
186
187
188

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
189
190
191
192
193
194
TEST_CASE(batchnormv3_test)
{
    float epsilon  = 1.0e-5f;
    float momentum = 0.9f;

    migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
195
    auto* mm = p.get_main_module();
kahmed10's avatar
kahmed10 committed
196
197
198
    migraphx::op::batch_norm_inference op{
        epsilon, momentum, migraphx::op::batch_norm_inference::spatial};
    migraphx::shape s0{migraphx::shape::float_type, {32}};
Shucai Xiao's avatar
Shucai Xiao committed
199
    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 32, 16, 16}});
kahmed10's avatar
kahmed10 committed
200
201
202
    std::vector<float> const_vals(32);
    std::fill(const_vals.begin(), const_vals.end(), 1.0f);

Shucai Xiao's avatar
Shucai Xiao committed
203
204
205
206
207
    auto l2 = mm->add_parameter("2", s0);
    auto l3 = mm->add_parameter("3", s0);
    auto l4 = mm->add_parameter("4", s0);
    auto l1 = mm->add_literal(migraphx::literal{s0, const_vals});
    mm->add_instruction(op, l0, l1, l2, l3, l4);
kahmed10's avatar
kahmed10 committed
208
209
210
211
212
    auto prog = optimize_tf("batchnormv3_test.pb", true);

    EXPECT(p == prog);
}

213
214
215
TEST_CASE(biasadd_test)
{
    migraphx::program p;
216
217

    auto* mm = p.get_main_module();
218
    migraphx::shape s0{migraphx::shape::float_type, {1, 500, 1, 1}};
219
    uint64_t axis = 1;
220
221
    auto l0       = mm->add_parameter("0", s0);
    auto l1       = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {500}});
222
223
224
    auto l2       = mm->add_instruction(
        migraphx::make_op("broadcast", {{"axis", axis}, {"dims", l0->get_shape().lens()}}), l1);
    mm->add_instruction(migraphx::make_op("add"), l0, l2);
Paul's avatar
Paul committed
225
    auto prog = optimize_tf("biasadd_test.pb", true);
226
227
228
229

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
230
231
232
233
234
235
236
237
238
239
TEST_CASE(biasadd_scalar_test)
{
    migraphx::program p;

    auto* mm = p.get_main_module();
    migraphx::shape s0{migraphx::shape::float_type, {1, 1}};
    uint64_t axis = 1;
    auto l0       = mm->add_parameter("0", s0);
    auto l1       = mm->add_literal(
        migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}, {0}}, {1.0}});
240
241
242
    auto l2 = mm->add_instruction(
        migraphx::make_op("broadcast", {{"axis", axis}, {"dims", l0->get_shape().lens()}}), l1);
    mm->add_instruction(migraphx::make_op("add"), l0, l2);
kahmed10's avatar
kahmed10 committed
243
244
245
246
247
    auto prog = optimize_tf("biasadd_scalar_test.pb", true);

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
248
249
250
TEST_CASE(cast_test)
{
    migraphx::program p;
251
252
253

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
254
255
256
257
    mm->add_instruction(
        migraphx::make_op("convert",
                          {{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}),
        l0);
Khalique's avatar
Khalique committed
258
259
260
261
262
    auto prog = optimize_tf("cast_test.pb", false);

    EXPECT(p == prog);
}

263
264
265
TEST_CASE(concat_test)
{
    migraphx::program p;
Khalique's avatar
Khalique committed
266

267
268
269
270
    auto* mm = p.get_main_module();

    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 7, 3}});
    auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}});
271
272
273

    int axis = 1;
    // tf uses axis as the third input, and it is in int32 format
Khalique's avatar
Khalique committed
274
    // add the literal using a vector in order to set stride to 1 (like in tf parser)
275
    mm->add_literal(migraphx::shape{migraphx::shape::int32_type}, std::vector<int>{axis});
276

277
    mm->add_instruction(migraphx::make_op("concat", {{"axis", axis}}), l0, l1);
Paul's avatar
Paul committed
278
    auto prog = optimize_tf("concat_test.pb", false);
279
280
281
282
283
284
285

    EXPECT(p == prog);
}

TEST_CASE(const_test)
{
    migraphx::program p;
286
287
288

    auto* mm = p.get_main_module();
    mm->add_literal(migraphx::shape{migraphx::shape::float_type}, std::vector<float>{1.0f});
Paul's avatar
Paul committed
289
    auto prog = optimize_tf("constant_test.pb", false);
290
291
292
293

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
294
migraphx::program create_conv()
295
296
{
    migraphx::program p;
Khalique's avatar
Khalique committed
297

298
299
300
    auto* mm = p.get_main_module();

    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
Khalique's avatar
Khalique committed
301
    std::vector<float> weight_data(3 * 3 * 3 * 32);
302
    std::fill(weight_data.begin(), weight_data.end(), 1.0f);
Khalique's avatar
Khalique committed
303
    auto l1 =
304
        mm->add_literal(migraphx::shape{migraphx::shape::float_type, {3, 3, 3, 32}}, weight_data);
305
306
307

    migraphx::op::convolution op;
    op.padding_mode = migraphx::op::padding_mode_t::same;
kahmed10's avatar
kahmed10 committed
308
    op.padding      = {1, 1, 1, 1};
Khalique's avatar
Khalique committed
309
310
    op.stride       = {1, 1};
    op.dilation     = {1, 1};
311
    auto l2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {3, 2, 0, 1}}}), l1);
312
    mm->add_instruction(op, l0, l2);
kahmed10's avatar
kahmed10 committed
313
314
315
316
317
318
319
320
321
322
323
    return p;
}

TEST_CASE(conv_test)
{
    migraphx::program p = create_conv();
    auto prog           = optimize_tf("conv_test.pb", true);

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
324
325
326
327
328
329
330
331
332
333
334
TEST_CASE(conv_add_test)
{
    migraphx::program p = create_conv();
    auto* mm            = p.get_main_module();
    auto l0             = std::prev(mm->end());
    mm->add_instruction(migraphx::make_op("add"), l0, l0);
    auto prog = optimize_tf("conv_add_test.pb", true);

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
335
336
337
338
TEST_CASE(conv_nchw_test)
{
    migraphx::program p = create_conv();
    auto prog           = optimize_tf("conv_nchw_test.pb", false);
339
340
341
342

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
TEST_CASE(conv_relu_test)
{
    migraphx::program p = create_conv();
    auto* mm            = p.get_main_module();
    auto l0             = std::prev(mm->end());
    mm->add_instruction(migraphx::make_op("relu"), l0);
    auto prog = optimize_tf("conv_relu_test.pb", true);

    EXPECT(p == prog);
}

TEST_CASE(conv_relu6_test)
{
    migraphx::program p = create_conv();
    auto* mm            = p.get_main_module();
    std::vector<size_t> input_lens{1, 32, 16, 16};
    auto l0      = std::prev(mm->end());
    auto min_val = mm->add_literal(0.0f);
    auto max_val = mm->add_literal(6.0f);
    min_val      = mm->add_instruction(
        migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), min_val);
    max_val = mm->add_instruction(
        migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), max_val);
    mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val);
    auto prog = optimize_tf("conv_relu6_test.pb", true);

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
372
373
374
375
TEST_CASE(depthwiseconv_test)
{
    migraphx::program p;

376
377
378
    auto* mm = p.get_main_module();

    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
Khalique's avatar
Khalique committed
379
380
381
    std::vector<float> weight_data(3 * 3 * 3 * 1);
    std::fill(weight_data.begin(), weight_data.end(), 1.0f);
    auto l1 =
382
        mm->add_literal(migraphx::shape{migraphx::shape::float_type, {3, 3, 3, 1}}, weight_data);
Khalique's avatar
Khalique committed
383
384
385

    migraphx::op::convolution op;
    op.padding_mode = migraphx::op::padding_mode_t::same;
Khalique's avatar
Khalique committed
386
    op.padding      = {1, 1};
Khalique's avatar
Khalique committed
387
388
389
    op.stride       = {1, 1};
    op.dilation     = {1, 1};
    op.group        = 3;
390
391
392
    auto l3 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {3, 2, 0, 1}}}), l1);
    auto l4 = mm->add_instruction(migraphx::make_op("contiguous"), l3);
    auto l5 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, 1, 3, 3}}}), l4);
393
    mm->add_instruction(op, l0, l5);
Paul's avatar
Paul committed
394
    auto prog = optimize_tf("depthwise_conv_test.pb", true);
Khalique's avatar
Khalique committed
395
396
397
398

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
399
400
401
TEST_CASE(expanddims_test)
{
    migraphx::program p;
Khalique's avatar
Khalique committed
402

403
404
405
406
    auto* mm = p.get_main_module();

    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4}});
    mm->add_literal(0);
407
    mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 2, 3, 4}}}), l0);
408
    auto prog = optimize_tf("expanddims_test.pb", false);
Khalique's avatar
Khalique committed
409
410
411
412
413
414
415
416
417

    EXPECT(p == prog);
}

TEST_CASE(expanddims_test_neg_dims)
{
    // this check makes sure the pb parses negative dim value correctly
    migraphx::program p;

418
419
420
421
    auto* mm = p.get_main_module();

    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4}});
    mm->add_literal(-1);
422
    mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 3, 4, 1}}}), l0);
423
    auto prog = optimize_tf("expanddims_neg_test.pb", false);
Khalique's avatar
Khalique committed
424
425
426
427

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
428
429
430
431
TEST_CASE(gather_test)
{
    migraphx::program p;

432
433
434
435
436
437
    auto* mm = p.get_main_module();

    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 4}});
    auto l1 = mm->add_literal(
        migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 1}});
    mm->add_literal(1);
Khalique's avatar
Khalique committed
438
439

    int axis = 1;
440
    mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), l0, l1);
Khalique's avatar
Khalique committed
441
442
    auto prog = optimize_tf("gather_test.pb", false);

Khalique's avatar
Khalique committed
443
444
445
    EXPECT(p == prog);
}

446
447
448
TEST_CASE(identity_test)
{
    migraphx::program p;
449
450
451

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
452
    mm->add_instruction(migraphx::make_op("identity"), l0);
Paul's avatar
Paul committed
453
    auto prog = optimize_tf("identity_test.pb", false);
454
455
456
457

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
458
459
460
TEST_CASE(matmul_test)
{
    migraphx::program p;
Khalique's avatar
Khalique committed
461

462
463
464
    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {8, 4}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 8}});
Khalique's avatar
Khalique committed
465

466
467
    auto trans_l0 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l0);
    auto trans_l1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1);
468

469
    mm->add_instruction(migraphx::make_op("dot"), trans_l0, trans_l1);
Paul's avatar
Paul committed
470
    auto prog = optimize_tf("matmul_test.pb", false);
Khalique's avatar
Khalique committed
471
472
473
474

    EXPECT(p == prog);
}

475
476
477
TEST_CASE(mean_test)
{
    migraphx::program p;
478
479

    auto* mm = p.get_main_module();
Khalique's avatar
Khalique committed
480
    migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {2, 3}};
481
482
483
    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
    mm->add_literal(l);
    mm->add_literal(l);
Paul's avatar
Paul committed
484
    migraphx::op::reduce_mean op{{2, 3}};
485
486
    mm->add_instruction(op, l0);
    auto l3 = mm->add_instruction(op, l0);
487
    mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2, 3}}}), l3);
Paul's avatar
Paul committed
488
    auto prog = optimize_tf("mean_test.pb", false);
489
490
491
492
493
494
495

    EXPECT(p == prog);
}

TEST_CASE(mean_test_nhwc)
{
    migraphx::program p;
496
497

    auto* mm = p.get_main_module();
Khalique's avatar
Khalique committed
498
    migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 2}};
499
    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
500
    auto l1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0);
Khalique's avatar
Khalique committed
501
    migraphx::op::reduce_mean op{{1, 2}};
502
    auto l2 = mm->add_instruction(op, l1);
503
    mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1, 2}}}), l2);
Paul's avatar
Paul committed
504
    auto prog = optimize_tf("mean_test_nhwc.pb", true);
505
506
507
508

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
509
510
511
512
TEST_CASE(mul_test)
{
    migraphx::program p;

513
514
515
516
    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 16}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 16}});

517
    mm->add_instruction(migraphx::make_op("mul"), l0, l1);
Paul's avatar
Paul committed
518
    auto prog = optimize_tf("mul_test.pb", false);
Khalique's avatar
Khalique committed
519
520
521
522

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
TEST_CASE(multi_output_test)
{
    migraphx::program p;

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
    auto l1  = mm->add_instruction(migraphx::make_op("relu"), l0);
    auto l2  = mm->add_instruction(migraphx::make_op("tanh"), l0);
    mm->add_return({l1, l2});

    EXPECT(test::throws([&] { parse_tf("multi_output_test.pb", false, {}, {"relu", "relu6"}); }));
    auto prog = parse_tf("multi_output_test.pb", false, {}, {"relu", "tanh"});

    EXPECT(p == prog);
}

539
540
541
TEST_CASE(onehot_test)
{
    migraphx::program p;
542
543
544

    auto* mm = p.get_main_module();
    auto l0  = mm->add_literal(
Khalique's avatar
Khalique committed
545
        migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {5}}, {1, 1, 1, 1, 1}});
546
547
548
549
    mm->add_literal(2);
    mm->add_literal(1.0f);
    mm->add_literal(0.0f);
    auto l1 = mm->add_literal(
Khalique's avatar
Khalique committed
550
        migraphx::literal{migraphx::shape{migraphx::shape::float_type, {2, 2}}, {1, 0, 0, 1}});
551
    int axis = 0;
552
    mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), l1, l0);
553
554
555
556
557
    auto prog = optimize_tf("onehot_test.pb", false);

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
558
559
560
561
562
563
564
565
TEST_CASE(noop_test)
{
    migraphx::program p;
    auto prog = optimize_tf("noop_test.pb", false);

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
566
567
568
TEST_CASE(pack_test)
{
    migraphx::program p;
569
570
571
572
573

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2}});
    auto l2  = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {2}});
Khalique's avatar
Khalique committed
574
575
576
577
    std::vector<migraphx::instruction_ref> args{l0, l1, l2};
    std::vector<migraphx::instruction_ref> unsqueezed_args;
    int64_t axis = 1;

578
579
580
581
582
583
584
585
586
    std::transform(
        args.begin(),
        args.end(),
        std::back_inserter(unsqueezed_args),
        [&](migraphx::instruction_ref arg) {
            return mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {axis}}}), arg);
        });
    mm->add_instruction(migraphx::make_op("concat", {{"axis", static_cast<int>(axis)}}),
                        unsqueezed_args);
Paul's avatar
Paul committed
587
    auto prog = optimize_tf("pack_test.pb", false);
Khalique's avatar
Khalique committed
588
589
590
591

    EXPECT(p == prog);
}

592
593
594
TEST_CASE(pack_test_nhwc)
{
    migraphx::program p;
595
596
597

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
598
    auto lt0 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0);
599
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
600
    auto lt1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l1);
601
    auto l2  = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
602
    auto lt2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l2);
Paul's avatar
Paul committed
603
    std::vector<migraphx::instruction_ref> args{lt0, lt1, lt2};
604
    std::vector<migraphx::instruction_ref> unsqueezed_args;
Paul's avatar
Paul committed
605
    int64_t nchw_axis = 3;
606
607
608
609
610

    std::transform(args.begin(),
                   args.end(),
                   std::back_inserter(unsqueezed_args),
                   [&](migraphx::instruction_ref arg) {
611
612
                       return mm->add_instruction(
                           migraphx::make_op("unsqueeze", {{"axes", {nchw_axis}}}), arg);
613
                   });
614
615
    mm->add_instruction(migraphx::make_op("concat", {{"axis", static_cast<int>(nchw_axis)}}),
                        unsqueezed_args);
Paul's avatar
Paul committed
616
    auto prog = optimize_tf("pack_test_nhwc.pb", true);
617
618
619
620

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
TEST_CASE(pad_test)
{
    migraphx::program p;

    auto* mm = p.get_main_module();

    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 4}});
    std::vector<int> pad_literals{1, 1, 2, 2};
    std::vector<int> pads{1, 2, 1, 2};
    mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {2, 2}}, pad_literals);

    mm->add_instruction(migraphx::make_op("pad", {{"pads", pads}}), l0);
    auto prog = optimize_tf("pad_test.pb", false);

    EXPECT(p == prog);
}

638
639
640
TEST_CASE(pooling_test)
{
    migraphx::program p;
641
642
643

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
644
645
    migraphx::op::pooling avg_pool_op{"average"};
    migraphx::op::pooling max_pool_op{"max"};
Shucai Xiao's avatar
Shucai Xiao committed
646
647
648
649
    avg_pool_op.stride  = {2, 2};
    max_pool_op.stride  = {2, 2};
    avg_pool_op.lengths = {2, 2};
    max_pool_op.lengths = {2, 2};
kahmed10's avatar
kahmed10 committed
650
    mm->add_instruction(avg_pool_op, l0);
651
    mm->add_instruction(max_pool_op, l0);
Paul's avatar
Paul committed
652
    auto prog = optimize_tf("pooling_test.pb", true);
653
654
655
656

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
657
658
659
TEST_CASE(pow_test)
{
    migraphx::program p;
660
661
662
663

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
664
    mm->add_instruction(migraphx::make_op("pow"), l0, l1);
Khalique's avatar
Khalique committed
665
666
667
668
669
    auto prog = optimize_tf("pow_test.pb", false);

    EXPECT(p == prog);
}

670
671
672
TEST_CASE(relu_test)
{
    migraphx::program p;
673
674
675

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
676
    mm->add_instruction(migraphx::make_op("relu"), l0);
Paul's avatar
Paul committed
677
    auto prog = optimize_tf("relu_test.pb", false);
678
679
680
681

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
682
683
684
TEST_CASE(relu6_test)
{
    migraphx::program p;
685
686

    auto* mm = p.get_main_module();
kahmed10's avatar
kahmed10 committed
687
    std::vector<size_t> input_lens{1, 3, 16, 16};
688
689
690
    auto l0      = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, input_lens});
    auto min_val = mm->add_literal(0.0f);
    auto max_val = mm->add_literal(6.0f);
691
692
693
694
695
    min_val      = mm->add_instruction(
        migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), min_val);
    max_val = mm->add_instruction(
        migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), max_val);
    mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val);
Paul's avatar
Paul committed
696
    auto prog = optimize_tf("relu6_test.pb", false);
Khalique's avatar
Khalique committed
697
698
699
700

    EXPECT(p == prog);
}

701
702
703
TEST_CASE(reshape_test)
{
    migraphx::program p;
704
705
706

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {16}});
707
708
    migraphx::shape s0{migraphx::shape::int32_type, {4}};
    // in tf, the second arg is a literal that contains new dimensions
709
    mm->add_literal(migraphx::literal{s0, {1, 1, 1, 16}});
710
    mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 1, 1, 16}}}), l0);
Paul's avatar
Paul committed
711
    auto prog = optimize_tf("reshape_test.pb", false);
712
713
714
715

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
716
717
718
TEST_CASE(rsqrt_test)
{
    migraphx::program p;
719
720
721

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
722
    mm->add_instruction(migraphx::make_op("rsqrt"), l0);
Khalique's avatar
Khalique committed
723
724
725
726
727
    auto prog = optimize_tf("rsqrt_test.pb", false);

    EXPECT(p == prog);
}

728
729
730
TEST_CASE(shape_test)
{
    migraphx::program p;
731
732
733
734

    auto* mm = p.get_main_module();
    mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
    mm->add_literal(
735
736
737
738
739
740
        migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {4}}, {1, 3, 16, 16}});
    auto prog = optimize_tf("shape_test.pb", false);

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
741
742
743
TEST_CASE(slice_test)
{
    migraphx::program p;
744
745

    auto* mm             = p.get_main_module();
Khalique's avatar
Khalique committed
746
    std::size_t num_axes = 2;
747
    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 10}});
Khalique's avatar
Khalique committed
748
    migraphx::shape s0{migraphx::shape::int32_type, {num_axes}};
749
750
    mm->add_literal(migraphx::literal{s0, {1, 0}});
    mm->add_literal(migraphx::literal{s0, {2, -1}});
Khalique's avatar
Khalique committed
751
752
753
754
755
756

    migraphx::op::slice op;
    op.starts = {1, 0};
    op.ends   = {3, 10};
    op.axes   = std::vector<int64_t>(num_axes);
    std::iota(op.axes.begin(), op.axes.end(), 0);
757
    mm->add_instruction(op, l0);
Khalique's avatar
Khalique committed
758
759
760
761
762
    auto prog = optimize_tf("slice_test.pb", false);

    EXPECT(p == prog);
}

763
764
765
TEST_CASE(softmax_test)
{
    migraphx::program p;
766
767
768

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3}});
769
    mm->add_instruction(migraphx::make_op("softmax", {{"axis", 1}}), l0);
Paul's avatar
Paul committed
770
    auto prog = optimize_tf("softmax_test.pb", false);
771
772
773
774

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
775
776
777
TEST_CASE(split_test)
{
    migraphx::program p;
778
779

    auto* mm = p.get_main_module();
kahmed10's avatar
kahmed10 committed
780
    std::vector<int64_t> axes{0, 1};
781
782
783
784
785
    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}});
    mm->add_literal(3); // num_splits
    mm->add_literal(1); // split axis
    mm->add_literal(1); // concat axis
    mm->add_literal(1); // concat axis
786
787
788
789
790
791
    auto l1 = mm->add_instruction(
        migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 0}}, {"ends", {5, 10}}}), l0);
    auto l2 = mm->add_instruction(
        migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 10}}, {"ends", {5, 20}}}), l0);
    auto l3 = mm->add_instruction(
        migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 20}}, {"ends", {5, 30}}}), l0);
kahmed10's avatar
kahmed10 committed
792
793
794
    auto l4 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l1, l2);
    auto l5 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l3);
    mm->add_return({l4, l5});
795
    auto prog = parse_tf("split_test.pb", false);
kahmed10's avatar
kahmed10 committed
796
797
798
799
800
801
802

    EXPECT(p == prog);
}

TEST_CASE(split_test_one_output)
{
    migraphx::program p;
803
804
805
806
807

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}});
    mm->add_literal(1); // num_splits
    mm->add_literal(1); // split axis
kahmed10's avatar
kahmed10 committed
808
809
    auto l1 = mm->add_instruction(migraphx::make_op("identity"), l0);
    mm->add_return({l1});
810
    auto prog = parse_tf("split_test_one_output.pb", false);
kahmed10's avatar
kahmed10 committed
811
812
813
814
815
816
817

    EXPECT(p == prog);
}

TEST_CASE(split_test_vector_as_input)
{
    migraphx::program p;
818
819

    auto* mm = p.get_main_module();
kahmed10's avatar
kahmed10 committed
820
    std::vector<int64_t> axes{0, 1};
821
    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}});
kahmed10's avatar
kahmed10 committed
822
    // split sizes
823
    mm->add_literal(
kahmed10's avatar
kahmed10 committed
824
        migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {3}}, {4, 15, 11}});
825
826
827
    mm->add_literal(1); // split axis
    mm->add_literal(1); // concat axis
    mm->add_literal(1); // concat axis
828
829
830
831
832
833
    auto l1 = mm->add_instruction(
        migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 0}}, {"ends", {5, 4}}}), l0);
    auto l2 = mm->add_instruction(
        migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 4}}, {"ends", {5, 19}}}), l0);
    auto l3 = mm->add_instruction(
        migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 19}}, {"ends", {5, 30}}}), l0);
kahmed10's avatar
kahmed10 committed
834
835
836
    auto l4 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l1, l2);
    auto l5 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l3);
    mm->add_return({l4, l5});
837
    auto prog = parse_tf("split_test_vector_as_input.pb", false);
kahmed10's avatar
kahmed10 committed
838
839
840
841

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
842
843
844
TEST_CASE(sqdiff_test)
{
    migraphx::program p;
845
846
847
848

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
849
    mm->add_instruction(migraphx::make_op("sqdiff"), l0, l1);
Khalique's avatar
Khalique committed
850
851
852
853
854
    auto prog = optimize_tf("sqdiff_test.pb", false);

    EXPECT(p == prog);
}

855
856
857
TEST_CASE(squeeze_test)
{
    migraphx::program p;
858
859
860

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 1}});
861
    mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0, 3}}}), l0);
Paul's avatar
Paul committed
862
    auto prog = optimize_tf("squeeze_test.pb", false);
863
864
865

    EXPECT(p == prog);
}
Khalique's avatar
Khalique committed
866

Khalique's avatar
Khalique committed
867
868
869
TEST_CASE(stopgradient_test)
{
    migraphx::program p;
870
871
872

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
873
    mm->add_instruction(migraphx::make_op("identity"), l0);
Khalique's avatar
Khalique committed
874
875
    auto prog = optimize_tf("stopgradient_test.pb", false);

Khalique's avatar
Khalique committed
876
    EXPECT(p == prog);
Khalique's avatar
Khalique committed
877
878
}

Khalique's avatar
Khalique committed
879
880
881
TEST_CASE(stridedslice_test)
{
    migraphx::program p;
882
883
884

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 1, 1}});
885
    auto l1  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0);
Khalique's avatar
Khalique committed
886
887
    std::size_t num_axes = 4;
    migraphx::op::slice op;
Khalique's avatar
Khalique committed
888
    op.starts = {0, 0, 0, 0};
Paul's avatar
Paul committed
889
    op.ends   = {1, 1, 1, 5};
Khalique's avatar
Khalique committed
890
891
    op.axes   = std::vector<int64_t>(num_axes);
    std::iota(op.axes.begin(), op.axes.end(), 0);
892
    auto l2          = mm->add_instruction(op, l1);
Paul's avatar
Paul committed
893
    auto shrink_axis = 1;
894
    mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {shrink_axis}}}), l2);
Paul's avatar
Paul committed
895
    auto prog = optimize_tf("stridedslice_test.pb", true);
Khalique's avatar
Khalique committed
896
897
898
899

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
900
901
902
TEST_CASE(stridedslice_masks_test)
{
    migraphx::program p;
903
904
905

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 3, 3}});
Khalique's avatar
Khalique committed
906
907
    std::size_t num_axes = 4;
    migraphx::op::slice op;
908
909
    op.starts = {0, 1, 1, 0};
    op.ends   = {1, 3, 3, 10};
Khalique's avatar
Khalique committed
910
911
912
    op.axes   = std::vector<int64_t>(num_axes);
    std::iota(op.axes.begin(), op.axes.end(), 0);
    // add literals for starts, ends, and strides in tf (NHWC format)
913
914
915
916
917
918
919
    mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}},
                    std::vector<int>{0, 1, 1, 0});
    mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}},
                    std::vector<int>{0, 0, 0, 0});
    mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}},
                    std::vector<int>{1, 1, 1, 1});

920
    auto l1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0);
921
    auto l2 = mm->add_instruction(op, l1);
kahmed10's avatar
kahmed10 committed
922
923
    auto l3 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 3, 1, 2}}}), l2);
    mm->add_return({l3});
924
    auto prog = parse_tf("stridedslice_masks_test.pb", true);
Khalique's avatar
Khalique committed
925
926
927
928

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
929
930
931
TEST_CASE(sub_test)
{
    migraphx::program p;
932
933
934
935

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
kahmed10's avatar
kahmed10 committed
936
937
    auto l2  = mm->add_instruction(migraphx::make_op("sub"), l0, l1);
    mm->add_return({l2});
938
    auto prog = parse_tf("sub_test.pb", false);
Khalique's avatar
Khalique committed
939
940
941
942

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
943
944
945
TEST_CASE(tanh_test)
{
    migraphx::program p;
946
947

    auto* mm = p.get_main_module();
kahmed10's avatar
kahmed10 committed
948
949
950
951
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
    auto l1  = mm->add_instruction(migraphx::make_op("tanh"), l0);
    mm->add_return({l1});
    auto prog = parse_tf("tanh_test.pb", false);
Khalique's avatar
Khalique committed
952
953
954
955

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
956
957
958
TEST_CASE(transpose_test)
{
    migraphx::program p;
959
960
961

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
Khalique's avatar
Khalique committed
962
    migraphx::shape s0{migraphx::shape::int32_type, {4}};
963
    mm->add_literal(migraphx::literal{s0, {0, 2, 3, 1}});
964
    mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0);
Khalique's avatar
Khalique committed
965
966
967
968
969
    auto prog = optimize_tf("transpose_test.pb", false);

    EXPECT(p == prog);
}

970
971
972
TEST_CASE(variable_batch_test)
{
    migraphx::program p;
973
974
975

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
976
    mm->add_instruction(migraphx::make_op("identity"), l0);
977
978
979
980
981
    auto prog = optimize_tf("variable_batch_test.pb", false);

    EXPECT(p == prog);
}

982
int main(int argc, const char* argv[]) { test::run(argc, argv); }