tf_test.cpp 32.6 KB
Newer Older
1
2
#include <iostream>
#include <vector>
Shucai Xiao's avatar
Shucai Xiao committed
3
#include <unordered_map>
4
#include <migraphx/literal.hpp>
Paul's avatar
Paul committed
5
6
7
8
#include <migraphx/pass_manager.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp>
9
10
11
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/tf.hpp>
12
#include <migraphx/make_op.hpp>
turneram's avatar
turneram committed
13
14
15
16
17
#include <migraphx/op/batch_norm_inference.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/reduce_mean.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/slice.hpp>
18
19
20

#include <migraphx/serialize.hpp>

21
22
#include "test.hpp"

Shucai Xiao's avatar
Shucai Xiao committed
23
24
25
26
migraphx::program parse_tf(const std::string& name,
                           bool is_nhwc,
                           const std::unordered_map<std::string, std::vector<int>>& dim_params = {},
                           const std::vector<std::string>& output_node_names                   = {})
27
{
kahmed10's avatar
kahmed10 committed
28
29
    return migraphx::parse_tf(name,
                              migraphx::tf_options{is_nhwc, 1, dim_params, output_node_names});
30
31
}

Paul's avatar
Paul committed
32
33
migraphx::program optimize_tf(const std::string& name, bool is_nhwc)
{
34
    auto prog = migraphx::parse_tf(name, migraphx::tf_options{is_nhwc, 1});
35
    auto* mm  = prog.get_main_module();
Paul's avatar
Paul committed
36
    if(is_nhwc)
37
        migraphx::run_passes(*mm,
Paul's avatar
Paul committed
38
39
40
                             {migraphx::simplify_reshapes{},
                              migraphx::dead_code_elimination{},
                              migraphx::eliminate_identity{}});
kahmed10's avatar
kahmed10 committed
41
42
43
44

    // remove the last return instruction
    auto last_ins = std::prev(mm->end());
    if(last_ins != mm->end())
kahmed10's avatar
kahmed10 committed
45
    {
kahmed10's avatar
kahmed10 committed
46
47
48
49
        if(last_ins->name() == "@return")
        {
            mm->remove_instruction(last_ins);
        }
kahmed10's avatar
kahmed10 committed
50
    }
Paul's avatar
Paul committed
51
52
53
    return prog;
}

54
55
56
TEST_CASE(add_test)
{
    migraphx::program p;
57
58
59
60

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
61
    mm->add_instruction(migraphx::make_op("add"), l0, l1);
Paul's avatar
Paul committed
62
    auto prog = optimize_tf("add_test.pb", false);
63
64
65
66

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
67
68
69
TEST_CASE(addv2_test)
{
    migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
70
71
72
73
    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    mm->add_instruction(migraphx::make_op("add"), l0, l1);
kahmed10's avatar
kahmed10 committed
74
75
76
77
78
    auto prog = optimize_tf("addv2_test.pb", false);

    EXPECT(p == prog);
}

79
80
TEST_CASE(add_bcast_test)
{
Khalique's avatar
Khalique committed
81

82
    migraphx::program p;
83
84

    auto* mm = p.get_main_module();
85
    migraphx::shape s0{migraphx::shape::float_type, {2, 3}};
86
87
    auto l0 = mm->add_parameter("0", s0);
    auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2, 1}});
88
    auto l2 =
89
        mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s0.lens()}}), l1);
kahmed10's avatar
kahmed10 committed
90
    mm->add_instruction(migraphx::make_op("add"), l0, l2);
Paul's avatar
Paul committed
91
    auto prog = optimize_tf("add_bcast_test.pb", false);
92
93
94
95

    EXPECT(p == prog);
}

96
97
98
TEST_CASE(argmax_test)
{
    migraphx::program p;
99
100

    auto* mm = p.get_main_module();
Shucai Xiao's avatar
Shucai Xiao committed
101
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 5, 6, 7}});
102
    mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {2}});
103
    auto ins = mm->add_instruction(migraphx::make_op("argmax", {{"axis", 2}}), l0);
kahmed10's avatar
kahmed10 committed
104
105
    auto l1  = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), ins);
    mm->add_return({l1});
Shucai Xiao's avatar
Shucai Xiao committed
106
    auto prog = parse_tf("argmax_test.pb", false, {{"0", {4, 5, 6, 7}}});
107
108
109
110
111
112
113

    EXPECT(p == prog);
}

TEST_CASE(argmin_test)
{
    migraphx::program p;
114
115
116
117

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
    mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {2}});
118
    auto ins = mm->add_instruction(migraphx::make_op("argmin", {{"axis", 2}}), l0);
kahmed10's avatar
kahmed10 committed
119
120
    auto l1  = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), ins);
    mm->add_return({l1});
