tf_test.cpp 32.1 KB
Newer Older
1
2
#include <iostream>
#include <vector>
Shucai Xiao's avatar
Shucai Xiao committed
3
#include <unordered_map>
4
#include <migraphx/literal.hpp>
Paul's avatar
Paul committed
5
6
7
8
#include <migraphx/pass_manager.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp>
9
10
11
12
#include <migraphx/operators.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/tf.hpp>
13
14
15
16
#include <migraphx/make_op.hpp>

#include <migraphx/serialize.hpp>

17
18
#include "test.hpp"

Shucai Xiao's avatar
Shucai Xiao committed
19
20
21
migraphx::program
parse_tf(const std::string& name,
         bool is_nhwc,
kahmed10's avatar
kahmed10 committed
22
23
         const std::unordered_map<std::string, std::vector<std::size_t>>& dim_params = {},
         const std::vector<std::string>& output_node_names                           = {})
24
{
kahmed10's avatar
kahmed10 committed
25
26
    return migraphx::parse_tf(name,
                              migraphx::tf_options{is_nhwc, 1, dim_params, output_node_names});
27
28
}

Paul's avatar
Paul committed
29
30
migraphx::program optimize_tf(const std::string& name, bool is_nhwc)
{
31
    auto prog = migraphx::parse_tf(name, migraphx::tf_options{is_nhwc, 1});
32
    auto* mm  = prog.get_main_module();
Paul's avatar
Paul committed
33
    if(is_nhwc)
34
        migraphx::run_passes(*mm,
Paul's avatar
Paul committed
35
36
37
                             {migraphx::simplify_reshapes{},
                              migraphx::dead_code_elimination{},
                              migraphx::eliminate_identity{}});
kahmed10's avatar
kahmed10 committed
38
39
40
41
42
43
44
45

    // remove the last return instruction
    auto last_ins = std::prev(mm->end());
    if(last_ins != mm->end())
        if(last_ins->name() == "@return")
        {
            mm->remove_instruction(last_ins);
        }
Paul's avatar
Paul committed
46
47
48
    return prog;
}

49
50
51
TEST_CASE(add_test)
{
    migraphx::program p;
52
53
54
55

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
56
    mm->add_instruction(migraphx::make_op("add"), l0, l1);
Paul's avatar
Paul committed
57
    auto prog = optimize_tf("add_test.pb", false);
58
59
60
61

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
62
63
64
TEST_CASE(addv2_test)
{
    migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
65
66
67
68
    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    mm->add_instruction(migraphx::make_op("add"), l0, l1);
kahmed10's avatar
kahmed10 committed
69
70
71
72
73
    auto prog = optimize_tf("addv2_test.pb", false);

    EXPECT(p == prog);
}

74
75
TEST_CASE(add_bcast_test)
{
Khalique's avatar
Khalique committed
76

77
    migraphx::program p;
78
79

    auto* mm = p.get_main_module();
80
    migraphx::shape s0{migraphx::shape::float_type, {2, 3}};
81
82
    auto l0 = mm->add_parameter("0", s0);
    auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2, 1}});
83
84
    auto l2 =
        mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", s0.lens()}}), l1);
kahmed10's avatar
kahmed10 committed
85
    mm->add_instruction(migraphx::make_op("add"), l0, l2);
Paul's avatar
Paul committed
86
    auto prog = optimize_tf("add_bcast_test.pb", false);
87
88
89
90

    EXPECT(p == prog);
}

91
92
93
TEST_CASE(argmax_test)
{
    migraphx::program p;
94
95

    auto* mm = p.get_main_module();
Shucai Xiao's avatar
Shucai Xiao committed
96
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 5, 6, 7}});
97
    mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {2}});
98
    auto ins = mm->add_instruction(migraphx::make_op("argmax", {{"axis", 2}}), l0);
kahmed10's avatar
kahmed10 committed
99
100
    auto l1  = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), ins);
    mm->add_return({l1});
Shucai Xiao's avatar
Shucai Xiao committed
101
    auto prog = parse_tf("argmax_test.pb", false, {{"0", {4, 5, 6, 7}}});
102
103
104
105
106
107
108

    EXPECT(p == prog);
}

TEST_CASE(argmin_test)
{
    migraphx::program p;
109
110
111
112

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
    mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {2}});
113
    auto ins = mm->add_instruction(migraphx::make_op("argmin", {{"axis", 2}}), l0);
kahmed10's avatar
kahmed10 committed
114
115
    auto l1  = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), ins);
    mm->add_return({l1});
