tf_test.cpp 32.2 KB
Newer Older
1
2
#include <iostream>
#include <vector>
Shucai Xiao's avatar
Shucai Xiao committed
3
#include <unordered_map>
4
#include <migraphx/literal.hpp>
Paul's avatar
Paul committed
5
6
7
8
#include <migraphx/pass_manager.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp>
9
10
11
12
#include <migraphx/operators.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/tf.hpp>
13
14
15
16
#include <migraphx/make_op.hpp>

#include <migraphx/serialize.hpp>

17
18
#include "test.hpp"

Shucai Xiao's avatar
Shucai Xiao committed
19
20
21
migraphx::program
parse_tf(const std::string& name,
         bool is_nhwc,
kahmed10's avatar
kahmed10 committed
22
23
         const std::unordered_map<std::string, std::vector<std::size_t>>& dim_params = {},
         const std::vector<std::string>& output_node_names                           = {})
24
{
kahmed10's avatar
kahmed10 committed
25
26
    return migraphx::parse_tf(name,
                              migraphx::tf_options{is_nhwc, 1, dim_params, output_node_names});
27
28
}

Paul's avatar
Paul committed
29
30
migraphx::program optimize_tf(const std::string& name, bool is_nhwc)
{
31
    auto prog = migraphx::parse_tf(name, migraphx::tf_options{is_nhwc, 1});
32
    auto* mm  = prog.get_main_module();
Paul's avatar
Paul committed
33
    if(is_nhwc)
34
        migraphx::run_passes(*mm,
Paul's avatar
Paul committed
35
36
37
                             {migraphx::simplify_reshapes{},
                              migraphx::dead_code_elimination{},
                              migraphx::eliminate_identity{}});
kahmed10's avatar
kahmed10 committed
38
39
40
41

    // remove the last return instruction
    auto last_ins = std::prev(mm->end());
    if(last_ins != mm->end())
kahmed10's avatar
kahmed10 committed
42
    {
kahmed10's avatar
kahmed10 committed
43
44
45
46
        if(last_ins->name() == "@return")
        {
            mm->remove_instruction(last_ins);
        }
kahmed10's avatar
kahmed10 committed
47
    }
Paul's avatar
Paul committed
48
49
50
    return prog;
}

51
52
53
TEST_CASE(add_test)
{
    migraphx::program p;
54
55
56
57

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
58
    mm->add_instruction(migraphx::make_op("add"), l0, l1);
Paul's avatar
Paul committed
59
    auto prog = optimize_tf("add_test.pb", false);
60
61
62
63

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
64
65
66
TEST_CASE(addv2_test)
{
    migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
67
68
69
70
    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    mm->add_instruction(migraphx::make_op("add"), l0, l1);
kahmed10's avatar
kahmed10 committed
71
72
73
74
75
    auto prog = optimize_tf("addv2_test.pb", false);

    EXPECT(p == prog);
}

76
77
TEST_CASE(add_bcast_test)
{
Khalique's avatar
Khalique committed
78

79
    migraphx::program p;
80
81

    auto* mm = p.get_main_module();
82
    migraphx::shape s0{migraphx::shape::float_type, {2, 3}};
83
84
    auto l0 = mm->add_parameter("0", s0);
    auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2, 1}});
85
86
    auto l2 =
        mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", s0.lens()}}), l1);
kahmed10's avatar
kahmed10 committed
87
    mm->add_instruction(migraphx::make_op("add"), l0, l2);
Paul's avatar
Paul committed
88
    auto prog = optimize_tf("add_bcast_test.pb", false);
89
90
91
92

    EXPECT(p == prog);
}

93
94
95
TEST_CASE(argmax_test)
{
    migraphx::program p;
96
97

    auto* mm = p.get_main_module();
Shucai Xiao's avatar
Shucai Xiao committed
98
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 5, 6, 7}});
99
    mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {2}});
100
    auto ins = mm->add_instruction(migraphx::make_op("argmax", {{"axis", 2}}), l0);
kahmed10's avatar
kahmed10 committed
101
102
    auto l1  = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), ins);
    mm->add_return({l1});
Shucai Xiao's avatar
Shucai Xiao committed
103
    auto prog = parse_tf("argmax_test.pb", false, {{"0", {4, 5, 6, 7}}});
104
105
106
107
108
109
110

    EXPECT(p == prog);
}

TEST_CASE(argmin_test)
{
    migraphx::program p;
111
112
113
114

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
    mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {2}});
115
    auto ins = mm->add_instruction(migraphx::make_op("argmin", {{"axis", 2}}), l0);
kahmed10's avatar
kahmed10 committed
116
117
    auto l1  = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), ins);
    mm->add_return({l1});