121
122
123
124
125
    auto prog = parse_tf("argmin_test.pb", false);

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
126
127
128
TEST_CASE(assert_less_equal_test)
{
    migraphx::program p;
129
130

    auto* mm = p.get_main_module();
Khalique's avatar
Khalique committed
131
    migraphx::shape s0{migraphx::shape::float_type, {2, 3}};
132
133
    auto l0 = mm->add_parameter("0", s0);
    auto l1 = mm->add_parameter("1", s0);
Khalique's avatar
Khalique committed
134
    migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {0, 1}};
135
    auto l2 = mm->add_literal(l);
136
137
138
    mm->add_instruction(migraphx::make_op("add"), l0, l1);
    auto l3 = mm->add_instruction(migraphx::make_op("identity"), l0, l1);
    mm->add_instruction(migraphx::make_op("identity"), l3, l2);
Khalique's avatar
Khalique committed
139
140
141
142
143
    auto prog = optimize_tf("assert_less_equal_test.pb", false);

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
144
145
146
147
TEST_CASE(batchmatmul_test)
{
    migraphx::program p;

148
149
150
151
    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 4}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 4, 8}});

152
    auto trans_l0 =
153
        mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l0);
154
    auto trans_l1 =
155
        mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1);
Khalique's avatar
Khalique committed
156

157
    mm->add_instruction(migraphx::make_op("dot"), trans_l0, trans_l1);
Khalique's avatar
Khalique committed
158
159
160
161
162
    auto prog = optimize_tf("batchmatmul_test.pb", false);

    EXPECT(p == prog);
}

163
164
TEST_CASE(batchnorm_test)
{
Khalique's avatar
Khalique committed
165
166
    float epsilon  = 1.001e-5f;
    float momentum = 0.9f;
167
168

    migraphx::program p;
169
170

    auto* mm = p.get_main_module();
Khalique's avatar
Khalique committed
171
172
    migraphx::op::batch_norm_inference op{
        epsilon, momentum, migraphx::op::batch_norm_inference::spatial};
173
    migraphx::shape s0{migraphx::shape::float_type, {32}};
174
    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 32, 16, 16}});
175
176
    std::vector<float> const_vals(32);
    std::fill(const_vals.begin(), const_vals.end(), 1.0f);
Khalique's avatar
Khalique committed
177

178
179
180
181
182
    auto l2 = mm->add_parameter("2", s0);
    auto l3 = mm->add_parameter("3", s0);
    auto l4 = mm->add_parameter("4", s0);
    auto l1 = mm->add_literal(migraphx::literal{s0, const_vals});
    mm->add_instruction(op, l0, l1, l2, l3, l4);
Paul's avatar
Paul committed
183
    auto prog = optimize_tf("batchnorm_test.pb", true);
184
185
186
187

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
188
189
190
191
192
193
TEST_CASE(batchnormv3_test)
{
    float epsilon  = 1.0e-5f;
    float momentum = 0.9f;

    migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
194
    auto* mm = p.get_main_module();
kahmed10's avatar
kahmed10 committed
195
196
197
    migraphx::op::batch_norm_inference op{
        epsilon, momentum, migraphx::op::batch_norm_inference::spatial};
    migraphx::shape s0{migraphx::shape::float_type, {32}};
Shucai Xiao's avatar
Shucai Xiao committed
198
    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 32, 16, 16}});
kahmed10's avatar
kahmed10 committed
199
200
201
    std::vector<float> const_vals(32);
    std::fill(const_vals.begin(), const_vals.end(), 1.0f);

Shucai Xiao's avatar
Shucai Xiao committed
202
203
204
205
206
    auto l2 = mm->add_parameter("2", s0);
    auto l3 = mm->add_parameter("3", s0);
    auto l4 = mm->add_parameter("4", s0);
    auto l1 = mm->add_literal(migraphx::literal{s0, const_vals});
    mm->add_instruction(op, l0, l1, l2, l3, l4);
kahmed10's avatar
kahmed10 committed
207
208
209
210
211
    auto prog = optimize_tf("batchnormv3_test.pb", true);

    EXPECT(p == prog);
}

212
213
214
TEST_CASE(biasadd_test)
{
    migraphx::program p;
215
216

    auto* mm = p.get_main_module();
217
    migraphx::shape s0{migraphx::shape::float_type, {1, 500, 1, 1}};
218
    uint64_t axis = 1;
219
220
    auto l0       = mm->add_parameter("0", s0);
    auto l1       = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {500}});
221
    auto l2       = mm->add_instruction(
222
        migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l0->get_shape().lens()}}), l1);
223
    mm->add_instruction(migraphx::make_op("add"), l0, l2);