116
117
118
119
120
    auto prog = parse_tf("argmin_test.pb", false);

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
121
122
123
TEST_CASE(assert_less_equal_test)
{
    migraphx::program p;
124
125

    auto* mm = p.get_main_module();
Khalique's avatar
Khalique committed
126
    migraphx::shape s0{migraphx::shape::float_type, {2, 3}};
127
128
    auto l0 = mm->add_parameter("0", s0);
    auto l1 = mm->add_parameter("1", s0);
Khalique's avatar
Khalique committed
129
    migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {0, 1}};
130
    auto l2 = mm->add_literal(l);
131
132
133
    mm->add_instruction(migraphx::make_op("add"), l0, l1);
    auto l3 = mm->add_instruction(migraphx::make_op("identity"), l0, l1);
    mm->add_instruction(migraphx::make_op("identity"), l3, l2);
Khalique's avatar
Khalique committed
134
135
136
137
138
    auto prog = optimize_tf("assert_less_equal_test.pb", false);

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
139
140
141
142
TEST_CASE(batchmatmul_test)
{
    migraphx::program p;

143
144
145
146
    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 4}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 4, 8}});

147
148
149
150
    auto trans_l0 =
        mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l0);
    auto trans_l1 =
        mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l1);
Khalique's avatar
Khalique committed
151

152
    mm->add_instruction(migraphx::make_op("dot"), trans_l0, trans_l1);
Khalique's avatar
Khalique committed
153
154
155
156
157
    auto prog = optimize_tf("batchmatmul_test.pb", false);

    EXPECT(p == prog);
}

158
159
TEST_CASE(batchnorm_test)
{
Khalique's avatar
Khalique committed
160
161
    float epsilon  = 1.001e-5f;
    float momentum = 0.9f;
162
163

    migraphx::program p;
164
165

    auto* mm = p.get_main_module();
Khalique's avatar
Khalique committed
166
167
    migraphx::op::batch_norm_inference op{
        epsilon, momentum, migraphx::op::batch_norm_inference::spatial};
168
    migraphx::shape s0{migraphx::shape::float_type, {32}};
169
    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 32, 16, 16}});
170
171
    std::vector<float> const_vals(32);
    std::fill(const_vals.begin(), const_vals.end(), 1.0f);
Khalique's avatar
Khalique committed
172

173
174
175
176
177
    auto l2 = mm->add_parameter("2", s0);
    auto l3 = mm->add_parameter("3", s0);
    auto l4 = mm->add_parameter("4", s0);
    auto l1 = mm->add_literal(migraphx::literal{s0, const_vals});
    mm->add_instruction(op, l0, l1, l2, l3, l4);
Paul's avatar
Paul committed
178
    auto prog = optimize_tf("batchnorm_test.pb", true);
179
180
181
182

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
183
184
185
186
187
188
TEST_CASE(batchnormv3_test)
{
    float epsilon  = 1.0e-5f;
    float momentum = 0.9f;

    migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
189
    auto* mm = p.get_main_module();
kahmed10's avatar
kahmed10 committed
190
191
192
    migraphx::op::batch_norm_inference op{
        epsilon, momentum, migraphx::op::batch_norm_inference::spatial};
    migraphx::shape s0{migraphx::shape::float_type, {32}};
Shucai Xiao's avatar
Shucai Xiao committed
193
    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 32, 16, 16}});
kahmed10's avatar
kahmed10 committed
194
195
196
    std::vector<float> const_vals(32);
    std::fill(const_vals.begin(), const_vals.end(), 1.0f);

Shucai Xiao's avatar
Shucai Xiao committed
197
198
199
200
201
    auto l2 = mm->add_parameter("2", s0);
    auto l3 = mm->add_parameter("3", s0);
    auto l4 = mm->add_parameter("4", s0);
    auto l1 = mm->add_literal(migraphx::literal{s0, const_vals});
    mm->add_instruction(op, l0, l1, l2, l3, l4);
kahmed10's avatar
kahmed10 committed
202
203
204
205
206
    auto prog = optimize_tf("batchnormv3_test.pb", true);

    EXPECT(p == prog);
}

207
208
209
TEST_CASE(biasadd_test)
{
    migraphx::program p;
210
211

    auto* mm = p.get_main_module();
212
    migraphx::shape s0{migraphx::shape::float_type, {1, 500, 1, 1}};
213
    uint64_t axis = 1;
214
215
    auto l0       = mm->add_parameter("0", s0);
    auto l1       = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {500}});
216
217
218
    auto l2       = mm->add_instruction(
        migraphx::make_op("broadcast", {{"axis", axis}, {"dims", l0->get_shape().lens()}}), l1);
    mm->add_instruction(migraphx::make_op("add"), l0, l2);
