tf_test.cpp 32.6 KB
Newer Older
1
2
#include <iostream>
#include <vector>
Shucai Xiao's avatar
Shucai Xiao committed
3
#include <unordered_map>
4
#include <migraphx/literal.hpp>
Paul's avatar
Paul committed
5
6
7
8
#include <migraphx/pass_manager.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp>
9
10
11
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/tf.hpp>
12
#include <migraphx/make_op.hpp>
turneram's avatar
turneram committed
13
14
15
16
17
#include <migraphx/op/batch_norm_inference.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/reduce_mean.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/slice.hpp>
18
19
20

#include <migraphx/serialize.hpp>

21
22
#include "test.hpp"

Shucai Xiao's avatar
Shucai Xiao committed
23
24
25
migraphx::program
parse_tf(const std::string& name,
         bool is_nhwc,
kahmed10's avatar
kahmed10 committed
26
27
         const std::unordered_map<std::string, std::vector<std::size_t>>& dim_params = {},
         const std::vector<std::string>& output_node_names                           = {})
28
{
kahmed10's avatar
kahmed10 committed
29
30
    return migraphx::parse_tf(name,
                              migraphx::tf_options{is_nhwc, 1, dim_params, output_node_names});
31
32
}

Paul's avatar
Paul committed
33
34
migraphx::program optimize_tf(const std::string& name, bool is_nhwc)
{
35
    auto prog = migraphx::parse_tf(name, migraphx::tf_options{is_nhwc, 1});
36
    auto* mm  = prog.get_main_module();
Paul's avatar
Paul committed
37
    if(is_nhwc)
38
        migraphx::run_passes(*mm,
Paul's avatar
Paul committed
39
40
41
                             {migraphx::simplify_reshapes{},
                              migraphx::dead_code_elimination{},
                              migraphx::eliminate_identity{}});
kahmed10's avatar
kahmed10 committed
42
43
44
45

    // remove the last return instruction
    auto last_ins = std::prev(mm->end());
    if(last_ins != mm->end())
kahmed10's avatar
kahmed10 committed
46
    {
kahmed10's avatar
kahmed10 committed
47
48
49
50
        if(last_ins->name() == "@return")
        {
            mm->remove_instruction(last_ins);
        }
kahmed10's avatar
kahmed10 committed
51
    }
Paul's avatar
Paul committed
52
53
54
    return prog;
}

55
56
57
TEST_CASE(add_test)
{
    migraphx::program p;
58
59
60
61

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
62
    mm->add_instruction(migraphx::make_op("add"), l0, l1);
Paul's avatar
Paul committed
63
    auto prog = optimize_tf("add_test.pb", false);
64
65
66
67

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
68
69
70
TEST_CASE(addv2_test)
{
    migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
71
72
73
74
    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    mm->add_instruction(migraphx::make_op("add"), l0, l1);
kahmed10's avatar
kahmed10 committed
75
76
77
78
79
    auto prog = optimize_tf("addv2_test.pb", false);

    EXPECT(p == prog);
}

80
81
TEST_CASE(add_bcast_test)
{
Khalique's avatar
Khalique committed
82

83
    migraphx::program p;
84
85

    auto* mm = p.get_main_module();
86
    migraphx::shape s0{migraphx::shape::float_type, {2, 3}};
87
88
    auto l0 = mm->add_parameter("0", s0);
    auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2, 1}});
89
    auto l2 =
90
        mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s0.lens()}}), l1);
kahmed10's avatar
kahmed10 committed
91
    mm->add_instruction(migraphx::make_op("add"), l0, l2);
Paul's avatar
Paul committed
92
    auto prog = optimize_tf("add_bcast_test.pb", false);
93
94
95
96

    EXPECT(p == prog);
}

97
98
99
TEST_CASE(argmax_test)
{
    migraphx::program p;
100
101

    auto* mm = p.get_main_module();
Shucai Xiao's avatar
Shucai Xiao committed
102
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 5, 6, 7}});
103
    mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {2}});
104
    auto ins = mm->add_instruction(migraphx::make_op("argmax", {{"axis", 2}}), l0);
kahmed10's avatar
kahmed10 committed
105
106
    auto l1  = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), ins);
    mm->add_return({l1});
Shucai Xiao's avatar
Shucai Xiao committed
107
    auto prog = parse_tf("argmax_test.pb", false, {{"0", {4, 5, 6, 7}}});
108
109
110
111
112
113
114

    EXPECT(p == prog);
}

TEST_CASE(argmin_test)
{
    migraphx::program p;
115
116
117
118

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
    mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {2}});
119
    auto ins = mm->add_instruction(migraphx::make_op("argmin", {{"axis", 2}}), l0);