Paul's avatar
Paul committed
224
    auto prog = optimize_tf("biasadd_test.pb", true);
225
226
227
228

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
229
230
231
232
233
234
235
236
237
238
TEST_CASE(biasadd_scalar_test)
{
    migraphx::program p;

    auto* mm = p.get_main_module();
    migraphx::shape s0{migraphx::shape::float_type, {1, 1}};
    uint64_t axis = 1;
    auto l0       = mm->add_parameter("0", s0);
    auto l1       = mm->add_literal(
        migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}, {0}}, {1.0}});
239
    auto l2 = mm->add_instruction(
240
        migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l0->get_shape().lens()}}), l1);
241
    mm->add_instruction(migraphx::make_op("add"), l0, l2);
kahmed10's avatar
kahmed10 committed
242
243
244
245
246
    auto prog = optimize_tf("biasadd_scalar_test.pb", true);

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
247
248
249
TEST_CASE(cast_test)
{
    migraphx::program p;
250
251
252

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
253
254
255
256
    mm->add_instruction(
        migraphx::make_op("convert",
                          {{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}),
        l0);
Khalique's avatar
Khalique committed
257
258
259
260
261
    auto prog = optimize_tf("cast_test.pb", false);

    EXPECT(p == prog);
}

262
263
264
TEST_CASE(concat_test)
{
    migraphx::program p;
Khalique's avatar
Khalique committed
265

266
267
268
269
    auto* mm = p.get_main_module();

    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 7, 3}});
    auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}});
270
271
272

    int axis = 1;
    // tf uses axis as the third input, and it is in int32 format
Khalique's avatar
Khalique committed
273
    // add the literal using a vector in order to set stride to 1 (like in tf parser)
274
    mm->add_literal(migraphx::shape{migraphx::shape::int32_type}, std::vector<int>{axis});
275

276
    mm->add_instruction(migraphx::make_op("concat", {{"axis", axis}}), l0, l1);
Paul's avatar
Paul committed
277
    auto prog = optimize_tf("concat_test.pb", false);
278
279
280
281
282
283
284

    EXPECT(p == prog);
}

TEST_CASE(const_test)
{
    migraphx::program p;
285
286
287

    auto* mm = p.get_main_module();
    mm->add_literal(migraphx::shape{migraphx::shape::float_type}, std::vector<float>{1.0f});
Paul's avatar
Paul committed
288
    auto prog = optimize_tf("constant_test.pb", false);
289
290
291
292

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
293
migraphx::program create_conv()
294
295
{
    migraphx::program p;
Khalique's avatar
Khalique committed
296

297
298
299
    auto* mm = p.get_main_module();

    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
Khalique's avatar
Khalique committed
300
    std::vector<float> weight_data(3 * 3 * 3 * 32);
301
    std::fill(weight_data.begin(), weight_data.end(), 1.0f);
Khalique's avatar
Khalique committed
302
    auto l1 =
303
        mm->add_literal(migraphx::shape{migraphx::shape::float_type, {3, 3, 3, 32}}, weight_data);
304
305
306

    migraphx::op::convolution op;
    op.padding_mode = migraphx::op::padding_mode_t::same;
kahmed10's avatar
kahmed10 committed
307
    op.padding      = {1, 1, 1, 1};
Khalique's avatar
Khalique committed
308
309
    op.stride       = {1, 1};
    op.dilation     = {1, 1};
310
311
    auto l2 =
        mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {3, 2, 0, 1}}}), l1);
312
    mm->add_instruction(op, l0, l2);
kahmed10's avatar
kahmed10 committed
313
314
315
316
317
318
319
320
321
322
323
    return p;
}