Paul's avatar
Paul committed
219
    auto prog = optimize_tf("biasadd_test.pb", true);
220
221
222
223

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
224
225
226
227
228
229
230
231
232
233
TEST_CASE(biasadd_scalar_test)
{
    migraphx::program p;

    auto* mm = p.get_main_module();
    migraphx::shape s0{migraphx::shape::float_type, {1, 1}};
    uint64_t axis = 1;
    auto l0       = mm->add_parameter("0", s0);
    auto l1       = mm->add_literal(
        migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}, {0}}, {1.0}});
234
235
236
    auto l2 = mm->add_instruction(
        migraphx::make_op("broadcast", {{"axis", axis}, {"dims", l0->get_shape().lens()}}), l1);
    mm->add_instruction(migraphx::make_op("add"), l0, l2);
kahmed10's avatar
kahmed10 committed
237
238
239
240
241
    auto prog = optimize_tf("biasadd_scalar_test.pb", true);

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
242
243
244
TEST_CASE(cast_test)
{
    migraphx::program p;
245
246
247

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
248
249
250
251
    mm->add_instruction(
        migraphx::make_op("convert",
                          {{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}),
        l0);
Khalique's avatar
Khalique committed
252
253
254
255
256
    auto prog = optimize_tf("cast_test.pb", false);

    EXPECT(p == prog);
}

257
258
259
TEST_CASE(concat_test)
{
    migraphx::program p;
Khalique's avatar
Khalique committed
260

261
262
263
264
    auto* mm = p.get_main_module();

    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 7, 3}});
    auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}});
265
266
267

    int axis = 1;
    // tf uses axis as the third input, and it is in int32 format
Khalique's avatar
Khalique committed
268
    // add the literal using a vector in order to set stride to 1 (like in tf parser)
269
    mm->add_literal(migraphx::shape{migraphx::shape::int32_type}, std::vector<int>{axis});
270

271
    mm->add_instruction(migraphx::make_op("concat", {{"axis", axis}}), l0, l1);
Paul's avatar
Paul committed
272
    auto prog = optimize_tf("concat_test.pb", false);
273
274
275
276
277
278
279

    EXPECT(p == prog);
}

TEST_CASE(const_test)
{
    migraphx::program p;
280
281
282

    auto* mm = p.get_main_module();
    mm->add_literal(migraphx::shape{migraphx::shape::float_type}, std::vector<float>{1.0f});
Paul's avatar
Paul committed
283
    auto prog = optimize_tf("constant_test.pb", false);
284
285
286
287

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
288
migraphx::program create_conv()
289
290
{
    migraphx::program p;
Khalique's avatar
Khalique committed
291

292
293
294
    auto* mm = p.get_main_module();

    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
Khalique's avatar
Khalique committed
295
    std::vector<float> weight_data(3 * 3 * 3 * 32);
296
    std::fill(weight_data.begin(), weight_data.end(), 1.0f);
Khalique's avatar
Khalique committed
297
    auto l1 =
298
        mm->add_literal(migraphx::shape{migraphx::shape::float_type, {3, 3, 3, 32}}, weight_data);
299
300
301

    migraphx::op::convolution op;
    op.padding_mode = migraphx::op::padding_mode_t::same;
Khalique's avatar
Khalique committed
302
    op.padding      = {1, 1};
Khalique's avatar
Khalique committed
303
304
    op.stride       = {1, 1};
    op.dilation     = {1, 1};
305
    auto l2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {3, 2, 0, 1}}}), l1);
306
    mm->add_instruction(op, l0, l2);
kahmed10's avatar
kahmed10 committed
307
308
309
310
311
312
313
314
315
316
317
    return p;
}