kahmed10's avatar
kahmed10 committed
120
121
    auto l1  = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), ins);
    mm->add_return({l1});
122
123
124
125
126
    auto prog = parse_tf("argmin_test.pb", false);

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
127
128
129
TEST_CASE(assert_less_equal_test)
{
    migraphx::program p;
130
131

    auto* mm = p.get_main_module();
Khalique's avatar
Khalique committed
132
    migraphx::shape s0{migraphx::shape::float_type, {2, 3}};
133
134
    auto l0 = mm->add_parameter("0", s0);
    auto l1 = mm->add_parameter("1", s0);
Khalique's avatar
Khalique committed
135
    migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {0, 1}};
136
    auto l2 = mm->add_literal(l);
137
138
139
    mm->add_instruction(migraphx::make_op("add"), l0, l1);
    auto l3 = mm->add_instruction(migraphx::make_op("identity"), l0, l1);
    mm->add_instruction(migraphx::make_op("identity"), l3, l2);
Khalique's avatar
Khalique committed
140
141
142
143
144
    auto prog = optimize_tf("assert_less_equal_test.pb", false);

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
145
146
147
148
TEST_CASE(batchmatmul_test)
{
    migraphx::program p;

149
150
151
152
    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 4}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 4, 8}});

153
    auto trans_l0 =
154
        mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l0);
155
    auto trans_l1 =
156
        mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1);
Khalique's avatar
Khalique committed
157

158
    mm->add_instruction(migraphx::make_op("dot"), trans_l0, trans_l1);
Khalique's avatar
Khalique committed
159
160
161
162
163
    auto prog = optimize_tf("batchmatmul_test.pb", false);

    EXPECT(p == prog);
}

164
165
TEST_CASE(batchnorm_test)
{
Khalique's avatar
Khalique committed
166
167
    float epsilon  = 1.001e-5f;
    float momentum = 0.9f;
168
169

    migraphx::program p;
170
171

    auto* mm = p.get_main_module();
Khalique's avatar
Khalique committed
172
173
    migraphx::op::batch_norm_inference op{
        epsilon, momentum, migraphx::op::batch_norm_inference::spatial};
174
    migraphx::shape s0{migraphx::shape::float_type, {32}};
175
    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 32, 16, 16}});
176
177
    std::vector<float> const_vals(32);
    std::fill(const_vals.begin(), const_vals.end(), 1.0f);
Khalique's avatar
Khalique committed
178

179
180
181
182
183
    auto l2 = mm->add_parameter("2", s0);
    auto l3 = mm->add_parameter("3", s0);
    auto l4 = mm->add_parameter("4", s0);
    auto l1 = mm->add_literal(migraphx::literal{s0, const_vals});
    mm->add_instruction(op, l0, l1, l2, l3, l4);
Paul's avatar
Paul committed
184
    auto prog = optimize_tf("batchnorm_test.pb", true);
185
186
187
188

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
189
190
191
192
193
194
TEST_CASE(batchnormv3_test)
{
    float epsilon  = 1.0e-5f;
    float momentum = 0.9f;

    migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
195
    auto* mm = p.get_main_module();
kahmed10's avatar
kahmed10 committed
196
197
198
    migraphx::op::batch_norm_inference op{
        epsilon, momentum, migraphx::op::batch_norm_inference::spatial};
    migraphx::shape s0{migraphx::shape::float_type, {32}};
Shucai Xiao's avatar
Shucai Xiao committed
199
    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 32, 16, 16}});
kahmed10's avatar
kahmed10 committed
200
201
202
    std::vector<float> const_vals(32);
    std::fill(const_vals.begin(), const_vals.end(), 1.0f);

Shucai Xiao's avatar
Shucai Xiao committed
203
204
205
206
207
    auto l2 = mm->add_parameter("2", s0);
    auto l3 = mm->add_parameter("3", s0);
    auto l4 = mm->add_parameter("4", s0);
    auto l1 = mm->add_literal(migraphx::literal{s0, const_vals});
    mm->add_instruction(op, l0, l1, l2, l3, l4);
kahmed10's avatar
kahmed10 committed
208
209
210
211
212
    auto prog = optimize_tf("batchnormv3_test.pb", true);

    EXPECT(p == prog);
}

213
214
215
TEST_CASE(biasadd_test)
{
    migraphx::program p;
216
217

    auto* mm = p.get_main_module();
218
    migraphx::shape s0{migraphx::shape::float_type, {1, 500, 1, 1}};
219
    uint64_t axis = 1;
220
221
    auto l0       = mm->add_parameter("0", s0);
    auto l1       = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {500}});