TEST_CASE(conv_test)
{
    migraphx::program p = create_conv();
    auto prog           = optimize_tf("conv_test.pb", true);

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
324
325
326
327
328
329
330
331
332
333
334
TEST_CASE(conv_add_test)
{
    migraphx::program p = create_conv();
    auto* mm            = p.get_main_module();
    auto l0             = std::prev(mm->end());
    mm->add_instruction(migraphx::make_op("add"), l0, l0);
    auto prog = optimize_tf("conv_add_test.pb", true);

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
335
336
337
338
TEST_CASE(conv_nchw_test)
{
    migraphx::program p = create_conv();
    auto prog           = optimize_tf("conv_nchw_test.pb", false);
339
340
341
342

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
TEST_CASE(conv_relu_test)
{
    migraphx::program p = create_conv();
    auto* mm            = p.get_main_module();
    auto l0             = std::prev(mm->end());
    mm->add_instruction(migraphx::make_op("relu"), l0);
    auto prog = optimize_tf("conv_relu_test.pb", true);

    EXPECT(p == prog);
}

TEST_CASE(conv_relu6_test)
{
    migraphx::program p = create_conv();
    auto* mm            = p.get_main_module();
Shucai Xiao's avatar
Shucai Xiao committed
358
    std::vector<int> input_lens{1, 32, 16, 16};
kahmed10's avatar
kahmed10 committed
359
360
361
    auto l0      = std::prev(mm->end());
    auto min_val = mm->add_literal(0.0f);
    auto max_val = mm->add_literal(6.0f);
362
363
364
365
    min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
                                  min_val);
    max_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
                                  max_val);
kahmed10's avatar
kahmed10 committed
366
367
368
369
370
371
    mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val);
    auto prog = optimize_tf("conv_relu6_test.pb", true);

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
372
373
374
375
TEST_CASE(depthwiseconv_test)
{
    migraphx::program p;

376
377
378
    auto* mm = p.get_main_module();

    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
Khalique's avatar
Khalique committed
379
380
381
    std::vector<float> weight_data(3 * 3 * 3 * 1);
    std::fill(weight_data.begin(), weight_data.end(), 1.0f);
    auto l1 =
382
        mm->add_literal(migraphx::shape{migraphx::shape::float_type, {3, 3, 3, 1}}, weight_data);
Khalique's avatar
Khalique committed
383
384
385

    migraphx::op::convolution op;
    op.padding_mode = migraphx::op::padding_mode_t::same;
Khalique's avatar
Khalique committed
386
    op.padding      = {1, 1};
Khalique's avatar
Khalique committed
387
388
389
    op.stride       = {1, 1};
    op.dilation     = {1, 1};
    op.group        = 3;
390
391
    auto l3 =
        mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {3, 2, 0, 1}}}), l1);
392
393
    auto l4 = mm->add_instruction(migraphx::make_op("contiguous"), l3);
    auto l5 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, 1, 3, 3}}}), l4);
394
    mm->add_instruction(op, l0, l5);
Paul's avatar
Paul committed
395
    auto prog = optimize_tf("depthwise_conv_test.pb", true);
Khalique's avatar
Khalique committed
396
397
398
399

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
400
401
402
TEST_CASE(expanddims_test)
{
    migraphx::program p;
Khalique's avatar
Khalique committed
403

404
405
406
407
    auto* mm = p.get_main_module();

    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4}});
    mm->add_literal(0);
408
    mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 2, 3, 4}}}), l0);
409
    auto prog = optimize_tf("expanddims_test.pb", false);
Khalique's avatar
Khalique committed
410
411
412
413
414
415
416
417
418

    EXPECT(p == prog);
}

TEST_CASE(expanddims_test_neg_dims)
{
    // this check makes sure the pb parses negative dim value correctly
    migraphx::program p;

419
420
421
422
    auto* mm = p.get_main_module();

    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4}});
    mm->add_literal(-1);
423
    mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 3, 4, 1}}}), l0);
424
    auto prog = optimize_tf("expanddims_neg_test.pb", false);
Khalique's avatar
Khalique committed
425
426
427
428

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
429
430
431
432
TEST_CASE(gather_test)
{
    migraphx::program p;

433
434
435
436
437
438
    auto* mm = p.get_main_module();

    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 4}});
    auto l1 = mm->add_literal(
        migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 1}});
    mm->add_literal(1);
Khalique's avatar
Khalique committed
439
440

    int axis = 1;
441
    mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), l0, l1);
Khalique's avatar
Khalique committed
442
443
    auto prog = optimize_tf("gather_test.pb", false);

Khalique's avatar
Khalique committed
444
445
446
    EXPECT(p == prog);
}

447
448
449
TEST_CASE(identity_test)
{
    migraphx::program p;
450
451
452

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
453
    mm->add_instruction(migraphx::make_op("identity"), l0);
Paul's avatar
Paul committed
454
    auto prog = optimize_tf("identity_test.pb", false);
455
456
457
458

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
459
460
461
TEST_CASE(matmul_test)
{
    migraphx::program p;
Khalique's avatar
Khalique committed
462

463
464
465
    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {8, 4}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 8}});
Khalique's avatar
Khalique committed
466

467
468
469
470
    auto trans_l0 =
        mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l0);
    auto trans_l1 =
        mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
471

472
    mm->add_instruction(migraphx::make_op("dot"), trans_l0, trans_l1);
Paul's avatar
Paul committed
473
    auto prog = optimize_tf("matmul_test.pb", false);
Khalique's avatar
Khalique committed
474
475
476
477

    EXPECT(p == prog);
}

478
479
480
TEST_CASE(mean_test)
{
    migraphx::program p;
481
482

    auto* mm = p.get_main_module();
Khalique's avatar
Khalique committed
483
    migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {2, 3}};