TEST_CASE(conv_test)
{
    migraphx::program p = create_conv();
    auto prog           = optimize_tf("conv_test.pb", true);

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
318
319
320
321
322
323
324
325
326
327
328
TEST_CASE(conv_add_test)
{
    migraphx::program p = create_conv();
    auto* mm            = p.get_main_module();
    auto l0             = std::prev(mm->end());
    mm->add_instruction(migraphx::make_op("add"), l0, l0);
    auto prog = optimize_tf("conv_add_test.pb", true);

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
329
330
331
332
TEST_CASE(conv_nchw_test)
{
    migraphx::program p = create_conv();
    auto prog           = optimize_tf("conv_nchw_test.pb", false);
333
334
335
336

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
TEST_CASE(conv_relu_test)
{
    migraphx::program p = create_conv();
    auto* mm            = p.get_main_module();
    auto l0             = std::prev(mm->end());
    mm->add_instruction(migraphx::make_op("relu"), l0);
    auto prog = optimize_tf("conv_relu_test.pb", true);

    EXPECT(p == prog);
}

TEST_CASE(conv_relu6_test)
{
    migraphx::program p = create_conv();
    auto* mm            = p.get_main_module();
    std::vector<size_t> input_lens{1, 32, 16, 16};
    auto l0      = std::prev(mm->end());
    auto min_val = mm->add_literal(0.0f);
    auto max_val = mm->add_literal(6.0f);
    min_val      = mm->add_instruction(
        migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), min_val);
    max_val = mm->add_instruction(
        migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), max_val);
    mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val);
    auto prog = optimize_tf("conv_relu6_test.pb", true);

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
366
367
368
369
TEST_CASE(depthwiseconv_test)
{
    migraphx::program p;

370
371
372
    auto* mm = p.get_main_module();

    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
Khalique's avatar
Khalique committed
373
374
375
    std::vector<float> weight_data(3 * 3 * 3 * 1);
    std::fill(weight_data.begin(), weight_data.end(), 1.0f);
    auto l1 =
376
        mm->add_literal(migraphx::shape{migraphx::shape::float_type, {3, 3, 3, 1}}, weight_data);
Khalique's avatar
Khalique committed
377
378
379

    migraphx::op::convolution op;
    op.padding_mode = migraphx::op::padding_mode_t::same;
Khalique's avatar
Khalique committed
380
    op.padding      = {1, 1};
Khalique's avatar
Khalique committed
381
382
383
    op.stride       = {1, 1};
    op.dilation     = {1, 1};
    op.group        = 3;
384
385
386
    auto l3 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {3, 2, 0, 1}}}), l1);
    auto l4 = mm->add_instruction(migraphx::make_op("contiguous"), l3);
    auto l5 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, 1, 3, 3}}}), l4);
387
    mm->add_instruction(op, l0, l5);
Paul's avatar
Paul committed
388
    auto prog = optimize_tf("depthwise_conv_test.pb", true);
Khalique's avatar
Khalique committed
389
390
391
392

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
393
394
395
TEST_CASE(expanddims_test)
{
    migraphx::program p;
Khalique's avatar
Khalique committed
396

397
398
399
400
    auto* mm = p.get_main_module();

    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4}});
    mm->add_literal(0);
401
    mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 2, 3, 4}}}), l0);
402
    auto prog = optimize_tf("expanddims_test.pb", false);
Khalique's avatar
Khalique committed
403
404
405
406
407
408
409
410
411

    EXPECT(p == prog);
}

TEST_CASE(expanddims_test_neg_dims)
{
    // this check makes sure the pb parses negative dim value correctly
    migraphx::program p;

412
413
414
415
    auto* mm = p.get_main_module();

    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4}});
    mm->add_literal(-1);
416
    mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 3, 4, 1}}}), l0);
417
    auto prog = optimize_tf("expanddims_neg_test.pb", false);
Khalique's avatar
Khalique committed
418
419
420
421

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
422
423
424
425
TEST_CASE(gather_test)
{
    migraphx::program p;

426
427
428
429
430
431
    auto* mm = p.get_main_module();

    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 4}});
    auto l1 = mm->add_literal(
        migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 1}});
    mm->add_literal(1);
Khalique's avatar
Khalique committed
432
433

    int axis = 1;
434
    mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), l0, l1);
Khalique's avatar
Khalique committed
435
436
    auto prog = optimize_tf("gather_test.pb", false);

Khalique's avatar
Khalique committed
437
438
439
    EXPECT(p == prog);
}

440
441
442
TEST_CASE(identity_test)
{
    migraphx::program p;
443
444
445

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
446
    mm->add_instruction(migraphx::make_op("identity"), l0);
Paul's avatar
Paul committed
447
    auto prog = optimize_tf("identity_test.pb", false);
448
449
450
451

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
452
453
454
TEST_CASE(matmul_test)
{
    migraphx::program p;
Khalique's avatar
Khalique committed
455

456
457
458
    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {8, 4}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 8}});
Khalique's avatar
Khalique committed
459

460
461
    auto trans_l0 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l0);
    auto trans_l1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1);
462

463
    mm->add_instruction(migraphx::make_op("dot"), trans_l0, trans_l1);
Paul's avatar
Paul committed
464
    auto prog = optimize_tf("matmul_test.pb", false);
Khalique's avatar
Khalique committed
465
466
467
468

    EXPECT(p == prog);
}

469
470
471
TEST_CASE(mean_test)
{
    migraphx::program p;
472
473

    auto* mm = p.get_main_module();
Khalique's avatar
Khalique committed
474
    migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {2, 3}};
475
476
477
    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
    mm->add_literal(l);
    mm->add_literal(l);