118
119
120
121
122
    auto prog = parse_tf("argmin_test.pb", false);

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
123
124
125
TEST_CASE(assert_less_equal_test)
{
    migraphx::program p;
126
127

    auto* mm = p.get_main_module();
Khalique's avatar
Khalique committed
128
    migraphx::shape s0{migraphx::shape::float_type, {2, 3}};
129
130
    auto l0 = mm->add_parameter("0", s0);
    auto l1 = mm->add_parameter("1", s0);
Khalique's avatar
Khalique committed
131
    migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {0, 1}};
132
    auto l2 = mm->add_literal(l);
133
134
135
    mm->add_instruction(migraphx::make_op("add"), l0, l1);
    auto l3 = mm->add_instruction(migraphx::make_op("identity"), l0, l1);
    mm->add_instruction(migraphx::make_op("identity"), l3, l2);
Khalique's avatar
Khalique committed
136
137
138
139
140
    auto prog = optimize_tf("assert_less_equal_test.pb", false);

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
141
142
143
144
TEST_CASE(batchmatmul_test)
{
    migraphx::program p;

145
146
147
148
    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 4}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 4, 8}});

149
150
151
152
    auto trans_l0 =
        mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l0);
    auto trans_l1 =
        mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l1);
Khalique's avatar
Khalique committed
153

154
    mm->add_instruction(migraphx::make_op("dot"), trans_l0, trans_l1);
Khalique's avatar
Khalique committed
155
156
157
158
159
    auto prog = optimize_tf("batchmatmul_test.pb", false);

    EXPECT(p == prog);
}

160
161
TEST_CASE(batchnorm_test)
{
Khalique's avatar
Khalique committed
162
163
    float epsilon  = 1.001e-5f;
    float momentum = 0.9f;
164
165

    migraphx::program p;
166
167

    auto* mm = p.get_main_module();
Khalique's avatar
Khalique committed
168
169
    migraphx::op::batch_norm_inference op{
        epsilon, momentum, migraphx::op::batch_norm_inference::spatial};
170
    migraphx::shape s0{migraphx::shape::float_type, {32}};
171
    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 32, 16, 16}});
172
173
    std::vector<float> const_vals(32);
    std::fill(const_vals.begin(), const_vals.end(), 1.0f);
Khalique's avatar
Khalique committed
174

175
176
177
178
179
    auto l2 = mm->add_parameter("2", s0);
    auto l3 = mm->add_parameter("3", s0);
    auto l4 = mm->add_parameter("4", s0);
    auto l1 = mm->add_literal(migraphx::literal{s0, const_vals});
    mm->add_instruction(op, l0, l1, l2, l3, l4);
Paul's avatar
Paul committed
180
    auto prog = optimize_tf("batchnorm_test.pb", true);
181
182
183
184

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
185
186
187
188
189
190
TEST_CASE(batchnormv3_test)
{
    float epsilon  = 1.0e-5f;
    float momentum = 0.9f;

    migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
191
    auto* mm = p.get_main_module();
kahmed10's avatar
kahmed10 committed
192
193
194
    migraphx::op::batch_norm_inference op{
        epsilon, momentum, migraphx::op::batch_norm_inference::spatial};
    migraphx::shape s0{migraphx::shape::float_type, {32}};
Shucai Xiao's avatar
Shucai Xiao committed
195
    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 32, 16, 16}});
kahmed10's avatar
kahmed10 committed
196
197
198
    std::vector<float> const_vals(32);
    std::fill(const_vals.begin(), const_vals.end(), 1.0f);

Shucai Xiao's avatar
Shucai Xiao committed
199
200
201
202
203
    auto l2 = mm->add_parameter("2", s0);
    auto l3 = mm->add_parameter("3", s0);
    auto l4 = mm->add_parameter("4", s0);
    auto l1 = mm->add_literal(migraphx::literal{s0, const_vals});
    mm->add_instruction(op, l0, l1, l2, l3, l4);
kahmed10's avatar
kahmed10 committed
204
205
206
207
208
    auto prog = optimize_tf("batchnormv3_test.pb", true);

    EXPECT(p == prog);
}

209
210
211
TEST_CASE(biasadd_test)
{
    migraphx::program p;
212
213

    auto* mm = p.get_main_module();
214
    migraphx::shape s0{migraphx::shape::float_type, {1, 500, 1, 1}};
215
    uint64_t axis = 1;
216
217
    auto l0       = mm->add_parameter("0", s0);
    auto l1       = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {500}});
218
219
220
    auto l2       = mm->add_instruction(
        migraphx::make_op("broadcast", {{"axis", axis}, {"dims", l0->get_shape().lens()}}), l1);
    mm->add_instruction(migraphx::make_op("add"), l0, l2);