484
485
486
    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
    mm->add_literal(l);
    mm->add_literal(l);
Paul's avatar
Paul committed
487
    migraphx::op::reduce_mean op{{2, 3}};
488
489
    mm->add_instruction(op, l0);
    auto l3 = mm->add_instruction(op, l0);
490
    mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2, 3}}}), l3);
Paul's avatar
Paul committed
491
    auto prog = optimize_tf("mean_test.pb", false);
492
493
494
495
496
497
498

    EXPECT(p == prog);
}

TEST_CASE(mean_test_nhwc)
{
    migraphx::program p;
499
500

    auto* mm = p.get_main_module();
Khalique's avatar
Khalique committed
501
    migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 2}};
502
    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
503
504
    auto l1 =
        mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
Khalique's avatar
Khalique committed
505
    migraphx::op::reduce_mean op{{1, 2}};
506
    auto l2 = mm->add_instruction(op, l1);
507
    mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1, 2}}}), l2);
Paul's avatar
Paul committed
508
    auto prog = optimize_tf("mean_test_nhwc.pb", true);
509
510
511
512

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
513
514
515
516
TEST_CASE(mul_test)
{
    migraphx::program p;

517
518
519
520
    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 16}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 16}});

521
    mm->add_instruction(migraphx::make_op("mul"), l0, l1);
Paul's avatar
Paul committed
522
    auto prog = optimize_tf("mul_test.pb", false);
Khalique's avatar
Khalique committed
523
524
525
526

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
TEST_CASE(multi_output_test)
{
    migraphx::program p;

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
    auto l1  = mm->add_instruction(migraphx::make_op("relu"), l0);
    auto l2  = mm->add_instruction(migraphx::make_op("tanh"), l0);
    mm->add_return({l1, l2});

    EXPECT(test::throws([&] { parse_tf("multi_output_test.pb", false, {}, {"relu", "relu6"}); }));
    auto prog = parse_tf("multi_output_test.pb", false, {}, {"relu", "tanh"});

    EXPECT(p == prog);
}

543
544
545
TEST_CASE(onehot_test)
{
    migraphx::program p;
546
547
548

    auto* mm = p.get_main_module();
    auto l0  = mm->add_literal(
Khalique's avatar
Khalique committed
549
        migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {5}}, {1, 1, 1, 1, 1}});
550
551
552
553
    mm->add_literal(2);
    mm->add_literal(1.0f);
    mm->add_literal(0.0f);
    auto l1 = mm->add_literal(
Khalique's avatar
Khalique committed
554
        migraphx::literal{migraphx::shape{migraphx::shape::float_type, {2, 2}}, {1, 0, 0, 1}});
555
    int axis = 0;
556
    mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), l1, l0);
557
558
559
560
561
    auto prog = optimize_tf("onehot_test.pb", false);

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
562
563
564
565
566
567
568
569
TEST_CASE(noop_test)
{
    migraphx::program p;
    auto prog = optimize_tf("noop_test.pb", false);

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
570
571
572
TEST_CASE(pack_test)
{
    migraphx::program p;
573
574
575
576
577

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2}});
    auto l2  = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {2}});
Khalique's avatar
Khalique committed
578
579
580
581
    std::vector<migraphx::instruction_ref> args{l0, l1, l2};
    std::vector<migraphx::instruction_ref> unsqueezed_args;
    int64_t axis = 1;

582
583
584
585
586
587
588
589
590
    std::transform(
        args.begin(),
        args.end(),
        std::back_inserter(unsqueezed_args),
        [&](migraphx::instruction_ref arg) {
            return mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {axis}}}), arg);
        });
    mm->add_instruction(migraphx::make_op("concat", {{"axis", static_cast<int>(axis)}}),
                        unsqueezed_args);
Paul's avatar
Paul committed
591
    auto prog = optimize_tf("pack_test.pb", false);
Khalique's avatar
Khalique committed
592
593
594
595

    EXPECT(p == prog);
}

596
597
598
TEST_CASE(pack_test_nhwc)
{
    migraphx::program p;
599
600
601

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
602
603
604
605
606
607
608
609
    auto lt0 =
        mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
    auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
    auto lt1 =
        mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l1);
    auto l2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
    auto lt2 =
        mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l2);
Paul's avatar
Paul committed
610
    std::vector<migraphx::instruction_ref> args{lt0, lt1, lt2};
611
    std::vector<migraphx::instruction_ref> unsqueezed_args;
Paul's avatar
Paul committed
612
    int64_t nchw_axis = 3;
613
614
615
616
617

    std::transform(args.begin(),
                   args.end(),
                   std::back_inserter(unsqueezed_args),
                   [&](migraphx::instruction_ref arg) {
618
619
                       return mm->add_instruction(
                           migraphx::make_op("unsqueeze", {{"axes", {nchw_axis}}}), arg);
620
                   });