Paul's avatar
Paul committed
478
    migraphx::op::reduce_mean op{{2, 3}};
479
480
    mm->add_instruction(op, l0);
    auto l3 = mm->add_instruction(op, l0);
481
    mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2, 3}}}), l3);
Paul's avatar
Paul committed
482
    auto prog = optimize_tf("mean_test.pb", false);
483
484
485
486
487
488
489

    EXPECT(p == prog);
}

TEST_CASE(mean_test_nhwc)
{
    migraphx::program p;
490
491

    auto* mm = p.get_main_module();
Khalique's avatar
Khalique committed
492
    migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 2}};
493
    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
494
    auto l1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0);
Khalique's avatar
Khalique committed
495
    migraphx::op::reduce_mean op{{1, 2}};
496
    auto l2 = mm->add_instruction(op, l1);
497
    mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1, 2}}}), l2);
Paul's avatar
Paul committed
498
    auto prog = optimize_tf("mean_test_nhwc.pb", true);
499
500
501
502

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
503
504
505
506
TEST_CASE(mul_test)
{
    migraphx::program p;

507
508
509
510
    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 16}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 16}});

511
    mm->add_instruction(migraphx::make_op("mul"), l0, l1);
Paul's avatar
Paul committed
512
    auto prog = optimize_tf("mul_test.pb", false);
Khalique's avatar
Khalique committed
513
514
515
516

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
TEST_CASE(multi_output_test)
{
    migraphx::program p;

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
    auto l1  = mm->add_instruction(migraphx::make_op("relu"), l0);
    auto l2  = mm->add_instruction(migraphx::make_op("tanh"), l0);
    mm->add_return({l1, l2});

    EXPECT(test::throws([&] { parse_tf("multi_output_test.pb", false, {}, {"relu", "relu6"}); }));
    auto prog = parse_tf("multi_output_test.pb", false, {}, {"relu", "tanh"});

    EXPECT(p == prog);
}

533
534
535
TEST_CASE(onehot_test)
{
    migraphx::program p;
536
537
538

    auto* mm = p.get_main_module();
    auto l0  = mm->add_literal(
Khalique's avatar
Khalique committed
539
        migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {5}}, {1, 1, 1, 1, 1}});
540
541
542
543
    mm->add_literal(2);
    mm->add_literal(1.0f);
    mm->add_literal(0.0f);
    auto l1 = mm->add_literal(
Khalique's avatar
Khalique committed
544
        migraphx::literal{migraphx::shape{migraphx::shape::float_type, {2, 2}}, {1, 0, 0, 1}});
545
    int axis = 0;
546
    mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), l1, l0);
547
548
549
550
551
    auto prog = optimize_tf("onehot_test.pb", false);

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
552
553
554
555
556
557
558
559
TEST_CASE(noop_test)
{
    migraphx::program p;
    auto prog = optimize_tf("noop_test.pb", false);

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
560
561
562
TEST_CASE(pack_test)
{
    migraphx::program p;
563
564
565
566
567

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2}});
    auto l2  = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {2}});
Khalique's avatar
Khalique committed
568
569
570
571
    std::vector<migraphx::instruction_ref> args{l0, l1, l2};
    std::vector<migraphx::instruction_ref> unsqueezed_args;
    int64_t axis = 1;

572
573
574
575
576
577
578
579
580
    std::transform(
        args.begin(),
        args.end(),
        std::back_inserter(unsqueezed_args),
        [&](migraphx::instruction_ref arg) {
            return mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {axis}}}), arg);
        });
    mm->add_instruction(migraphx::make_op("concat", {{"axis", static_cast<int>(axis)}}),
                        unsqueezed_args);
Paul's avatar
Paul committed
581
    auto prog = optimize_tf("pack_test.pb", false);
Khalique's avatar
Khalique committed
582
583
584
585

    EXPECT(p == prog);
}

586
587
588
TEST_CASE(pack_test_nhwc)
{
    migraphx::program p;
589
590
591

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
592
    auto lt0 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0);
593
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
594
    auto lt1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l1);
595
    auto l2  = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
596
    auto lt2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l2);
Paul's avatar
Paul committed
597
    std::vector<migraphx::instruction_ref> args{lt0, lt1, lt2};
598
    std::vector<migraphx::instruction_ref> unsqueezed_args;
Paul's avatar
Paul committed
599
    int64_t nchw_axis = 3;
600
601
602
603
604

    std::transform(args.begin(),
                   args.end(),
                   std::back_inserter(unsqueezed_args),
                   [&](migraphx::instruction_ref arg) {
605
606
                       return mm->add_instruction(
                           migraphx::make_op("unsqueeze", {{"axes", {nchw_axis}}}), arg);
607
                   });