222
    auto l2       = mm->add_instruction(
223
        migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l0->get_shape().lens()}}), l1);
224
    mm->add_instruction(migraphx::make_op("add"), l0, l2);
Paul's avatar
Paul committed
225
    auto prog = optimize_tf("biasadd_test.pb", true);
226
227
228
229

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
230
231
232
233
234
235
236
237
238
239
TEST_CASE(biasadd_scalar_test)
{
    migraphx::program p;

    auto* mm = p.get_main_module();
    migraphx::shape s0{migraphx::shape::float_type, {1, 1}};
    uint64_t axis = 1;
    auto l0       = mm->add_parameter("0", s0);
    auto l1       = mm->add_literal(
        migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}, {0}}, {1.0}});
240
    auto l2 = mm->add_instruction(
241
        migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l0->get_shape().lens()}}), l1);
242
    mm->add_instruction(migraphx::make_op("add"), l0, l2);
kahmed10's avatar
kahmed10 committed
243
244
245
246
247
    auto prog = optimize_tf("biasadd_scalar_test.pb", true);

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
248
249
250
TEST_CASE(cast_test)
{
    migraphx::program p;
251
252
253

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
254
255
256
257
    mm->add_instruction(
        migraphx::make_op("convert",
                          {{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}),
        l0);
Khalique's avatar
Khalique committed
258
259
260
261
262
    auto prog = optimize_tf("cast_test.pb", false);

    EXPECT(p == prog);
}

263
264
265
TEST_CASE(concat_test)
{
    migraphx::program p;
Khalique's avatar
Khalique committed
266

267
268
269
270
    auto* mm = p.get_main_module();

    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 7, 3}});
    auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}});
271
272
273

    int axis = 1;
    // tf uses axis as the third input, and it is in int32 format
Khalique's avatar
Khalique committed
274
    // add the literal using a vector in order to set stride to 1 (like in tf parser)
275
    mm->add_literal(migraphx::shape{migraphx::shape::int32_type}, std::vector<int>{axis});
276

277
    mm->add_instruction(migraphx::make_op("concat", {{"axis", axis}}), l0, l1);
Paul's avatar
Paul committed
278
    auto prog = optimize_tf("concat_test.pb", false);
279
280
281
282
283
284
285

    EXPECT(p == prog);
}

TEST_CASE(const_test)
{
    migraphx::program p;
286
287
288

    auto* mm = p.get_main_module();
    mm->add_literal(migraphx::shape{migraphx::shape::float_type}, std::vector<float>{1.0f});
Paul's avatar
Paul committed
289
    auto prog = optimize_tf("constant_test.pb", false);
290
291
292
293

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
294
migraphx::program create_conv()
295
296
{
    migraphx::program p;
Khalique's avatar
Khalique committed
297

298
299
300
    auto* mm = p.get_main_module();

    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
Khalique's avatar
Khalique committed
301
    std::vector<float> weight_data(3 * 3 * 3 * 32);
302
    std::fill(weight_data.begin(), weight_data.end(), 1.0f);
Khalique's avatar
Khalique committed
303
    auto l1 =
304
        mm->add_literal(migraphx::shape{migraphx::shape::float_type, {3, 3, 3, 32}}, weight_data);
305
306
307

    migraphx::op::convolution op;
    op.padding_mode = migraphx::op::padding_mode_t::same;
kahmed10's avatar
kahmed10 committed
308
    op.padding      = {1, 1, 1, 1};
Khalique's avatar
Khalique committed
309
310
    op.stride       = {1, 1};
    op.dilation     = {1, 1};
311
312
    auto l2 =
        mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {3, 2, 0, 1}}}), l1);
313
    mm->add_instruction(op, l0, l2);
kahmed10's avatar
kahmed10 committed
314
315
316
317
318
319
320
321
322
323
324
    return p;
}