Paul's avatar
Paul committed
221
    auto prog = optimize_tf("biasadd_test.pb", true);
222
223
224
225

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
226
227
228
229
230
231
232
233
234
235
TEST_CASE(biasadd_scalar_test)
{
    migraphx::program p;

    auto* mm = p.get_main_module();
    migraphx::shape s0{migraphx::shape::float_type, {1, 1}};
    uint64_t axis = 1;
    auto l0       = mm->add_parameter("0", s0);
    auto l1       = mm->add_literal(
        migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}, {0}}, {1.0}});
236
237
238
    auto l2 = mm->add_instruction(
        migraphx::make_op("broadcast", {{"axis", axis}, {"dims", l0->get_shape().lens()}}), l1);
    mm->add_instruction(migraphx::make_op("add"), l0, l2);
kahmed10's avatar
kahmed10 committed
239
240
241
242
243
    auto prog = optimize_tf("biasadd_scalar_test.pb", true);

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
244
245
246
TEST_CASE(cast_test)
{
    migraphx::program p;
247
248
249

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
250
251
252
253
    mm->add_instruction(
        migraphx::make_op("convert",
                          {{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}),
        l0);
Khalique's avatar
Khalique committed
254
255
256
257
258
    auto prog = optimize_tf("cast_test.pb", false);

    EXPECT(p == prog);
}

259
260
261
TEST_CASE(concat_test)
{
    migraphx::program p;
Khalique's avatar
Khalique committed
262

263
264
265
266
    auto* mm = p.get_main_module();

    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 7, 3}});
    auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}});
267
268
269

    int axis = 1;
    // tf uses axis as the third input, and it is in int32 format
Khalique's avatar
Khalique committed
270
    // add the literal using a vector in order to set stride to 1 (like in tf parser)
271
    mm->add_literal(migraphx::shape{migraphx::shape::int32_type}, std::vector<int>{axis});
272

273
    mm->add_instruction(migraphx::make_op("concat", {{"axis", axis}}), l0, l1);
Paul's avatar
Paul committed
274
    auto prog = optimize_tf("concat_test.pb", false);
275
276
277
278
279
280
281

    EXPECT(p == prog);
}

TEST_CASE(const_test)
{
    migraphx::program p;
282
283
284

    auto* mm = p.get_main_module();
    mm->add_literal(migraphx::shape{migraphx::shape::float_type}, std::vector<float>{1.0f});
Paul's avatar
Paul committed
285
    auto prog = optimize_tf("constant_test.pb", false);
286
287
288
289

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
290
migraphx::program create_conv()
291
292
{
    migraphx::program p;
Khalique's avatar
Khalique committed
293

294
295
296
    auto* mm = p.get_main_module();

    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
Khalique's avatar
Khalique committed
297
    std::vector<float> weight_data(3 * 3 * 3 * 32);
298
    std::fill(weight_data.begin(), weight_data.end(), 1.0f);
Khalique's avatar
Khalique committed
299
    auto l1 =
300
        mm->add_literal(migraphx::shape{migraphx::shape::float_type, {3, 3, 3, 32}}, weight_data);
301
302
303

    migraphx::op::convolution op;
    op.padding_mode = migraphx::op::padding_mode_t::same;
kahmed10's avatar
kahmed10 committed
304
    op.padding      = {1, 1, 1, 1};
Khalique's avatar
Khalique committed
305
306
    op.stride       = {1, 1};
    op.dilation     = {1, 1};
307
    auto l2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {3, 2, 0, 1}}}), l1);
308
    mm->add_instruction(op, l0, l2);
kahmed10's avatar
kahmed10 committed
309
310
311
312
313
314
315
316
317
318
319
    return p;
}