608
609
    mm->add_instruction(migraphx::make_op("concat", {{"axis", static_cast<int>(nchw_axis)}}),
                        unsqueezed_args);
Paul's avatar
Paul committed
610
    auto prog = optimize_tf("pack_test_nhwc.pb", true);
611
612
613
614

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
TEST_CASE(pad_test)
{
    migraphx::program p;

    auto* mm = p.get_main_module();

    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 4}});
    std::vector<int> pad_literals{1, 1, 2, 2};
    std::vector<int> pads{1, 2, 1, 2};
    mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {2, 2}}, pad_literals);

    mm->add_instruction(migraphx::make_op("pad", {{"pads", pads}}), l0);
    auto prog = optimize_tf("pad_test.pb", false);

    EXPECT(p == prog);
}

632
633
634
TEST_CASE(pooling_test)
{
    migraphx::program p;
635
636
637

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
638
639
    migraphx::op::pooling avg_pool_op{"average"};
    migraphx::op::pooling max_pool_op{"max"};
Shucai Xiao's avatar
Shucai Xiao committed
640
641
642
643
    avg_pool_op.stride  = {2, 2};
    max_pool_op.stride  = {2, 2};
    avg_pool_op.lengths = {2, 2};
    max_pool_op.lengths = {2, 2};
644
    mm->add_instruction(max_pool_op, l0);
Paul's avatar
Paul committed
645
    auto prog = optimize_tf("pooling_test.pb", true);
646
647
648
649

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
650
651
652
TEST_CASE(pow_test)
{
    migraphx::program p;
653
654
655
656

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
657
    mm->add_instruction(migraphx::make_op("pow"), l0, l1);
Khalique's avatar
Khalique committed
658
659
660
661
662
    auto prog = optimize_tf("pow_test.pb", false);

    EXPECT(p == prog);
}

663
664
665
TEST_CASE(relu_test)
{
    migraphx::program p;
666
667
668

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
669
    mm->add_instruction(migraphx::make_op("relu"), l0);
Paul's avatar
Paul committed
670
    auto prog = optimize_tf("relu_test.pb", false);
671
672
673
674

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
675
676
677
TEST_CASE(relu6_test)
{
    migraphx::program p;
678
679

    auto* mm = p.get_main_module();
kahmed10's avatar
kahmed10 committed
680
    std::vector<size_t> input_lens{1, 3, 16, 16};
681
682
683
    auto l0      = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, input_lens});
    auto min_val = mm->add_literal(0.0f);
    auto max_val = mm->add_literal(6.0f);
684
685
686
687
688
    min_val      = mm->add_instruction(
        migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), min_val);
    max_val = mm->add_instruction(
        migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), max_val);
    mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val);
Paul's avatar
Paul committed
689
    auto prog = optimize_tf("relu6_test.pb", false);
Khalique's avatar
Khalique committed
690
691
692
693

    EXPECT(p == prog);
}

694
695
696
TEST_CASE(reshape_test)
{
    migraphx::program p;
697
698
699

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {16}});
700
701
    migraphx::shape s0{migraphx::shape::int32_type, {4}};
    // in tf, the second arg is a literal that contains new dimensions
702
    mm->add_literal(migraphx::literal{s0, {1, 1, 1, 16}});
703
    mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 1, 1, 16}}}), l0);
Paul's avatar
Paul committed
704
    auto prog = optimize_tf("reshape_test.pb", false);
705
706
707
708

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
709
710
711
TEST_CASE(rsqrt_test)
{
    migraphx::program p;
712
713
714

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
715
    mm->add_instruction(migraphx::make_op("rsqrt"), l0);
Khalique's avatar
Khalique committed
716
717
718
719
720
    auto prog = optimize_tf("rsqrt_test.pb", false);

    EXPECT(p == prog);
}

721
722
723
TEST_CASE(shape_test)
{
    migraphx::program p;
724
725
726
727

    auto* mm = p.get_main_module();
    mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
    mm->add_literal(
728
729
730
731
732
733
        migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {4}}, {1, 3, 16, 16}});
    auto prog = optimize_tf("shape_test.pb", false);

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
734
735
736
TEST_CASE(slice_test)
{
    migraphx::program p;
737
738

    auto* mm             = p.get_main_module();
Khalique's avatar
Khalique committed
739
    std::size_t num_axes = 2;
740
    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 10}});
Khalique's avatar
Khalique committed
741
    migraphx::shape s0{migraphx::shape::int32_type, {num_axes}};
742
743
    mm->add_literal(migraphx::literal{s0, {1, 0}});
    mm->add_literal(migraphx::literal{s0, {2, -1}});