TEST_CASE(conv_test)
{
    migraphx::program p = create_conv();
    auto prog           = optimize_tf("conv_test.pb", true);

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
325
326
327
328
329
330
331
332
333
334
335
TEST_CASE(conv_add_test)
{
    migraphx::program p = create_conv();
    auto* mm            = p.get_main_module();
    auto l0             = std::prev(mm->end());
    mm->add_instruction(migraphx::make_op("add"), l0, l0);
    auto prog = optimize_tf("conv_add_test.pb", true);

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
336
337
338
339
TEST_CASE(conv_nchw_test)
{
    migraphx::program p = create_conv();
    auto prog           = optimize_tf("conv_nchw_test.pb", false);
340
341
342
343

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
TEST_CASE(conv_relu_test)
{
    migraphx::program p = create_conv();
    auto* mm            = p.get_main_module();
    auto l0             = std::prev(mm->end());
    mm->add_instruction(migraphx::make_op("relu"), l0);
    auto prog = optimize_tf("conv_relu_test.pb", true);

    EXPECT(p == prog);
}

TEST_CASE(conv_relu6_test)
{
    migraphx::program p = create_conv();
    auto* mm            = p.get_main_module();
    std::vector<size_t> input_lens{1, 32, 16, 16};
    auto l0      = std::prev(mm->end());
    auto min_val = mm->add_literal(0.0f);
    auto max_val = mm->add_literal(6.0f);
363
364
365
366
    min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
                                  min_val);
    max_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
                                  max_val);
kahmed10's avatar
kahmed10 committed
367
368
369
370
371
372
    mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val);
    auto prog = optimize_tf("conv_relu6_test.pb", true);

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
373
374
375
376
TEST_CASE(depthwiseconv_test)
{
    migraphx::program p;

377
378
379
    auto* mm = p.get_main_module();

    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
Khalique's avatar
Khalique committed
380
381
382
    std::vector<float> weight_data(3 * 3 * 3 * 1);
    std::fill(weight_data.begin(), weight_data.end(), 1.0f);
    auto l1 =
383
        mm->add_literal(migraphx::shape{migraphx::shape::float_type, {3, 3, 3, 1}}, weight_data);
Khalique's avatar
Khalique committed
384
385
386

    migraphx::op::convolution op;
    op.padding_mode = migraphx::op::padding_mode_t::same;
Khalique's avatar
Khalique committed
387
    op.padding      = {1, 1};
Khalique's avatar
Khalique committed
388
389
390
    op.stride       = {1, 1};
    op.dilation     = {1, 1};
    op.group        = 3;
391
392
    auto l3 =
        mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {3, 2, 0, 1}}}), l1);
393
394
    auto l4 = mm->add_instruction(migraphx::make_op("contiguous"), l3);
    auto l5 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, 1, 3, 3}}}), l4);
395
    mm->add_instruction(op, l0, l5);
Paul's avatar
Paul committed
396
    auto prog = optimize_tf("depthwise_conv_test.pb", true);
Khalique's avatar
Khalique committed
397
398
399
400

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
401
402
403
TEST_CASE(expanddims_test)
{
    migraphx::program p;
Khalique's avatar
Khalique committed
404

405
406
407
408
    auto* mm = p.get_main_module();

    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4}});
    mm->add_literal(0);
409
    mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 2, 3, 4}}}), l0);
410
    auto prog = optimize_tf("expanddims_test.pb", false);
Khalique's avatar
Khalique committed
411
412
413
414
415
416
417
418
419

    EXPECT(p == prog);
}

TEST_CASE(expanddims_test_neg_dims)
{
    // this check makes sure the pb parses negative dim value correctly
    migraphx::program p;

420
421
422
423
    auto* mm = p.get_main_module();

    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4}});
    mm->add_literal(-1);
424
    mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 3, 4, 1}}}), l0);
425
    auto prog = optimize_tf("expanddims_neg_test.pb", false);
Khalique's avatar
Khalique committed
426
427
428
429

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
430
431
432
433
TEST_CASE(gather_test)
{
    migraphx::program p;

434
435
436
437
438
439
    auto* mm = p.get_main_module();

    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 4}});
    auto l1 = mm->add_literal(
        migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 1}});
    mm->add_literal(1);
Khalique's avatar
Khalique committed
440
441

    int axis = 1;
442
    mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), l0, l1);
Khalique's avatar
Khalique committed
443
444
    auto prog = optimize_tf("gather_test.pb", false);

Khalique's avatar
Khalique committed
445
446
447
    EXPECT(p == prog);
}

448
449
450
TEST_CASE(identity_test)
{
    migraphx::program p;
451
452
453

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
454
    mm->add_instruction(migraphx::make_op("identity"), l0);
Paul's avatar
Paul committed
455
    auto prog = optimize_tf("identity_test.pb", false);
456
457
458
459

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
460
461
462
TEST_CASE(matmul_test)
{
    migraphx::program p;
Khalique's avatar
Khalique committed
463

464
465
466
    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {8, 4}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 8}});
Khalique's avatar
Khalique committed
467

468
469
470
471
    auto trans_l0 =
        mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l0);
    auto trans_l1 =
        mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
472

473
    mm->add_instruction(migraphx::make_op("dot"), trans_l0, trans_l1);
Paul's avatar
Paul committed
474
    auto prog = optimize_tf("matmul_test.pb", false);
Khalique's avatar
Khalique committed
475
476
477
478

    EXPECT(p == prog);
}