621
622
    mm->add_instruction(migraphx::make_op("concat", {{"axis", static_cast<int>(nchw_axis)}}),
                        unsqueezed_args);
Paul's avatar
Paul committed
623
    auto prog = optimize_tf("pack_test_nhwc.pb", true);
624
625
626
627

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
TEST_CASE(pad_test)
{
    migraphx::program p;

    auto* mm = p.get_main_module();

    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 4}});
    std::vector<int> pad_literals{1, 1, 2, 2};
    std::vector<int> pads{1, 2, 1, 2};
    mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {2, 2}}, pad_literals);

    mm->add_instruction(migraphx::make_op("pad", {{"pads", pads}}), l0);
    auto prog = optimize_tf("pad_test.pb", false);

    EXPECT(p == prog);
}

645
646
647
TEST_CASE(pooling_test)
{
    migraphx::program p;
648
649
650

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
651
652
    migraphx::op::pooling avg_pool_op{"average"};
    migraphx::op::pooling max_pool_op{"max"};
Shucai Xiao's avatar
Shucai Xiao committed
653
654
655
656
    avg_pool_op.stride  = {2, 2};
    max_pool_op.stride  = {2, 2};
    avg_pool_op.lengths = {2, 2};
    max_pool_op.lengths = {2, 2};
kahmed10's avatar
kahmed10 committed
657
    mm->add_instruction(avg_pool_op, l0);
658
    mm->add_instruction(max_pool_op, l0);
Paul's avatar
Paul committed
659
    auto prog = optimize_tf("pooling_test.pb", true);
660
661
662
663

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
664
665
666
TEST_CASE(pow_test)
{
    migraphx::program p;
667
668
669
670

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
671
    mm->add_instruction(migraphx::make_op("pow"), l0, l1);
Khalique's avatar
Khalique committed
672
673
674
675
676
    auto prog = optimize_tf("pow_test.pb", false);

    EXPECT(p == prog);
}

677
678
679
TEST_CASE(relu_test)
{
    migraphx::program p;
680
681
682

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
683
    mm->add_instruction(migraphx::make_op("relu"), l0);
Paul's avatar
Paul committed
684
    auto prog = optimize_tf("relu_test.pb", false);
685
686
687
688

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
689
690
691
TEST_CASE(relu6_test)
{
    migraphx::program p;
692
693

    auto* mm = p.get_main_module();
Shucai Xiao's avatar
Shucai Xiao committed
694
    std::vector<int> input_lens{1, 3, 16, 16};
695
696
697
    auto l0      = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, input_lens});
    auto min_val = mm->add_literal(0.0f);
    auto max_val = mm->add_literal(6.0f);
698
699
700
701
    min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
                                  min_val);
    max_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
                                  max_val);
702
    mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val);
Paul's avatar
Paul committed
703
    auto prog = optimize_tf("relu6_test.pb", false);
Khalique's avatar
Khalique committed
704
705
706
707

    EXPECT(p == prog);
}

708
709
710
TEST_CASE(reshape_test)
{
    migraphx::program p;
711
712
713

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {16}});
714
715
    migraphx::shape s0{migraphx::shape::int32_type, {4}};
    // in tf, the second arg is a literal that contains new dimensions
716
    mm->add_literal(migraphx::literal{s0, {1, 1, 1, 16}});
717
    mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 1, 1, 16}}}), l0);
Paul's avatar
Paul committed
718
    auto prog = optimize_tf("reshape_test.pb", false);
719
720
721
722

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
723
724
725
TEST_CASE(rsqrt_test)
{
    migraphx::program p;
726
727
728

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
729
    mm->add_instruction(migraphx::make_op("rsqrt"), l0);
Khalique's avatar
Khalique committed
730
731
732
733
734
    auto prog = optimize_tf("rsqrt_test.pb", false);

    EXPECT(p == prog);
}

735
736
737
TEST_CASE(shape_test)
{
    migraphx::program p;
738
739
740
741

    auto* mm = p.get_main_module();
    mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
    mm->add_literal(
742
743
744
745
746
747
        migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {4}}, {1, 3, 16, 16}});
    auto prog = optimize_tf("shape_test.pb", false);

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
748
749
750
TEST_CASE(slice_test)
{
    migraphx::program p;
751

Shucai Xiao's avatar
Shucai Xiao committed
752
    auto* mm     = p.get_main_module();
Shucai Xiao's avatar
Shucai Xiao committed
753
    int num_axes = 2;
Shucai Xiao's avatar
Shucai Xiao committed
754
    auto l0      = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 10}});