Khalique's avatar
Khalique committed
744
745
746
747
748
749

    migraphx::op::slice op;
    op.starts = {1, 0};
    op.ends   = {3, 10};
    op.axes   = std::vector<int64_t>(num_axes);
    std::iota(op.axes.begin(), op.axes.end(), 0);
750
    mm->add_instruction(op, l0);
Khalique's avatar
Khalique committed
751
752
753
754
755
    auto prog = optimize_tf("slice_test.pb", false);

    EXPECT(p == prog);
}

756
757
758
TEST_CASE(softmax_test)
{
    migraphx::program p;
759
760
761

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3}});
762
    mm->add_instruction(migraphx::make_op("softmax", {{"axis", 1}}), l0);
Paul's avatar
Paul committed
763
    auto prog = optimize_tf("softmax_test.pb", false);
764
765
766
767

    EXPECT(p == prog);
}

kahmed10's avatar
kahmed10 committed
768
769
770
TEST_CASE(split_test)
{
    migraphx::program p;
771
772

    auto* mm = p.get_main_module();
kahmed10's avatar
kahmed10 committed
773
    std::vector<int64_t> axes{0, 1};
774
775
776
777
778
    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}});
    mm->add_literal(3); // num_splits
    mm->add_literal(1); // split axis
    mm->add_literal(1); // concat axis
    mm->add_literal(1); // concat axis
779
780
781
782
783
784
785
    auto l1 = mm->add_instruction(
        migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 0}}, {"ends", {5, 10}}}), l0);
    auto l2 = mm->add_instruction(
        migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 10}}, {"ends", {5, 20}}}), l0);
    auto l3 = mm->add_instruction(
        migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 20}}, {"ends", {5, 30}}}), l0);
    mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l1, l2);
kahmed10's avatar
kahmed10 committed
786
787
    auto l4 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l3);
    mm->add_return({l4});
788
    auto prog = parse_tf("split_test.pb", false);
kahmed10's avatar
kahmed10 committed
789
790
791
792
793
794
795

    EXPECT(p == prog);
}

TEST_CASE(split_test_one_output)
{
    migraphx::program p;
796
797
798
799
800

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}});
    mm->add_literal(1); // num_splits
    mm->add_literal(1); // split axis
kahmed10's avatar
kahmed10 committed
801
802
    auto l1 = mm->add_instruction(migraphx::make_op("identity"), l0);
    mm->add_return({l1});
803
    auto prog = parse_tf("split_test_one_output.pb", false);
kahmed10's avatar
kahmed10 committed
804
805
806
807
808
809
810

    EXPECT(p == prog);
}

TEST_CASE(split_test_vector_as_input)
{
    migraphx::program p;
811
812

    auto* mm = p.get_main_module();
kahmed10's avatar
kahmed10 committed
813
    std::vector<int64_t> axes{0, 1};
814
    auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}});
kahmed10's avatar
kahmed10 committed
815
    // split sizes
816
    mm->add_literal(
kahmed10's avatar
kahmed10 committed
817
        migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {3}}, {4, 15, 11}});
818
819
820
    mm->add_literal(1); // split axis
    mm->add_literal(1); // concat axis
    mm->add_literal(1); // concat axis
821
822
823
824
825
826
827
    auto l1 = mm->add_instruction(
        migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 0}}, {"ends", {5, 4}}}), l0);
    auto l2 = mm->add_instruction(
        migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 4}}, {"ends", {5, 19}}}), l0);
    auto l3 = mm->add_instruction(
        migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 19}}, {"ends", {5, 30}}}), l0);
    mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l1, l2);
kahmed10's avatar
kahmed10 committed
828
829
    auto l4 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l3);
    mm->add_return({l4});
830
    auto prog = parse_tf("split_test_vector_as_input.pb", false);
kahmed10's avatar
kahmed10 committed
831
832
833
834

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
835
836
837
TEST_CASE(sqdiff_test)
{
    migraphx::program p;
838
839
840
841

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
842
    mm->add_instruction(migraphx::make_op("sqdiff"), l0, l1);
Khalique's avatar
Khalique committed
843
844
845
846
847
    auto prog = optimize_tf("sqdiff_test.pb", false);

    EXPECT(p == prog);
}

848
849
850
TEST_CASE(squeeze_test)
{
    migraphx::program p;
851
852
853

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 1}});
854
    mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0, 3}}}), l0);
Paul's avatar
Paul committed
855
    auto prog = optimize_tf("squeeze_test.pb", false);
856
857
858

    EXPECT(p == prog);
}
Khalique's avatar
Khalique committed
859