479
480
481
TEST_CASE(mean_test)
{
    migraphx::program p;
482
483

    auto* mm = p.get_main_module();
Khalique's avatar
Khalique committed
484
    migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {2, 3}};
485
486
487
    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
    mm->add_literal(l);
    mm->add_literal(l);
Paul's avatar
Paul committed
488
    migraphx::op::reduce_mean op{{2, 3}};
489
490
    mm->add_instruction(op, l0);
    auto l3 = mm->add_instruction(op, l0);
491
    mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2, 3}}}), l3);
Paul's avatar
Paul committed
492
    auto prog = optimize_tf("mean_test.pb", false);
493
494
495
496
497
498
499

    EXPECT(p == prog);
}

TEST_CASE(mean_test_nhwc)
{
    migraphx::program p;
500
501

    auto* mm = p.get_main_module();
Khalique's avatar
Khalique committed
502
    migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 2}};
503
    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
504
505
    auto l1 =
        mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
Khalique's avatar
Khalique committed
506
    migraphx::op::reduce_mean op{{1, 2}};
507
    auto l2 = mm->add_instruction(op, l1);
508
    mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1, 2}}}), l2);
Paul's avatar
Paul committed
509
    auto prog = optimize_tf("mean_test_nhwc.pb", true);
510
511
512
513

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
514
515
516
517
TEST_CASE(mul_test)
{
    migraphx::program p;

518
519
520
521
    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 16}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 16}});

522
    mm->add_instruction(migraphx::make_op("mul"), l0, l1);
Paul's avatar
Paul committed
523
    auto prog = optimize_tf("mul_test.pb", false);
Khalique's avatar
Khalique committed
524
525
526
527

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
TEST_CASE(multi_output_test)
{
    migraphx::program p;

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
    auto l1  = mm->add_instruction(migraphx::make_op("relu"), l0);
    auto l2  = mm->add_instruction(migraphx::make_op("tanh"), l0);
    mm->add_return({l1, l2});

    EXPECT(test::throws([&] { parse_tf("multi_output_test.pb", false, {}, {"relu", "relu6"}); }));
    auto prog = parse_tf("multi_output_test.pb", false, {}, {"relu", "tanh"});

    EXPECT(p == prog);
}

544
545
546
TEST_CASE(onehot_test)
{
    migraphx::program p;
547
548
549

    auto* mm = p.get_main_module();
    auto l0  = mm->add_literal(
Khalique's avatar
Khalique committed
550
        migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {5}}, {1, 1, 1, 1, 1}});
551
552
553
554
    mm->add_literal(2);
    mm->add_literal(1.0f);
    mm->add_literal(0.0f);
    auto l1 = mm->add_literal(
Khalique's avatar
Khalique committed
555
        migraphx::literal{migraphx::shape{migraphx::shape::float_type, {2, 2}}, {1, 0, 0, 1}});
556
    int axis = 0;
557
    mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), l1, l0);
558
559
560
561
562
    auto prog = optimize_tf("onehot_test.pb", false);

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
563
564
565
566
567
568
569
570
TEST_CASE(noop_test)
{
    migraphx::program p;
    auto prog = optimize_tf("noop_test.pb", false);

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
571
572
573
TEST_CASE(pack_test)
{
    migraphx::program p;
574
575
576
577
578

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2}});
    auto l2  = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {2}});
Khalique's avatar
Khalique committed
579
580
581
582
    std::vector<migraphx::instruction_ref> args{l0, l1, l2};
    std::vector<migraphx::instruction_ref> unsqueezed_args;
    int64_t axis = 1;

583
584
585
586
587
588
589
590
591
    std::transform(
        args.begin(),
        args.end(),
        std::back_inserter(unsqueezed_args),
        [&](migraphx::instruction_ref arg) {
            return mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {axis}}}), arg);
        });
    mm->add_instruction(migraphx::make_op("concat", {{"axis", static_cast<int>(axis)}}),
                        unsqueezed_args);
Paul's avatar
Paul committed
592
    auto prog = optimize_tf("pack_test.pb", false);
Khalique's avatar
Khalique committed
593
594
595
596

    EXPECT(p == prog);
}

597
598
599
TEST_CASE(pack_test_nhwc)
{
    migraphx::program p;
600
601
602

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
603
604
605
606
607
608
609
610
    auto lt0 =
        mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
    auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
    auto lt1 =
        mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l1);
    auto l2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
    auto lt2 =
        mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l2);
Paul's avatar
Paul committed
611
    std::vector<migraphx::instruction_ref> args{lt0, lt1, lt2};
612
    std::vector<migraphx::instruction_ref> unsqueezed_args;
Paul's avatar
Paul committed
613
    int64_t nchw_axis = 3;