TEST_CASE(conv_test)
{
    migraphx::program p = create_conv();
    auto prog           = optimize_tf("conv_test.pb", true);

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
320
321
322
323
324
325
326
327
328
329
330
TEST_CASE(conv_add_test)
{
    migraphx::program p = create_conv();
    auto* mm            = p.get_main_module();
    auto l0             = std::prev(mm->end());
    mm->add_instruction(migraphx::make_op("add"), l0, l0);
    auto prog = optimize_tf("conv_add_test.pb", true);

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
331
332
333
334
TEST_CASE(conv_nchw_test)
{
    migraphx::program p = create_conv();
    auto prog           = optimize_tf("conv_nchw_test.pb", false);
335
336
337
338

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
TEST_CASE(conv_relu_test)
{
    migraphx::program p = create_conv();
    auto* mm            = p.get_main_module();
    auto l0             = std::prev(mm->end());
    mm->add_instruction(migraphx::make_op("relu"), l0);
    auto prog = optimize_tf("conv_relu_test.pb", true);

    EXPECT(p == prog);
}

TEST_CASE(conv_relu6_test)
{
    migraphx::program p = create_conv();
    auto* mm            = p.get_main_module();
    std::vector<size_t> input_lens{1, 32, 16, 16};
    auto l0      = std::prev(mm->end());
    auto min_val = mm->add_literal(0.0f);
    auto max_val = mm->add_literal(6.0f);
    min_val      = mm->add_instruction(
        migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), min_val);
    max_val = mm->add_instruction(
        migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), max_val);
    mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val);
    auto prog = optimize_tf("conv_relu6_test.pb", true);

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
368
369
370
371
TEST_CASE(depthwiseconv_test)
{
    migraphx::program p;

372
373
374
    auto* mm = p.get_main_module();

    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
Khalique's avatar
Khalique committed
375
376
377
    std::vector<float> weight_data(3 * 3 * 3 * 1);
    std::fill(weight_data.begin(), weight_data.end(), 1.0f);
    auto l1 =
378
        mm->add_literal(migraphx::shape{migraphx::shape::float_type, {3, 3, 3, 1}}, weight_data);
Khalique's avatar
Khalique committed
379
380
381

    migraphx::op::convolution op;
    op.padding_mode = migraphx::op::padding_mode_t::same;
Khalique's avatar
Khalique committed
382
    op.padding      = {1, 1};
Khalique's avatar
Khalique committed
383
384
385
    op.stride       = {1, 1};
    op.dilation     = {1, 1};
    op.group        = 3;
386
387
388
    auto l3 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {3, 2, 0, 1}}}), l1);
    auto l4 = mm->add_instruction(migraphx::make_op("contiguous"), l3);
    auto l5 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, 1, 3, 3}}}), l4);
389
    mm->add_instruction(op, l0, l5);
Paul's avatar
Paul committed
390
    auto prog = optimize_tf("depthwise_conv_test.pb", true);
Khalique's avatar
Khalique committed
391
392
393
394

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
395
396
397
TEST_CASE(expanddims_test)
{
    migraphx::program p;
Khalique's avatar
Khalique committed
398

399
400
401
402
    auto* mm = p.get_main_module();

    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4}});
    mm->add_literal(0);
403
    mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 2, 3, 4}}}), l0);
404
    auto prog = optimize_tf("expanddims_test.pb", false);
Khalique's avatar
Khalique committed
405
406
407
408
409
410
411
412
413

    EXPECT(p == prog);
}

TEST_CASE(expanddims_test_neg_dims)
{
    // this check makes sure the pb parses negative dim value correctly
    migraphx::program p;

414
415
416
417
    auto* mm = p.get_main_module();

    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4}});
    mm->add_literal(-1);
418
    mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 3, 4, 1}}}), l0);
419
    auto prog = optimize_tf("expanddims_neg_test.pb", false);
Khalique's avatar
Khalique committed
420
421
422
423

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
424
425
426
427
TEST_CASE(gather_test)
{
    migraphx::program p;

428
429
430
431
432
433
    auto* mm = p.get_main_module();

    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 4}});
    auto l1 = mm->add_literal(
        migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 1}});
    mm->add_literal(1);
Khalique's avatar
Khalique committed
434
435

    int axis = 1;
436
    mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), l0, l1);
Khalique's avatar
Khalique committed
437
438
    auto prog = optimize_tf("gather_test.pb", false);

Khalique's avatar
Khalique committed
439
440
441
    EXPECT(p == prog);
}

442
443
444
TEST_CASE(identity_test)
{
    migraphx::program p;
445
446
447

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
448
    mm->add_instruction(migraphx::make_op("identity"), l0);
Paul's avatar
Paul committed
449
    auto prog = optimize_tf("identity_test.pb", false);
450
451
452
453

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
454
455
456
TEST_CASE(matmul_test)
{
    migraphx::program p;
Khalique's avatar
Khalique committed
457

458
459
460
    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {8, 4}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 8}});
Khalique's avatar
Khalique committed
461

462
463
    auto trans_l0 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l0);
    auto trans_l1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1);
464

465
    mm->add_instruction(migraphx::make_op("dot"), trans_l0, trans_l1);
Paul's avatar
Paul committed
466
    auto prog = optimize_tf("matmul_test.pb", false);
Khalique's avatar
Khalique committed
467
468
469
470

    EXPECT(p == prog);
}

471
472
473
TEST_CASE(mean_test)
{
    migraphx::program p;
474
475

    auto* mm = p.get_main_module();
Khalique's avatar
Khalique committed
476
    migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {2, 3}};