Khalique's avatar
Khalique committed
860
861
862
TEST_CASE(stopgradient_test)
{
    migraphx::program p;
863
864
865

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
866
    mm->add_instruction(migraphx::make_op("identity"), l0);
Khalique's avatar
Khalique committed
867
868
    auto prog = optimize_tf("stopgradient_test.pb", false);

Khalique's avatar
Khalique committed
869
    EXPECT(p == prog);
Khalique's avatar
Khalique committed
870
871
}

Khalique's avatar
Khalique committed
872
873
874
TEST_CASE(stridedslice_test)
{
    migraphx::program p;
875
876
877

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 1, 1}});
878
    auto l1  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0);
Khalique's avatar
Khalique committed
879
880
    std::size_t num_axes = 4;
    migraphx::op::slice op;
Khalique's avatar
Khalique committed
881
    op.starts = {0, 0, 0, 0};
Paul's avatar
Paul committed
882
    op.ends   = {1, 1, 1, 5};
Khalique's avatar
Khalique committed
883
884
    op.axes   = std::vector<int64_t>(num_axes);
    std::iota(op.axes.begin(), op.axes.end(), 0);
885
    auto l2          = mm->add_instruction(op, l1);
Paul's avatar
Paul committed
886
    auto shrink_axis = 1;
887
    mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {shrink_axis}}}), l2);
Paul's avatar
Paul committed
888
    auto prog = optimize_tf("stridedslice_test.pb", true);
Khalique's avatar
Khalique committed
889
890
891
892

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
893
894
895
TEST_CASE(stridedslice_masks_test)
{
    migraphx::program p;
896
897
898

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 3, 3}});
Khalique's avatar
Khalique committed
899
900
    std::size_t num_axes = 4;
    migraphx::op::slice op;
901
902
    op.starts = {0, 1, 1, 0};
    op.ends   = {1, 3, 3, 10};
Khalique's avatar
Khalique committed
903
904
905
    op.axes   = std::vector<int64_t>(num_axes);
    std::iota(op.axes.begin(), op.axes.end(), 0);
    // add literals for starts, ends, and strides in tf (NHWC format)
906
907
908
909
910
911
912
    mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}},
                    std::vector<int>{0, 1, 1, 0});
    mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}},
                    std::vector<int>{0, 0, 0, 0});
    mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}},
                    std::vector<int>{1, 1, 1, 1});

913
    auto l1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0);
914
    auto l2 = mm->add_instruction(op, l1);
kahmed10's avatar
kahmed10 committed
915
916
    auto l3 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 3, 1, 2}}}), l2);
    mm->add_return({l3});
917
    auto prog = parse_tf("stridedslice_masks_test.pb", true);
Khalique's avatar
Khalique committed
918
919
920
921

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
922
923
924
TEST_CASE(sub_test)
{
    migraphx::program p;
925
926
927
928

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
    auto l1  = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
kahmed10's avatar
kahmed10 committed
929
930
    auto l2  = mm->add_instruction(migraphx::make_op("sub"), l0, l1);
    mm->add_return({l2});
931
    auto prog = parse_tf("sub_test.pb", false);
Khalique's avatar
Khalique committed
932
933
934
935

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
936
937
938
TEST_CASE(tanh_test)
{
    migraphx::program p;
939
940

    auto* mm = p.get_main_module();
kahmed10's avatar
kahmed10 committed
941
942
943
944
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
    auto l1  = mm->add_instruction(migraphx::make_op("tanh"), l0);
    mm->add_return({l1});
    auto prog = parse_tf("tanh_test.pb", false);
Khalique's avatar
Khalique committed
945
946
947
948

    EXPECT(p == prog);
}

Khalique's avatar
Khalique committed
949
950
951
TEST_CASE(transpose_test)
{
    migraphx::program p;
952
953
954

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
Khalique's avatar
Khalique committed
955
    migraphx::shape s0{migraphx::shape::int32_type, {4}};
956
    mm->add_literal(migraphx::literal{s0, {0, 2, 3, 1}});
957
    mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0);
Khalique's avatar
Khalique committed
958
959
960
961
962
    auto prog = optimize_tf("transpose_test.pb", false);

    EXPECT(p == prog);
}

963
964
965
TEST_CASE(variable_batch_test)
{
    migraphx::program p;
966
967
968

    auto* mm = p.get_main_module();
    auto l0  = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
969
    mm->add_instruction(migraphx::make_op("identity"), l0);
970
971
972
973
974
    auto prog = optimize_tf("variable_batch_test.pb", false);

    EXPECT(p == prog);
}

975
int main(int argc, const char* argv[]) { test::run(argc, argv); }