Khalique's avatar
Khalique committed
755
    migraphx::shape s0{migraphx::shape::int32_type, {num_axes}};
756
757
    mm->add_literal(migraphx::literal{s0, {1, 0}});
    mm->add_literal(migraphx::literal{s0, {2, -1}});
Khalique's avatar
Khalique committed
758
759
760
761
762
763

    migraphx::op::slice op;
    op.starts = {1, 0};
    op.ends   = {3, 10};
    op.axes   = std::vector<int64_t>(num_axes);
    std::iota(op.axes.begin(), op.axes.end(), 0);
764
    mm->add_instruction(op, l0);
Khalique's avatar
Khalique committed
765
766
767
768
769
    auto prog = optimize_tf("slice_test.pb", false);

    EXPECT(p == prog);
}

770
771
772
TEST_CASE(softmax_test)
{
    migraphx::program p;
773
774
775

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3}});
776
    mm->add_instruction(migraphx::make_op("softmax", {{"axis", 1}}), l0);
Paul's avatar
Paul committed
777
    auto prog = optimize_tf("softmax_test.pb", false);
778
779
780
781

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
782
783
784
TEST_CASE(split_test)
{
    migraphx::program p;
785
786

    auto* mm = p.get_main_module();
kahmed10's avatar
kahmed10 committed
787
    std::vector<int64_t> axes{0, 1};
788
789
790
791
792
    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}});
    mm->add_literal(3); // num_splits
    mm->add_literal(1); // split axis
    mm->add_literal(1); // concat axis
    mm->add_literal(1); // concat axis
793
794
795
796
797
798
    auto l1 = mm->add_instruction(
        migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 0}}, {"ends", {5, 10}}}), l0);
    auto l2 = mm->add_instruction(
        migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 10}}, {"ends", {5, 20}}}), l0);
    auto l3 = mm->add_instruction(
        migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 20}}, {"ends", {5, 30}}}), l0);
kahmed10's avatar
kahmed10 committed
799
800
801
    auto l4 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l1, l2);
    auto l5 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l3);
    mm->add_return({l4, l5});
802
    auto prog = parse_tf("split_test.pb", false);
kahmed10's avatar
kahmed10 committed
803
804
805
806
807
808
809

    EXPECT(p == prog);
}

TEST_CASE(split_test_one_output)
{
    migraphx::program p;
810
811
812
813
814

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}});
    mm->add_literal(1); // num_splits
    mm->add_literal(1); // split axis
kahmed10's avatar
kahmed10 committed
815
816
    auto l1 = mm->add_instruction(migraphx::make_op("identity"), l0);
    mm->add_return({l1});
817
    auto prog = parse_tf("split_test_one_output.pb", false);
kahmed10's avatar
kahmed10 committed
818
819
820
821
822
823
824

    EXPECT(p == prog);
}

TEST_CASE(split_test_vector_as_input)
{
    migraphx::program p;
825
826

    auto* mm = p.get_main_module();
kahmed10's avatar
kahmed10 committed
827
    std::vector<int64_t> axes{0, 1};
828
    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}});
kahmed10's avatar
kahmed10 committed
829
    // split sizes
830
    mm->add_literal(
kahmed10's avatar
kahmed10 committed
831
        migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {3}}, {4, 15, 11}});
832
833
834
    mm->add_literal(1); // split axis
    mm->add_literal(1); // concat axis
    mm->add_literal(1); // concat axis
835
836
837
838
839
840
    auto l1 = mm->add_instruction(
        migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 0}}, {"ends", {5, 4}}}), l0);
    auto l2 = mm->add_instruction(
        migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 4}}, {"ends", {5, 19}}}), l0);
    auto l3 = mm->add_instruction(
        migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 19}}, {"ends", {5, 30}}}), l0);
kahmed10's avatar
kahmed10 committed
841
842
843
    auto l4 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l1, l2);
    auto l5 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l3);
    mm->add_return({l4, l5});
844
    auto prog = parse_tf("split_test_vector_as_input.pb", false);
kahmed10's avatar
kahmed10 committed
845
846
847
848

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
849
850
851
TEST_CASE(sqdiff_test)
{
    migraphx::program p;
852
853
854
855

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
856
    mm->add_instruction(migraphx::make_op("sqdiff"), l0, l1);
Khalique's avatar
Khalique committed
857
858
859
860
861
    auto prog = optimize_tf("sqdiff_test.pb", false);

    EXPECT(p == prog);
}

862
863
864
TEST_CASE(squeeze_test)
{
    migraphx::program p;
865
866
867

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 1}});
868
    mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0, 3}}}), l0);
Paul's avatar
Paul committed
869
    auto prog = optimize_tf("squeeze_test.pb", false);