614
615
616
617
618

    std::transform(args.begin(),
                   args.end(),
                   std::back_inserter(unsqueezed_args),
                   [&](migraphx::instruction_ref arg) {
619
620
                       return mm->add_instruction(
                           migraphx::make_op("unsqueeze", {{"axes", {nchw_axis}}}), arg);
621
                   });
622
623
    mm->add_instruction(migraphx::make_op("concat", {{"axis", static_cast<int>(nchw_axis)}}),
                        unsqueezed_args);
Paul's avatar
Paul committed
624
    auto prog = optimize_tf("pack_test_nhwc.pb", true);
625
626
627
628

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
TEST_CASE(pad_test)
{
    migraphx::program p;

    auto* mm = p.get_main_module();

    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 4}});
    std::vector<int> pad_literals{1, 1, 2, 2};
    std::vector<int> pads{1, 2, 1, 2};
    mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {2, 2}}, pad_literals);

    mm->add_instruction(migraphx::make_op("pad", {{"pads", pads}}), l0);
    auto prog = optimize_tf("pad_test.pb", false);

    EXPECT(p == prog);
}

646
647
648
TEST_CASE(pooling_test)
{
    migraphx::program p;
649
650
651

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
652
653
    migraphx::op::pooling avg_pool_op{"average"};
    migraphx::op::pooling max_pool_op{"max"};
Shucai Xiao's avatar
Shucai Xiao committed
654
655
656
657
    avg_pool_op.stride  = {2, 2};
    max_pool_op.stride  = {2, 2};
    avg_pool_op.lengths = {2, 2};
    max_pool_op.lengths = {2, 2};
kahmed10's avatar
kahmed10 committed
658
    mm->add_instruction(avg_pool_op, l0);
659
    mm->add_instruction(max_pool_op, l0);
Paul's avatar
Paul committed
660
    auto prog = optimize_tf("pooling_test.pb", true);
661
662
663
664

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
665
666
667
TEST_CASE(pow_test)
{
    migraphx::program p;
668
669
670
671

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
672
    mm->add_instruction(migraphx::make_op("pow"), l0, l1);
Khalique's avatar
Khalique committed
673
674
675
676
677
    auto prog = optimize_tf("pow_test.pb", false);

    EXPECT(p == prog);
}

678
679
680
TEST_CASE(relu_test)
{
    migraphx::program p;
681
682
683

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
684
    mm->add_instruction(migraphx::make_op("relu"), l0);
Paul's avatar
Paul committed
685
    auto prog = optimize_tf("relu_test.pb", false);
686
687
688
689

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
690
691
692
TEST_CASE(relu6_test)
{
    migraphx::program p;
693
694

    auto* mm = p.get_main_module();
kahmed10's avatar
kahmed10 committed
695
    std::vector<size_t> input_lens{1, 3, 16, 16};
696
697
698
    auto l0      = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, input_lens});
    auto min_val = mm->add_literal(0.0f);
    auto max_val = mm->add_literal(6.0f);
699
700
701
702
    min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
                                  min_val);
    max_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
                                  max_val);
703
    mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val);
Paul's avatar
Paul committed
704
    auto prog = optimize_tf("relu6_test.pb", false);
Khalique's avatar
Khalique committed
705
706
707
708

    EXPECT(p == prog);
}

709
710
711
TEST_CASE(reshape_test)
{
    migraphx::program p;
712
713
714

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {16}});
715
716
    migraphx::shape s0{migraphx::shape::int32_type, {4}};
    // in tf, the second arg is a literal that contains new dimensions
717
    mm->add_literal(migraphx::literal{s0, {1, 1, 1, 16}});
718
    mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 1, 1, 16}}}), l0);
Paul's avatar
Paul committed
719
    auto prog = optimize_tf("reshape_test.pb", false);
720
721
722
723

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
724
725
726
TEST_CASE(rsqrt_test)
{
    migraphx::program p;
727
728
729

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
730
    mm->add_instruction(migraphx::make_op("rsqrt"), l0);
Khalique's avatar
Khalique committed
731
732
733
734
735
    auto prog = optimize_tf("rsqrt_test.pb", false);

    EXPECT(p == prog);
}

736
737
738
TEST_CASE(shape_test)
{
    migraphx::program p;
739
740
741
742

    auto* mm = p.get_main_module();
    mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
    mm->add_literal(
743
744
745
746
747
748
        migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {4}}, {1, 3, 16, 16}});
    auto prog = optimize_tf("shape_test.pb", false);

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
749
750
751
TEST_CASE(slice_test)
{
    migraphx::program p;
752
753

    auto* mm             = p.get_main_module();
Khalique's avatar
Khalique committed
754
    std::size_t num_axes = 2;
755
    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 10}});