477
478
479
    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
    mm->add_literal(l);
    mm->add_literal(l);
Paul's avatar
Paul committed
480
    migraphx::op::reduce_mean op{{2, 3}};
481
482
    mm->add_instruction(op, l0);
    auto l3 = mm->add_instruction(op, l0);
483
    mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2, 3}}}), l3);
Paul's avatar
Paul committed
484
    auto prog = optimize_tf("mean_test.pb", false);
485
486
487
488
489
490
491

    EXPECT(p == prog);
}

TEST_CASE(mean_test_nhwc)
{
    migraphx::program p;
492
493

    auto* mm = p.get_main_module();
Khalique's avatar
Khalique committed
494
    migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 2}};
495
    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
496
    auto l1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0);
Khalique's avatar
Khalique committed
497
    migraphx::op::reduce_mean op{{1, 2}};
498
    auto l2 = mm->add_instruction(op, l1);
499
    mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1, 2}}}), l2);
Paul's avatar
Paul committed
500
    auto prog = optimize_tf("mean_test_nhwc.pb", true);
501
502
503
504

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
505
506
507
508
TEST_CASE(mul_test)
{
    migraphx::program p;

509
510
511
512
    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 16}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 16}});

513
    mm->add_instruction(migraphx::make_op("mul"), l0, l1);
Paul's avatar
Paul committed
514
    auto prog = optimize_tf("mul_test.pb", false);
Khalique's avatar
Khalique committed
515
516
517
518

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
TEST_CASE(multi_output_test)
{
    migraphx::program p;

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
    auto l1  = mm->add_instruction(migraphx::make_op("relu"), l0);
    auto l2  = mm->add_instruction(migraphx::make_op("tanh"), l0);
    mm->add_return({l1, l2});

    EXPECT(test::throws([&] { parse_tf("multi_output_test.pb", false, {}, {"relu", "relu6"}); }));
    auto prog = parse_tf("multi_output_test.pb", false, {}, {"relu", "tanh"});

    EXPECT(p == prog);
}

535
536
537
TEST_CASE(onehot_test)
{
    migraphx::program p;
538
539
540

    auto* mm = p.get_main_module();
    auto l0  = mm->add_literal(
Khalique's avatar
Khalique committed
541
        migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {5}}, {1, 1, 1, 1, 1}});
542
543
544
545
    mm->add_literal(2);
    mm->add_literal(1.0f);
    mm->add_literal(0.0f);
    auto l1 = mm->add_literal(
Khalique's avatar
Khalique committed
546
        migraphx::literal{migraphx::shape{migraphx::shape::float_type, {2, 2}}, {1, 0, 0, 1}});
547
    int axis = 0;
548
    mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), l1, l0);
549
550
551
552
553
    auto prog = optimize_tf("onehot_test.pb", false);

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
554
555
556
557
558
559
560
561
TEST_CASE(noop_test)
{
    migraphx::program p;
    auto prog = optimize_tf("noop_test.pb", false);

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
562
563
564
TEST_CASE(pack_test)
{
    migraphx::program p;
565
566
567
568
569

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2}});
    auto l2  = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {2}});
Khalique's avatar
Khalique committed
570
571
572
573
    std::vector<migraphx::instruction_ref> args{l0, l1, l2};
    std::vector<migraphx::instruction_ref> unsqueezed_args;
    int64_t axis = 1;

574
575
576
577
578
579
580
581
582
    std::transform(
        args.begin(),
        args.end(),
        std::back_inserter(unsqueezed_args),
        [&](migraphx::instruction_ref arg) {
            return mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {axis}}}), arg);
        });
    mm->add_instruction(migraphx::make_op("concat", {{"axis", static_cast<int>(axis)}}),
                        unsqueezed_args);
Paul's avatar
Paul committed
583
    auto prog = optimize_tf("pack_test.pb", false);
Khalique's avatar
Khalique committed
584
585
586
587

    EXPECT(p == prog);
}

588
589
590
TEST_CASE(pack_test_nhwc)
{
    migraphx::program p;
591
592
593

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
594
    auto lt0 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0);
595
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
596
    auto lt1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l1);
597
    auto l2  = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
598
    auto lt2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l2);
Paul's avatar
Paul committed
599
    std::vector<migraphx::instruction_ref> args{lt0, lt1, lt2};
600
    std::vector<migraphx::instruction_ref> unsqueezed_args;
Paul's avatar
Paul committed
601
    int64_t nchw_axis = 3;