870
871
872

    EXPECT(p == prog);
}
Khalique's avatar
Khalique committed
873

Khalique's avatar
Khalique committed
874
875
876
TEST_CASE(stopgradient_test)
{
    migraphx::program p;
877
878
879

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
880
    mm->add_instruction(migraphx::make_op("identity"), l0);
Khalique's avatar
Khalique committed
881
882
    auto prog = optimize_tf("stopgradient_test.pb", false);

Khalique's avatar
Khalique committed
883
    EXPECT(p == prog);
Khalique's avatar
Khalique committed
884
885
}

Khalique's avatar
Khalique committed
886
887
888
TEST_CASE(stridedslice_test)
{
    migraphx::program p;
889
890
891

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 1, 1}});
892
893
    auto l1 =
        mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
Shucai Xiao's avatar
Shucai Xiao committed
894
    int num_axes = 4;
Khalique's avatar
Khalique committed
895
    migraphx::op::slice op;
Khalique's avatar
Khalique committed
896
    op.starts = {0, 0, 0, 0};
Paul's avatar
Paul committed
897
    op.ends   = {1, 1, 1, 5};
Khalique's avatar
Khalique committed
898
899
    op.axes   = std::vector<int64_t>(num_axes);
    std::iota(op.axes.begin(), op.axes.end(), 0);
900
    auto l2          = mm->add_instruction(op, l1);
Paul's avatar
Paul committed
901
    auto shrink_axis = 1;
902
    mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {shrink_axis}}}), l2);
Paul's avatar
Paul committed
903
    auto prog = optimize_tf("stridedslice_test.pb", true);
Khalique's avatar
Khalique committed
904
905
906
907

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
908
909
910
TEST_CASE(stridedslice_masks_test)
{
    migraphx::program p;
911
912
913

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 3, 3}});
Shucai Xiao's avatar
Shucai Xiao committed
914
    int num_axes = 4;
Khalique's avatar
Khalique committed
915
    migraphx::op::slice op;
916
917
    op.starts = {0, 1, 1, 0};
    op.ends   = {1, 3, 3, 10};
Khalique's avatar
Khalique committed
918
919
920
    op.axes   = std::vector<int64_t>(num_axes);
    std::iota(op.axes.begin(), op.axes.end(), 0);
    // add literals for starts, ends, and strides in tf (NHWC format)
921
922
923
924
925
926
927
    mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}},
                    std::vector<int>{0, 1, 1, 0});
    mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}},
                    std::vector<int>{0, 0, 0, 0});
    mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}},
                    std::vector<int>{1, 1, 1, 1});

928
929
    auto l1 =
        mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
930
    auto l2 = mm->add_instruction(op, l1);
931
932
    auto l3 =
        mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), l2);
kahmed10's avatar
kahmed10 committed
933
    mm->add_return({l3});
934
    auto prog = parse_tf("stridedslice_masks_test.pb", true);
Khalique's avatar
Khalique committed
935
936
937
938

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
939
940
941
TEST_CASE(sub_test)
{
    migraphx::program p;
942
943
944
945

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
kahmed10's avatar
kahmed10 committed
946
947
    auto l2  = mm->add_instruction(migraphx::make_op("sub"), l0, l1);
    mm->add_return({l2});
948
    auto prog = parse_tf("sub_test.pb", false);
Khalique's avatar
Khalique committed
949
950
951
952

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
953
954
955
TEST_CASE(tanh_test)
{
    migraphx::program p;
956
957

    auto* mm = p.get_main_module();
kahmed10's avatar
kahmed10 committed
958
959
960
961
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
    auto l1  = mm->add_instruction(migraphx::make_op("tanh"), l0);
    mm->add_return({l1});
    auto prog = parse_tf("tanh_test.pb", false);
Khalique's avatar
Khalique committed
962
963
964
965

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
966
967
968
TEST_CASE(transpose_test)
{
    migraphx::program p;
969
970
971

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
Khalique's avatar
Khalique committed
972
    migraphx::shape s0{migraphx::shape::int32_type, {4}};
973
    mm->add_literal(migraphx::literal{s0, {0, 2, 3, 1}});
974
    mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
Khalique's avatar
Khalique committed
975
976
977
978
979
    auto prog = optimize_tf("transpose_test.pb", false);

    EXPECT(p == prog);
}

980
981
982
TEST_CASE(variable_batch_test)
{
    migraphx::program p;
983
984
985

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
986
    mm->add_instruction(migraphx::make_op("identity"), l0);
987
988
989
990
991
    auto prog = optimize_tf("variable_batch_test.pb", false);

    EXPECT(p == prog);
}

992
int main(int argc, const char* argv[]) { test::run(argc, argv); }