Khalique's avatar
Khalique committed
756
    migraphx::shape s0{migraphx::shape::int32_type, {num_axes}};
757
758
    mm->add_literal(migraphx::literal{s0, {1, 0}});
    mm->add_literal(migraphx::literal{s0, {2, -1}});
Khalique's avatar
Khalique committed
759
760
761
762
763
764

    migraphx::op::slice op;
    op.starts = {1, 0};
    op.ends   = {3, 10};
    op.axes   = std::vector<int64_t>(num_axes);
    std::iota(op.axes.begin(), op.axes.end(), 0);
765
    mm->add_instruction(op, l0);
Khalique's avatar
Khalique committed
766
767
768
769
770
    auto prog = optimize_tf("slice_test.pb", false);

    EXPECT(p == prog);
}

771
772
773
TEST_CASE(softmax_test)
{
    migraphx::program p;
774
775
776

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3}});
777
    mm->add_instruction(migraphx::make_op("softmax", {{"axis", 1}}), l0);
Paul's avatar
Paul committed
778
    auto prog = optimize_tf("softmax_test.pb", false);
779
780
781
782

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
783
784
785
TEST_CASE(split_test)
{
    migraphx::program p;
786
787

    auto* mm = p.get_main_module();
kahmed10's avatar
kahmed10 committed
788
    std::vector<int64_t> axes{0, 1};
789
790
791
792
793
    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}});
    mm->add_literal(3); // num_splits
    mm->add_literal(1); // split axis
    mm->add_literal(1); // concat axis
    mm->add_literal(1); // concat axis
794
795
796
797
798
799
    auto l1 = mm->add_instruction(
        migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 0}}, {"ends", {5, 10}}}), l0);
    auto l2 = mm->add_instruction(
        migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 10}}, {"ends", {5, 20}}}), l0);
    auto l3 = mm->add_instruction(
        migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 20}}, {"ends", {5, 30}}}), l0);
kahmed10's avatar
kahmed10 committed
800
801
802
    auto l4 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l1, l2);
    auto l5 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l3);
    mm->add_return({l4, l5});
803
    auto prog = parse_tf("split_test.pb", false);
kahmed10's avatar
kahmed10 committed
804
805
806
807
808
809
810

    EXPECT(p == prog);
}

TEST_CASE(split_test_one_output)
{
    migraphx::program p;
811
812
813
814
815

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}});
    mm->add_literal(1); // num_splits
    mm->add_literal(1); // split axis
kahmed10's avatar
kahmed10 committed
816
817
    auto l1 = mm->add_instruction(migraphx::make_op("identity"), l0);
    mm->add_return({l1});
818
    auto prog = parse_tf("split_test_one_output.pb", false);
kahmed10's avatar
kahmed10 committed
819
820
821
822
823
824
825

    EXPECT(p == prog);
}

TEST_CASE(split_test_vector_as_input)
{
    migraphx::program p;
826
827

    auto* mm = p.get_main_module();
kahmed10's avatar
kahmed10 committed
828
    std::vector<int64_t> axes{0, 1};
829
    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}});
kahmed10's avatar
kahmed10 committed
830
    // split sizes
831
    mm->add_literal(
kahmed10's avatar
kahmed10 committed
832
        migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {3}}, {4, 15, 11}});
833
834
835
    mm->add_literal(1); // split axis
    mm->add_literal(1); // concat axis
    mm->add_literal(1); // concat axis
836
837
838
839
840
841
    auto l1 = mm->add_instruction(
        migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 0}}, {"ends", {5, 4}}}), l0);
    auto l2 = mm->add_instruction(
        migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 4}}, {"ends", {5, 19}}}), l0);
    auto l3 = mm->add_instruction(
        migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 19}}, {"ends", {5, 30}}}), l0);
kahmed10's avatar
kahmed10 committed
842
843
844
    auto l4 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l1, l2);
    auto l5 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l3);
    mm->add_return({l4, l5});
845
    auto prog = parse_tf("split_test_vector_as_input.pb", false);
kahmed10's avatar
kahmed10 committed
846
847
848
849

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
850
851
852
TEST_CASE(sqdiff_test)
{
    migraphx::program p;
853
854
855
856

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
857
    mm->add_instruction(migraphx::make_op("sqdiff"), l0, l1);
Khalique's avatar
Khalique committed
858
859
860
861
862
    auto prog = optimize_tf("sqdiff_test.pb", false);

    EXPECT(p == prog);
}

863
864
865
TEST_CASE(squeeze_test)
{
    migraphx::program p;
866
867
868

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 1}});
869
    mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0, 3}}}), l0);