602
603
604
605
606

    std::transform(args.begin(),
                   args.end(),
                   std::back_inserter(unsqueezed_args),
                   [&](migraphx::instruction_ref arg) {
607
608
                       return mm->add_instruction(
                           migraphx::make_op("unsqueeze", {{"axes", {nchw_axis}}}), arg);
609
                   });
610
611
    mm->add_instruction(migraphx::make_op("concat", {{"axis", static_cast<int>(nchw_axis)}}),
                        unsqueezed_args);
Paul's avatar
Paul committed
612
    auto prog = optimize_tf("pack_test_nhwc.pb", true);
613
614
615
616

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
TEST_CASE(pad_test)
{
    migraphx::program p;

    auto* mm = p.get_main_module();

    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 4}});
    std::vector<int> pad_literals{1, 1, 2, 2};
    std::vector<int> pads{1, 2, 1, 2};
    mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {2, 2}}, pad_literals);

    mm->add_instruction(migraphx::make_op("pad", {{"pads", pads}}), l0);
    auto prog = optimize_tf("pad_test.pb", false);

    EXPECT(p == prog);
}

634
635
636
TEST_CASE(pooling_test)
{
    migraphx::program p;
637
638
639

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
640
641
    migraphx::op::pooling avg_pool_op{"average"};
    migraphx::op::pooling max_pool_op{"max"};
Shucai Xiao's avatar
Shucai Xiao committed
642
643
644
645
    avg_pool_op.stride  = {2, 2};
    max_pool_op.stride  = {2, 2};
    avg_pool_op.lengths = {2, 2};
    max_pool_op.lengths = {2, 2};
kahmed10's avatar
kahmed10 committed
646
    mm->add_instruction(avg_pool_op, l0);
647
    mm->add_instruction(max_pool_op, l0);
Paul's avatar
Paul committed
648
    auto prog = optimize_tf("pooling_test.pb", true);
649
650
651
652

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
653
654
655
TEST_CASE(pow_test)
{
    migraphx::program p;
656
657
658
659

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
660
    mm->add_instruction(migraphx::make_op("pow"), l0, l1);
Khalique's avatar
Khalique committed
661
662
663
664
665
    auto prog = optimize_tf("pow_test.pb", false);

    EXPECT(p == prog);
}

666
667
668
TEST_CASE(relu_test)
{
    migraphx::program p;
669
670
671

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
672
    mm->add_instruction(migraphx::make_op("relu"), l0);
Paul's avatar
Paul committed
673
    auto prog = optimize_tf("relu_test.pb", false);
674
675
676
677

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
678
679
680
TEST_CASE(relu6_test)
{
    migraphx::program p;
681
682

    auto* mm = p.get_main_module();
kahmed10's avatar
kahmed10 committed
683
    std::vector<size_t> input_lens{1, 3, 16, 16};
684
685
686
    auto l0      = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, input_lens});
    auto min_val = mm->add_literal(0.0f);
    auto max_val = mm->add_literal(6.0f);
687
688
689
690
691
    min_val      = mm->add_instruction(
        migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), min_val);
    max_val = mm->add_instruction(
        migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), max_val);
    mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val);
Paul's avatar
Paul committed
692
    auto prog = optimize_tf("relu6_test.pb", false);
Khalique's avatar
Khalique committed
693
694
695
696

    EXPECT(p == prog);
}

697
698
699
TEST_CASE(reshape_test)
{
    migraphx::program p;
700
701
702

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {16}});
703
704
    migraphx::shape s0{migraphx::shape::int32_type, {4}};
    // in tf, the second arg is a literal that contains new dimensions
705
    mm->add_literal(migraphx::literal{s0, {1, 1, 1, 16}});
706
    mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 1, 1, 16}}}), l0);
Paul's avatar
Paul committed
707
    auto prog = optimize_tf("reshape_test.pb", false);
708
709
710
711

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
712
713
714
TEST_CASE(rsqrt_test)
{
    migraphx::program p;
715
716
717

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
718
    mm->add_instruction(migraphx::make_op("rsqrt"), l0);
Khalique's avatar
Khalique committed
719
720
721
722
723
    auto prog = optimize_tf("rsqrt_test.pb", false);

    EXPECT(p == prog);
}

724
725
726
TEST_CASE(shape_test)
{
    migraphx::program p;
727
728
729
730

    auto* mm = p.get_main_module();
    mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
    mm->add_literal(
731
732
733
734
735
736
        migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {4}}, {1, 3, 16, 16}});
    auto prog = optimize_tf("shape_test.pb", false);

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
737
738
739
TEST_CASE(slice_test)
{
    migraphx::program p;
740
741

    auto* mm             = p.get_main_module();
Khalique's avatar
Khalique committed
742
    std::size_t num_axes = 2;
743
    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 10}});
Khalique's avatar
Khalique committed
744
    migraphx::shape s0{migraphx::shape::int32_type, {num_axes}};
745
746
    mm->add_literal(migraphx::literal{s0, {1, 0}});
    mm->add_literal(migraphx::literal{s0, {2, -1}});
Khalique's avatar
Khalique committed
747
748
749
750
751
752

    migraphx::op::slice op;
    op.starts = {1, 0};
    op.ends   = {3, 10};
    op.axes   = std::vector<int64_t>(num_axes);
    std::iota(op.axes.begin(), op.axes.end(), 0);
753
    mm->add_instruction(op, l0);
Khalique's avatar
Khalique committed
754
755
756
757
758
    auto prog = optimize_tf("slice_test.pb", false);

    EXPECT(p == prog);
}

759
760
761
TEST_CASE(softmax_test)
{
    migraphx::program p;
762
763
764

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3}});
765
    mm->add_instruction(migraphx::make_op("softmax", {{"axis", 1}}), l0);
Paul's avatar
Paul committed
766
    auto prog = optimize_tf("softmax_test.pb", false);
767
768
769
770

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
771
772
773
TEST_CASE(split_test)
{
    migraphx::program p;
774
775

    auto* mm = p.get_main_module();
kahmed10's avatar
kahmed10 committed
776
    std::vector<int64_t> axes{0, 1};
777
778
779
780
781
    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}});
    mm->add_literal(3); // num_splits
    mm->add_literal(1); // split axis
    mm->add_literal(1); // concat axis
    mm->add_literal(1); // concat axis
782
783
784
785
786
787
    auto l1 = mm->add_instruction(
        migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 0}}, {"ends", {5, 10}}}), l0);
    auto l2 = mm->add_instruction(
        migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 10}}, {"ends", {5, 20}}}), l0);
    auto l3 = mm->add_instruction(
        migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 20}}, {"ends", {5, 30}}}), l0);
kahmed10's avatar
kahmed10 committed
788
789
790
    auto l4 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l1, l2);
    auto l5 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l3);
    mm->add_return({l4, l5});
791
    auto prog = parse_tf("split_test.pb", false);
kahmed10's avatar
kahmed10 committed
792
793
794
795
796
797
798

    EXPECT(p == prog);
}

TEST_CASE(split_test_one_output)
{
    migraphx::program p;
799
800
801
802
803

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}});
    mm->add_literal(1); // num_splits
    mm->add_literal(1); // split axis
kahmed10's avatar
kahmed10 committed
804
805
    auto l1 = mm->add_instruction(migraphx::make_op("identity"), l0);
    mm->add_return({l1});
806
    auto prog = parse_tf("split_test_one_output.pb", false);
kahmed10's avatar
kahmed10 committed
807
808
809
810
811
812
813

    EXPECT(p == prog);
}

TEST_CASE(split_test_vector_as_input)
{
    migraphx::program p;
814
815

    auto* mm = p.get_main_module();
kahmed10's avatar
kahmed10 committed
816
    std::vector<int64_t> axes{0, 1};
817
    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}});
kahmed10's avatar
kahmed10 committed
818
    // split sizes
819
    mm->add_literal(
kahmed10's avatar
kahmed10 committed
820
        migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {3}}, {4, 15, 11}});
821
822
823
    mm->add_literal(1); // split axis
    mm->add_literal(1); // concat axis
    mm->add_literal(1); // concat axis
824
825
826
827
828
829
    auto l1 = mm->add_instruction(
        migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 0}}, {"ends", {5, 4}}}), l0);
    auto l2 = mm->add_instruction(
        migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 4}}, {"ends", {5, 19}}}), l0);
    auto l3 = mm->add_instruction(
        migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 19}}, {"ends", {5, 30}}}), l0);
kahmed10's avatar
kahmed10 committed
830
831
832
    auto l4 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l1, l2);
    auto l5 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l3);
    mm->add_return({l4, l5});
833
    auto prog = parse_tf("split_test_vector_as_input.pb", false);
kahmed10's avatar
kahmed10 committed
834
835
836
837

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
838
839
840
TEST_CASE(sqdiff_test)
{
    migraphx::program p;
841
842
843
844

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
845
    mm->add_instruction(migraphx::make_op("sqdiff"), l0, l1);
Khalique's avatar
Khalique committed
846
847
848
849
850
    auto prog = optimize_tf("sqdiff_test.pb", false);

    EXPECT(p == prog);
}

851
852
853
TEST_CASE(squeeze_test)
{
    migraphx::program p;
854
855
856

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 1}});
857
    mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0, 3}}}), l0);