Paul's avatar
Paul committed
870
    auto prog = optimize_tf("squeeze_test.pb", false);
871
872
873

    EXPECT(p == prog);
}
Khalique's avatar
Khalique committed
874

Khalique's avatar
Khalique committed
875
876
877
TEST_CASE(stopgradient_test)
{
    migraphx::program p;
878
879
880

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
881
    mm->add_instruction(migraphx::make_op("identity"), l0);
Khalique's avatar
Khalique committed
882
883
    auto prog = optimize_tf("stopgradient_test.pb", false);

Khalique's avatar
Khalique committed
884
    EXPECT(p == prog);
Khalique's avatar
Khalique committed
885
886
}

Khalique's avatar
Khalique committed
887
888
889
TEST_CASE(stridedslice_test)
{
    migraphx::program p;
890
891
892

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 1, 1}});
893
894
    auto l1 =
        mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
Khalique's avatar
Khalique committed
895
896
    std::size_t num_axes = 4;
    migraphx::op::slice op;
Khalique's avatar
Khalique committed
897
    op.starts = {0, 0, 0, 0};
Paul's avatar
Paul committed
898
    op.ends   = {1, 1, 1, 5};
Khalique's avatar
Khalique committed
899
900
    op.axes   = std::vector<int64_t>(num_axes);
    std::iota(op.axes.begin(), op.axes.end(), 0);
901
    auto l2          = mm->add_instruction(op, l1);
Paul's avatar
Paul committed
902
    auto shrink_axis = 1;
903
    mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {shrink_axis}}}), l2);
Paul's avatar
Paul committed
904
    auto prog = optimize_tf("stridedslice_test.pb", true);
Khalique's avatar
Khalique committed
905
906
907
908

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
909
910
911
TEST_CASE(stridedslice_masks_test)
{
    migraphx::program p;
912
913
914

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 3, 3}});
Khalique's avatar
Khalique committed
915
916
    std::size_t num_axes = 4;
    migraphx::op::slice op;
917
918
    op.starts = {0, 1, 1, 0};
    op.ends   = {1, 3, 3, 10};
Khalique's avatar
Khalique committed
919
920
921
    op.axes   = std::vector<int64_t>(num_axes);
    std::iota(op.axes.begin(), op.axes.end(), 0);
    // add literals for starts, ends, and strides in tf (NHWC format)
922
923
924
925
926
927
928
    mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}},
                    std::vector<int>{0, 1, 1, 0});
    mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}},
                    std::vector<int>{0, 0, 0, 0});
    mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}},
                    std::vector<int>{1, 1, 1, 1});

929
930
    auto l1 =
        mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
931
    auto l2 = mm->add_instruction(op, l1);
932
933
    auto l3 =
        mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), l2);
kahmed10's avatar
kahmed10 committed
934
    mm->add_return({l3});
935
    auto prog = parse_tf("stridedslice_masks_test.pb", true);
Khalique's avatar
Khalique committed
936
937
938
939

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
940
941
942
TEST_CASE(sub_test)
{
    migraphx::program p;
943
944
945
946

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
kahmed10's avatar
kahmed10 committed
947
948
    auto l2  = mm->add_instruction(migraphx::make_op("sub"), l0, l1);
    mm->add_return({l2});
949
    auto prog = parse_tf("sub_test.pb", false);
Khalique's avatar
Khalique committed
950
951
952
953

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
954
955
956
TEST_CASE(tanh_test)
{
    migraphx::program p;
957
958

    auto* mm = p.get_main_module();
kahmed10's avatar
kahmed10 committed
959
960
961
962
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
    auto l1  = mm->add_instruction(migraphx::make_op("tanh"), l0);
    mm->add_return({l1});
    auto prog = parse_tf("tanh_test.pb", false);
Khalique's avatar
Khalique committed
963
964
965
966

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
967
968
969
TEST_CASE(transpose_test)
{
    migraphx::program p;
970
971
972

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
Khalique's avatar
Khalique committed
973
    migraphx::shape s0{migraphx::shape::int32_type, {4}};
974
    mm->add_literal(migraphx::literal{s0, {0, 2, 3, 1}});
975
    mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
Khalique's avatar
Khalique committed
976
977
978
979
980
    auto prog = optimize_tf("transpose_test.pb", false);

    EXPECT(p == prog);
}

981
982
983
TEST_CASE(variable_batch_test)
{
    migraphx::program p;
984
985
986

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
987
    mm->add_instruction(migraphx::make_op("identity"), l0);
988
989
990
991
992
    auto prog = optimize_tf("variable_batch_test.pb", false);

    EXPECT(p == prog);
}

993
int main(int argc, const char* argv[]) { test::run(argc, argv); }