Paul's avatar
Paul committed
858
    auto prog = optimize_tf("squeeze_test.pb", false);
859
860
861

    EXPECT(p == prog);
}
Khalique's avatar
Khalique committed
862

Khalique's avatar
Khalique committed
863
864
865
TEST_CASE(stopgradient_test)
{
    migraphx::program p;
866
867
868

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
869
    mm->add_instruction(migraphx::make_op("identity"), l0);
Khalique's avatar
Khalique committed
870
871
    auto prog = optimize_tf("stopgradient_test.pb", false);

Khalique's avatar
Khalique committed
872
    EXPECT(p == prog);
Khalique's avatar
Khalique committed
873
874
}

Khalique's avatar
Khalique committed
875
876
877
TEST_CASE(stridedslice_test)
{
    migraphx::program p;
878
879
880

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 1, 1}});
881
    auto l1  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0);
Khalique's avatar
Khalique committed
882
883
    std::size_t num_axes = 4;
    migraphx::op::slice op;
Khalique's avatar
Khalique committed
884
    op.starts = {0, 0, 0, 0};
Paul's avatar
Paul committed
885
    op.ends   = {1, 1, 1, 5};
Khalique's avatar
Khalique committed
886
887
    op.axes   = std::vector<int64_t>(num_axes);
    std::iota(op.axes.begin(), op.axes.end(), 0);
888
    auto l2          = mm->add_instruction(op, l1);
Paul's avatar
Paul committed
889
    auto shrink_axis = 1;
890
    mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {shrink_axis}}}), l2);
Paul's avatar
Paul committed
891
    auto prog = optimize_tf("stridedslice_test.pb", true);
Khalique's avatar
Khalique committed
892
893
894
895

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
896
897
898
TEST_CASE(stridedslice_masks_test)
{
    migraphx::program p;
899
900
901

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 3, 3}});
Khalique's avatar
Khalique committed
902
903
    std::size_t num_axes = 4;
    migraphx::op::slice op;
904
905
    op.starts = {0, 1, 1, 0};
    op.ends   = {1, 3, 3, 10};
Khalique's avatar
Khalique committed
906
907
908
    op.axes   = std::vector<int64_t>(num_axes);
    std::iota(op.axes.begin(), op.axes.end(), 0);
    // add literals for starts, ends, and strides in tf (NHWC format)
909
910
911
912
913
914
915
    mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}},
                    std::vector<int>{0, 1, 1, 0});
    mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}},
                    std::vector<int>{0, 0, 0, 0});
    mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}},
                    std::vector<int>{1, 1, 1, 1});

916
    auto l1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0);
917
    auto l2 = mm->add_instruction(op, l1);
kahmed10's avatar
kahmed10 committed
918
919
    auto l3 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 3, 1, 2}}}), l2);
    mm->add_return({l3});
920
    auto prog = parse_tf("stridedslice_masks_test.pb", true);
Khalique's avatar
Khalique committed
921
922
923
924

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
925
926
927
TEST_CASE(sub_test)
{
    migraphx::program p;
928
929
930
931

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
kahmed10's avatar
kahmed10 committed
932
933
    auto l2  = mm->add_instruction(migraphx::make_op("sub"), l0, l1);
    mm->add_return({l2});
934
    auto prog = parse_tf("sub_test.pb", false);
Khalique's avatar
Khalique committed
935
936
937
938

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
939
940
941
TEST_CASE(tanh_test)
{
    migraphx::program p;
942
943

    auto* mm = p.get_main_module();
kahmed10's avatar
kahmed10 committed
944
945
946
947
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
    auto l1  = mm->add_instruction(migraphx::make_op("tanh"), l0);
    mm->add_return({l1});
    auto prog = parse_tf("tanh_test.pb", false);
Khalique's avatar
Khalique committed
948
949
950
951

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
952
953
954
TEST_CASE(transpose_test)
{
    migraphx::program p;
955
956
957

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
Khalique's avatar
Khalique committed
958
    migraphx::shape s0{migraphx::shape::int32_type, {4}};
959
    mm->add_literal(migraphx::literal{s0, {0, 2, 3, 1}});
960
    mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0);
Khalique's avatar
Khalique committed
961
962
963
964
965
    auto prog = optimize_tf("transpose_test.pb", false);

    EXPECT(p == prog);
}

966
967
968
TEST_CASE(variable_batch_test)
{
    migraphx::program p;
969
970
971

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
972
    mm->add_instruction(migraphx::make_op("identity"), l0);
973
974
975
976
977
    auto prog = optimize_tf("variable_batch_test.pb", false);

    EXPECT(p == prog);
}

978
int main(int argc, const char* argv[]) { test::run(argc, argv); }