op_shape_test.cpp 59.5 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
#include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
5
#include <sstream>
6
7
8
9
#include <migraphx/make_op.hpp>

#include <migraphx/serialize.hpp>

10
11
#include "test.hpp"

Paul's avatar
Paul committed
12
template <class... Ts>
Paul's avatar
Paul committed
13
void expect_shape(const migraphx::shape& expected, const migraphx::operation& op, Ts... xs)
14
{
Paul's avatar
Paul committed
15
    migraphx::program p;
16
    auto* mm = p.get_main_module();
Paul's avatar
Paul committed
17
18
    std::vector<migraphx::shape> shapes{xs...};
    std::vector<migraphx::instruction_ref> args(shapes.size());
Paul's avatar
Paul committed
19
    std::transform(
20
21
        shapes.begin(), shapes.end(), args.begin(), [&](auto&& s) { return mm->add_outline(s); });
    mm->add_instruction(op, args);
22
    if(p.get_output_shapes().back() != expected)
Paul's avatar
Paul committed
23
    {
24
        std::cout << "FAILED: Incorrect shape for " << op << ": ";
25
        std::cout << expected << " != " << p.get_output_shapes().back() << std::endl;
Paul's avatar
Paul committed
26
        for(auto&& s : shapes)
27
28
29
30
            std::cout << "    " << s << std::endl;
    }
}

Paul's avatar
Paul committed
31
template <class... Ts>
Paul's avatar
Paul committed
32
void throws_shape(const migraphx::operation& op, Ts... xs)
33
{
Paul's avatar
Paul committed
34
    migraphx::program p;
35
    auto* mm = p.get_main_module();
Paul's avatar
Paul committed
36
37
    std::vector<migraphx::shape> shapes{xs...};
    std::vector<migraphx::instruction_ref> args(shapes.size());
Paul's avatar
Paul committed
38
    std::transform(
39
40
        shapes.begin(), shapes.end(), args.begin(), [&](auto&& s) { return mm->add_outline(s); });
    bool thrown = test::throws([&] { mm->add_instruction(op, args); });
Paul's avatar
Paul committed
41
42
    if(not thrown)
    {
43
        std::cout << "FAILED: No error found for " << op.name() << ": ";
Paul's avatar
Paul committed
44
        for(auto&& s : shapes)
45
46
47
48
            std::cout << "    " << s << std::endl;
    }
}

Paul's avatar
Paul committed
49
50
51
52
template <class...>
struct always_false : std::false_type
{
};
53

Paul's avatar
Paul committed
54
template <class... Ts>
Paul's avatar
Paul committed
55
void throws_shape(const migraphx::shape&, Ts...)
56
{
Paul's avatar
Paul committed
57
58
    static_assert(always_false<Ts...>{},
                  "An expected shape should not be passed to throws_shape function");
59
60
}

Paul's avatar
Paul committed
61
TEST_CASE(batch_norm_inference_shape)
62
63
{
    const size_t channels = 3;
Paul's avatar
Paul committed
64
65
    migraphx::shape s{migraphx::shape::float_type, {4, channels, 3, 3}};
    migraphx::shape vars{migraphx::shape::float_type, {channels}};
66
67
68
    expect_shape(s, migraphx::make_op("batch_norm_inference"), s, vars, vars, vars, vars);
    throws_shape(migraphx::make_op("batch_norm_inference"), s);
    throws_shape(migraphx::make_op("batch_norm_inference"), s, vars, vars, vars, vars, vars);
69
70
}

Paul's avatar
Paul committed
71
TEST_CASE(convolution_shape)
72
{
Paul's avatar
Paul committed
73
74
75
    migraphx::shape output{migraphx::shape::float_type, {4, 4, 1, 1}};
    migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
    migraphx::shape weights{migraphx::shape::float_type, {4, 3, 3, 3}};
76
77
78
79
80
    expect_shape(output, migraphx::make_op("convolution"), input, weights);
    throws_shape(migraphx::make_op("convolution"), input);
    throws_shape(
        migraphx::make_op("convolution", {{"padding", {0}}, {"stride", {1}}, {"dilation", {1}}}),
        input);
Paul's avatar
Paul committed
81
82
83

    migraphx::shape input2{migraphx::shape::float_type, {3, 3}};
    migraphx::shape weights2{migraphx::shape::float_type, {3, 3}};
84
85
    throws_shape(migraphx::make_op("convolution"), input2, weights2);
    throws_shape(migraphx::make_op("convolution"), input2, weights);
86
87
88
89

    migraphx::shape output_1d{migraphx::shape::float_type, {4, 4, 1}};
    migraphx::shape input_1d{migraphx::shape::float_type, {4, 3, 3}};
    migraphx::shape weights_1d{migraphx::shape::float_type, {4, 3, 3}};
90
91
92
93
94
    expect_shape(
        output_1d,
        migraphx::make_op("convolution", {{"padding", {0}}, {"stride", {1}}, {"dilation", {1}}}),
        input_1d,
        weights_1d);
95
96
97
98

    migraphx::shape output_3d{migraphx::shape::float_type, {4, 4, 1, 1, 1}};
    migraphx::shape input_3d{migraphx::shape::float_type, {4, 3, 3, 3, 3}};
    migraphx::shape weights_3d{migraphx::shape::float_type, {4, 3, 3, 3, 3}};
99
100
101
102
103
104
105
106
    expect_shape(
        output_3d,
        migraphx::make_op("convolution",
                          {{"padding", {0, 0, 0}}, {"stride", {1, 1, 1}}, {"dilation", {1, 1, 1}}}),
        input_3d,
        weights_3d);

    throws_shape(migraphx::make_op("convolution"), input_3d, weights_3d);
107
108
}

kahmed10's avatar
kahmed10 committed
109
110
111
112
113
TEST_CASE(deconvolution_shape)
{
    migraphx::shape input{migraphx::shape::float_type, {4, 4, 1, 1}};
    migraphx::shape output{migraphx::shape::float_type, {4, 3, 3, 3}};
    migraphx::shape weights{migraphx::shape::float_type, {4, 3, 3, 3}};
114
115
116
117
118
    expect_shape(output, migraphx::make_op("deconvolution"), input, weights);
    throws_shape(migraphx::make_op("deconvolution"), input);
    throws_shape(
        migraphx::make_op("deconvolution", {{"padding", {0}}, {"stride", {1}}, {"dilation", {1}}}),
        input);
kahmed10's avatar
kahmed10 committed
119
120
121
122

    migraphx::shape input_1d{migraphx::shape::float_type, {4, 4, 1}};
    migraphx::shape output_1d{migraphx::shape::float_type, {4, 3, 3}};
    migraphx::shape weights_1d{migraphx::shape::float_type, {4, 3, 3}};
123
124
125
126
127
    expect_shape(
        output_1d,
        migraphx::make_op("deconvolution", {{"padding", {0}}, {"stride", {1}}, {"dilation", {1}}}),
        input_1d,
        weights_1d);
kahmed10's avatar
kahmed10 committed
128
129
130
131

    migraphx::shape input_3d{migraphx::shape::float_type, {4, 4, 1, 1, 1}};
    migraphx::shape output_3d{migraphx::shape::float_type, {4, 3, 3, 3, 3}};
    migraphx::shape weights_3d{migraphx::shape::float_type, {4, 3, 3, 3, 3}};
132
133
134
135
136
137
    expect_shape(
        output_3d,
        migraphx::make_op("deconvolution",
                          {{"padding", {0, 0, 0}}, {"stride", {1, 1, 1}}, {"dilation", {1, 1, 1}}}),
        input_3d,
        weights_3d);
kahmed10's avatar
kahmed10 committed
138
139
}

140
141
TEST_CASE(quant_convolution_shape)
{
142
    migraphx::shape output{migraphx::shape::int32_type, {4, 4, 1, 1}};
143
144
    migraphx::shape input{migraphx::shape::int8_type, {4, 3, 3, 3}};
    migraphx::shape weights{migraphx::shape::int8_type, {4, 3, 3, 3}};
145
146
147
148
149
150
151
152
153
154
    expect_shape(output, migraphx::make_op("quant_convolution"), input, weights);
    throws_shape(migraphx::make_op("quant_convolution"), input);
    throws_shape(migraphx::make_op("quant_convolution",
                                   {{"padding", {0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}),
                 input,
                 weights);
    throws_shape(migraphx::make_op("quant_convolution",
                                   {{"padding", {0}}, {"stride", {1}}, {"dilation", {1}}}),
                 input,
                 weights);
155

156
    migraphx::shape input2{migraphx::shape::int32_type, {3, 3}};
157
    migraphx::shape weights2{migraphx::shape::float_type, {3, 3}};
158
159
    throws_shape(migraphx::make_op("quant_convolution"), input2, weights2);
    throws_shape(migraphx::make_op("quant_convolution"), input2, weights);
160

161
    migraphx::shape input3{migraphx::shape::int32_type, {4, 3, 3, 3}};
162
    migraphx::shape weight3{migraphx::shape::float_type, {4, 3, 3, 3}};
163
164
165
    throws_shape(migraphx::make_op("quant_convolution"), input3, weights);
    throws_shape(migraphx::make_op("quant_convolution"), input, weight3);
    throws_shape(migraphx::make_op("quant_convolution"), input3, weight3);
166
167
}

kahmed10's avatar
kahmed10 committed
168
169
170
171
TEST_CASE(pooling_shape)
{
    migraphx::shape output{migraphx::shape::float_type, {4, 3, 1, 1}};
    migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
172
173
174
175
176
177
178
179
180
181
    throws_shape(
        migraphx::make_op("pooling",
                          {{"mode", "max"}, {"padding", {1}}, {"stride", {0}}, {"lengths", {1}}}),
        input);
    expect_shape(
        output,
        migraphx::make_op(
            "pooling",
            {{"mode", "max"}, {"padding", {0, 0}}, {"stride", {3, 3}}, {"lengths", {1, 1}}}),
        input);
Shucai Xiao's avatar
Shucai Xiao committed
182
183

    migraphx::shape output1{migraphx::shape::float_type, {4, 3, 2, 2}};
184
185
186
187
188
189
190
191
    expect_shape(output1,
                 migraphx::make_op("pooling",
                                   {{"mode", "max"},
                                    {"padding", {0, 0}},
                                    {"stride", {3, 3}},
                                    {"lengths", {1, 1}},
                                    {"ceil_mode", true}}),
                 input);
kahmed10's avatar
kahmed10 committed
192
193
}

194
195
196
197
TEST_CASE(inconsistent_attr_shape)
{
    migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
    migraphx::shape weights{migraphx::shape::float_type, {4, 3, 3, 3}};
198
199
200
201
202
203
204
205
206
207
208
209
    throws_shape(migraphx::make_op("convolution",
                                   {{"padding", {1, 1}}, {"stride", {2}}, {"dilation", {3, 3, 3}}}),
                 input,
                 weights);
    throws_shape(migraphx::make_op("deconvolution",
                                   {{"padding", {1, 1}}, {"stride", {2}}, {"dilation", {3, 3, 3}}}),
                 input,
                 weights);
    throws_shape(
        migraphx::make_op(
            "pooling", {{"mode", "max"}, {"padding", {1}}, {"stride", {0}}, {"lengths", {1, 1}}}),
        input);
210
211
}

Paul's avatar
Paul committed
212
TEST_CASE(transpose_shape)
213
{
Paul's avatar
Paul committed
214
215
    migraphx::shape input{migraphx::shape::float_type, {2, 2}};
    migraphx::shape output{migraphx::shape::float_type, {2, 2}, {1, 2}};
216
217
218
219
    expect_shape(input, migraphx::make_op("transpose", {{"dims", {0, 1}}}), input);
    expect_shape(output, migraphx::make_op("transpose", {{"dims", {1, 0}}}), input);
    expect_shape(output, migraphx::make_op("transpose"), input);
    throws_shape(migraphx::make_op("transpose", {{"dims", {1, 2}}}), input);
220
221
}

Paul's avatar
Paul committed
222
TEST_CASE(contiguous_shape)
223
{
Paul's avatar
Paul committed
224
225
    migraphx::shape output{migraphx::shape::float_type, {2, 2}};
    migraphx::shape input{migraphx::shape::float_type, {2, 2}, {1, 2}};
226
227
    expect_shape(output, migraphx::make_op("contiguous"), input);
    throws_shape(migraphx::make_op("contiguous"), input, input);
Paul's avatar
Paul committed
228

Paul's avatar
Paul committed
229
    migraphx::shape single{migraphx::shape::float_type, {2}};
230
    expect_shape(single, migraphx::make_op("contiguous"), single);
231
232
}

233
234
235
236
TEST_CASE(contiguous_shape_scalar)
{
    migraphx::shape output{migraphx::shape::float_type};
    migraphx::shape input{migraphx::shape::float_type};
237
    expect_shape(output, migraphx::make_op("contiguous"), input);
238
239
}

Paul's avatar
Paul committed
240
TEST_CASE(reshape_shape)
241
{
Paul's avatar
Paul committed
242
    migraphx::shape input{migraphx::shape::float_type, {24, 1, 1, 1}};
Paul's avatar
Paul committed
243
244
    for(auto&& new_shape :
        std::vector<std::vector<int64_t>>{{8, 3, 1, 1}, {1, 3, 4, 2}, {1, 3, 4, 2}})
245
246
247
    {
        std::vector<std::size_t> lens(new_shape.size());
        std::copy(new_shape.begin(), new_shape.end(), lens.begin());
Paul's avatar
Paul committed
248
        migraphx::shape output{migraphx::shape::float_type, lens};
249
        expect_shape(output, migraphx::make_op("reshape", {{"dims", new_shape}}), input);
250
251
    }

Shucai Xiao's avatar
Shucai Xiao committed
252
253
    for(auto&& new_shape :
        std::vector<std::vector<int64_t>>{{8, 3, 2, 2}, {1, 3, -1, -1}, {3, 0, 0}, {3, 2, 0}})
254
    {
255
        throws_shape(migraphx::make_op("reshape", {{"dims", new_shape}}), input);
256
    }
Shucai Xiao's avatar
Shucai Xiao committed
257

Shucai Xiao's avatar
Shucai Xiao committed
258
    std::vector<std::pair<std::vector<int64_t>, migraphx::shape>> minus1_tests{
Shucai Xiao's avatar
Shucai Xiao committed
259
260
261
262
263
264
265
266
        {{2, -1, 3}, {migraphx::shape::float_type, {2, 4, 3}}},
        {{0, -1, 0}, {migraphx::shape::float_type, {24, 1, 1}}},
        {{2, -1, 0}, {migraphx::shape::float_type, {2, 12, 1}}},
        {{0, 0, -1}, {migraphx::shape::float_type, {24, 1, 1}}},
        {{2, 0, -1}, {migraphx::shape::float_type, {2, 1, 12}}},
        {{-1, 2, 3}, {migraphx::shape::float_type, {4, 2, 3}}},
        {{-1, 0, 3}, {migraphx::shape::float_type, {8, 1, 3}}},
        {{-1, 0, 0}, {migraphx::shape::float_type, {24, 1, 1}}},
Shucai Xiao's avatar
Shucai Xiao committed
267
        {{-1, 3, 0}, {migraphx::shape::float_type, {8, 3, 1}}}};
Shucai Xiao's avatar
Shucai Xiao committed
268

Shucai Xiao's avatar
Shucai Xiao committed
269
    for(auto& it : minus1_tests)
Shucai Xiao's avatar
Shucai Xiao committed
270
    {
271
        expect_shape(it.second, migraphx::make_op("reshape", {{"dims", it.first}}), input);
Shucai Xiao's avatar
Shucai Xiao committed
272
    }
273
274
}

Paul's avatar
Paul committed
275
TEST_CASE(flatten_shape)
276
{
Paul's avatar
Paul committed
277
278
    migraphx::shape input{migraphx::shape::float_type, {2, 4, 6, 8}};
    expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 2 * 4 * 6 * 8}},
279
                 migraphx::make_op("flatten", {{"axis", 0}}),
Scott Thornton's avatar
Scott Thornton committed
280
                 input);
281
    expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 2 * 4 * 6 * 8}},
282
                 migraphx::make_op("flatten", {{"axis", -4}}),
283
                 input);
Paul's avatar
Paul committed
284
    expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 4 * 6 * 8}},
285
                 migraphx::make_op("flatten", {{"axis", 1}}),
Paul's avatar
Paul committed
286
                 input);
287
    expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 4 * 6 * 8}},
288
                 migraphx::make_op("flatten", {{"axis", -3}}),
289
                 input);
Paul's avatar
Paul committed
290
    expect_shape(migraphx::shape{migraphx::shape::float_type, {2 * 4, 6 * 8}},
291
                 migraphx::make_op("flatten", {{"axis", 2}}),
Paul's avatar
Paul committed
292
293
                 input);
    expect_shape(migraphx::shape{migraphx::shape::float_type, {2 * 4 * 6, 8}},
294
                 migraphx::make_op("flatten", {{"axis", 3}}),
Paul's avatar
Paul committed
295
                 input);
Paul's avatar
Paul committed
296
    expect_shape(migraphx::shape{migraphx::shape::float_type, {2 * 4 * 6 * 8, 1}},
297
                 migraphx::make_op("flatten", {{"axis", 4}}),
Scott Thornton's avatar
Scott Thornton committed
298
                 input);
299
300
    throws_shape(migraphx::make_op("flatten", {{"axis", 5}}), input);
    throws_shape(migraphx::make_op("flatten", {{"axis", -5}}), input);
301
302
}

Paul's avatar
Paul committed
303
TEST_CASE(slice_shape)
Scott Thornton's avatar
Scott Thornton committed
304
{
Paul's avatar
Paul committed
305
306
    migraphx::shape input{migraphx::shape::int32_type, {2, 2, 3}};
    expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}},
307
                 migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1}}, {"ends", {3}}}),
Scott Thornton's avatar
Scott Thornton committed
308
                 input);
Paul's avatar
Paul committed
309
    expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}},
310
311
                 migraphx::make_op(
                     "slice", {{"axes", {0, 1, 2}}, {"starts", {0, 0, 1}}, {"ends", {2, 2, 3}}}),
Scott Thornton's avatar
Scott Thornton committed
312
                 input);
Paul's avatar
Paul committed
313
    expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 1}, {6, 3, 1}},
314
                 migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {10}}}),
Scott Thornton's avatar
Scott Thornton committed
315
316
                 input);
}
Scott Thornton's avatar
Scott Thornton committed
317

wsttiger's avatar
wsttiger committed
318
TEST_CASE(multibroadcast)
Scott Thornton's avatar
Scott Thornton committed
319
320
{
    {
321
        std::vector<std::size_t> lens{4, 2, 5, 3};
Paul's avatar
Paul committed
322
323
        migraphx::shape input{migraphx::shape::float_type, {2, 1, 3}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 3, 0, 1}},
324
                     migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
325
                     input);
Scott Thornton's avatar
Scott Thornton committed
326
327
    }
    {
328
        std::vector<std::size_t> lens{4, 2, 5, 3};
Paul's avatar
Paul committed
329
330
        migraphx::shape input{migraphx::shape::float_type, {2, 1, 1}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 1, 0, 0}},
331
                     migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
332
                     input);
Scott Thornton's avatar
Scott Thornton committed
333
334
    }
    {
335
        std::vector<std::size_t> lens{4, 2, 5, 3};
Paul's avatar
Paul committed
336
337
        migraphx::shape input{migraphx::shape::float_type, {5, 1}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 0, 1, 0}},
338
                     migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
339
                     input);
Scott Thornton's avatar
Scott Thornton committed
340
341
    }
    {
342
        std::vector<std::size_t> lens{4, 2, 5, 3};
Paul's avatar
Paul committed
343
344
        migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {1, 0, 0, 0}},
345
                     migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
346
347
348
349
                     input);
    }
    {
        std::vector<std::size_t> lens{4, 2, 5, 3};
Paul's avatar
Paul committed
350
351
        migraphx::shape input{migraphx::shape::float_type, {3}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 0, 0, 1}},
352
                     migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
353
354
                     input);
    }
355
356
    {
        std::vector<std::size_t> lens{4, 4, 1, 3};
Paul's avatar
Paul committed
357
358
        migraphx::shape input{migraphx::shape::float_type, {4, 1, 3}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 3, 3, 1}},
359
                     migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
360
361
                     input);
    }
362
363
    {
        std::vector<std::size_t> lens{4, 1, 1, 3};
Paul's avatar
Paul committed
364
365
        migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {1, 1, 1, 0}},
366
                     migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
367
368
369
370
                     input);
    }
    {
        std::vector<std::size_t> lens{4, 1, 3};
Paul's avatar
Paul committed
371
        migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
372
        throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input);
Scott Thornton's avatar
Scott Thornton committed
373
374
    }
    {
375
        std::vector<std::size_t> lens{4, 1, 3};
Paul's avatar
Paul committed
376
        migraphx::shape input{migraphx::shape::float_type, {}};
377
        throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input);
Scott Thornton's avatar
Scott Thornton committed
378
    }
Shucai Xiao's avatar
Shucai Xiao committed
379
380
381
    {
        std::vector<std::size_t> lens{2, 3, 4, 5};
        migraphx::shape input{migraphx::shape::float_type, {3, 4}};
382
        throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input);
Shucai Xiao's avatar
Shucai Xiao committed
383
384
385
386
    }
    {
        std::vector<std::size_t> lens{2, 3, 4, 5};
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4}};
387
        throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input);
Shucai Xiao's avatar
Shucai Xiao committed
388
    }
Scott Thornton's avatar
Scott Thornton committed
389
390
}

391
392
393
394
TEST_CASE(broadcast)
{
    {
        std::vector<std::size_t> lens{1, 1};
395
        migraphx::shape input{migraphx::shape::float_type, {1}, {0}};
396
        expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1}, {0, 0}},
397
                     migraphx::make_op("broadcast", {{"axis", 0}, {"dims", lens}}),
398
399
                     input);
    }
kahmed10's avatar
kahmed10 committed
400

401
402
    {
        std::vector<std::size_t> lens{1, 1};
403
404
405
406
407
408
409
410
        migraphx::shape input{migraphx::shape::float_type, {2}};
        throws_shape(migraphx::op::broadcast{1, lens}, input);
    }

    {
        std::vector<std::size_t> lens{2, 2};
        migraphx::shape input{migraphx::shape::float_type, {1, 2}};
        throws_shape(migraphx::op::broadcast{1, lens}, input);
411
412
413
414
415
416
    }

    {
        std::vector<std::size_t> lens{3, 2, 4, 3};
        migraphx::shape input{migraphx::shape::float_type, {4, 3}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 2, 4, 3}, {0, 0, 3, 1}},
417
                     migraphx::make_op("broadcast", {{"axis", 2}, {"dims", lens}}),
418
419
420
421
422
423
                     input);
    }

    {
        std::vector<std::size_t> lens{3, 2, 4, 3};
        migraphx::shape input{migraphx::shape::float_type, {4, 4}};
424
        throws_shape(migraphx::make_op("broadcast", {{"axis", 2}, {"dims", lens}}), input);
425
426
427
    }
}

Shucai Xiao's avatar
Shucai Xiao committed
428
TEST_CASE(gather)
429
430
431
432
{
    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
        migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
433
        int axis = 1;
434
        expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 2, 3, 4, 5}},
435
                     migraphx::make_op("gather", {{"axis", axis}}),
Shucai Xiao's avatar
Shucai Xiao committed
436
437
                     input,
                     indices);
438
439
    }

440
441
442
443
    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
        migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
        int axis = -4;
444
        expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 3, 4, 5}},
445
                     migraphx::make_op("gather", {{"axis", axis}}),
446
447
448
449
                     input,
                     indices);
    }

450
451
452
453
454
    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
        migraphx::shape indices{migraphx::shape::int32_type, {1}};
        int axis = -4;
        expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}},
455
                     migraphx::make_op("gather", {{"axis", axis}}),
456
457
458
459
460
461
                     input,
                     indices);
    }

    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
462
        migraphx::shape indices{migraphx::shape::int32_type};
463
464
        int axis = -4;
        expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 4, 5}},
465
                     migraphx::make_op("gather", {{"axis", axis}}),
466
467
468
469
470
471
                     input,
                     indices);
    }

    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
472
        migraphx::shape indices{migraphx::shape::int32_type};
473
474
        int axis = 3;
        expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4}},
475
                     migraphx::make_op("gather", {{"axis", axis}}),
476
477
478
479
480
481
                     input,
                     indices);
    }

    {
        migraphx::shape input{migraphx::shape::float_type, {3}};
482
        migraphx::shape indices{migraphx::shape::int32_type};
483
        int axis = 0;
484
        expect_shape(migraphx::shape{migraphx::shape::float_type},
485
                     migraphx::make_op("gather", {{"axis", axis}}),
486
487
488
489
490
491
492
493
494
                     input,
                     indices);
    }

    {
        migraphx::shape input{migraphx::shape::float_type, {3}};
        migraphx::shape indices{migraphx::shape::int32_type, {1}};
        int axis = 0;
        expect_shape(migraphx::shape{migraphx::shape::float_type, {1}},
495
                     migraphx::make_op("gather", {{"axis", axis}}),
496
497
498
499
                     input,
                     indices);
    }

500
501
502
    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
        migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
503
        int axis = 4;
504
        throws_shape(migraphx::make_op("gather", {{"axis", axis}}), input, indices);
505
    }
506
507
508
509
510

    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
        migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
        int axis = -5;
511
        throws_shape(migraphx::make_op("gather", {{"axis", axis}}), input, indices);
512
    }
513
514
}

Khalique's avatar
Khalique committed
515
template <class T>
Khalique's avatar
Khalique committed
516
void test_softmax_variations()
517
518
519
{
    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
Khalique's avatar
Khalique committed
520
        expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{0}, input);
521
522
523
524
    }

    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
Khalique's avatar
Khalique committed
525
        expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{1}, input);
526
527
528
529
    }

    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
Khalique's avatar
Khalique committed
530
        expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{2}, input);
531
532
533
534
    }

    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
Khalique's avatar
Khalique committed
535
        expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{3}, input);
536
537
538
539
540
    }

    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
        int axis = 4;
Khalique's avatar
Khalique committed
541
        throws_shape(T{axis}, input);
542
543
544
    }
}

Khalique's avatar
Khalique committed
545
TEST_CASE(softmax) { test_softmax_variations<migraphx::op::softmax>(); }
546

Khalique's avatar
Khalique committed
547
TEST_CASE(logsoftmax) { test_softmax_variations<migraphx::op::logsoftmax>(); }
548

Shucai Xiao's avatar
Shucai Xiao committed
549
TEST_CASE(test_argmax)
550
551
552
{
    {
        migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
Shucai Xiao's avatar
Shucai Xiao committed
553
        expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}},
554
                     migraphx::make_op("argmax", {{"axis", 0}}),
Shucai Xiao's avatar
Shucai Xiao committed
555
                     input);
Shucai Xiao's avatar
Shucai Xiao committed
556
557
558
559
    }

    {
        migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
Shucai Xiao's avatar
Shucai Xiao committed
560
        expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}},
561
                     migraphx::make_op("argmax", {{"axis", 1}}),
Shucai Xiao's avatar
Shucai Xiao committed
562
                     input);
Shucai Xiao's avatar
Shucai Xiao committed
563
564
565
566
    }

    {
        migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
Shucai Xiao's avatar
Shucai Xiao committed
567
        expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}},
568
                     migraphx::make_op("argmax", {{"axis", 2}}),
Shucai Xiao's avatar
Shucai Xiao committed
569
                     input);
Shucai Xiao's avatar
Shucai Xiao committed
570
571
572
573
    }

    {
        migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
Shucai Xiao's avatar
Shucai Xiao committed
574
        expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}},
575
                     migraphx::make_op("argmax", {{"axis", 3}}),
Shucai Xiao's avatar
Shucai Xiao committed
576
                     input);
Shucai Xiao's avatar
Shucai Xiao committed
577
578
579
580
    }

    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
581
        throws_shape(migraphx::make_op("argmax", {{"axis", 4}}), input);
582
    }
Shucai Xiao's avatar
Shucai Xiao committed
583
}
584

Shucai Xiao's avatar
Shucai Xiao committed
585
TEST_CASE(test_argmin)
586
587
588
{
    {
        migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
Shucai Xiao's avatar
Shucai Xiao committed
589
        expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}},
590
                     migraphx::make_op("argmin", {{"axis", 0}}),
Shucai Xiao's avatar
Shucai Xiao committed
591
                     input);
592
593
594
595
    }

    {
        migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
Shucai Xiao's avatar
Shucai Xiao committed
596
        expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}},
597
                     migraphx::make_op("argmin", {{"axis", 1}}),
Shucai Xiao's avatar
Shucai Xiao committed
598
                     input);
599
600
601
602
    }

    {
        migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
Shucai Xiao's avatar
Shucai Xiao committed
603
        expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}},
604
                     migraphx::make_op("argmin", {{"axis", 2}}),
Shucai Xiao's avatar
Shucai Xiao committed
605
                     input);
606
607
608
609
    }

    {
        migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
Shucai Xiao's avatar
Shucai Xiao committed
610
        expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}},
611
                     migraphx::make_op("argmin", {{"axis", 3}}),
Shucai Xiao's avatar
Shucai Xiao committed
612
                     input);
Shucai Xiao's avatar
Shucai Xiao committed
613
614
615
616
    }

    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
617
        throws_shape(migraphx::make_op("argmin", {{"axis", 4}}), input);
618
    }
Shucai Xiao's avatar
Shucai Xiao committed
619
}
620

621
622
623
624
TEST_CASE(test_scalar)
{
    migraphx::shape s1{migraphx::shape::float_type, {1}, {1}};
    migraphx::shape s2{migraphx::shape::float_type, {2, 3, 4, 5}, {0, 0, 0, 0}};
625
    expect_shape(s2, migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 3, 4, 5}}}), s1);
626
627
628
629
630
}

TEST_CASE(test_scalar_nelemnts)
{
    migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
631
    throws_shape(migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 3, 4, 5}}}), input);
632
633
}

634
635
TEST_CASE(test_squeeze)
{
636
637
    migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}};
    migraphx::shape s2{migraphx::shape::float_type, {4, 1, 3, 3}};
638
    expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {3}}}), s1);
639
}
640

641
642
643
644
TEST_CASE(test_squeeze_negative_axis)
{
    migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}};
    migraphx::shape s2{migraphx::shape::float_type, {4, 1, 3, 3}};
645
    expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {-2}}}), s1);
646
647
648
649
650
}

TEST_CASE(test_squeeze_wrong_axis)
{
    migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}};
651
    throws_shape(migraphx::make_op("squeeze", {{"axes", {0}}}), s1);
652
653
654
655
656
657
}

TEST_CASE(test_squeeze_all)
{
    migraphx::shape s1{migraphx::shape::float_type, {1}};
    migraphx::shape s2{migraphx::shape::float_type};
658
    expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {0}}}), s1);
659
660
661
662
663
664
}

TEST_CASE(test_unsqueeze_scalar)
{
    migraphx::shape s1{migraphx::shape::float_type, {1}, {0}};
    migraphx::shape s2{migraphx::shape::float_type, {1}, {1}};
665
    expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {0}}}), s1);
666
667
668
669
670
}

TEST_CASE(test_unsqueeze_scalar_tensor1)
{
    migraphx::shape s{migraphx::shape::float_type, {4, 3, 3}, {0, 0, 0}};
671
    throws_shape(migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s);
672
673
674
675
676
}

TEST_CASE(test_unsqueeze_scalar_tensor2)
{
    migraphx::shape s{migraphx::shape::float_type, {1, 1, 1}, {0, 0, 0}};
677
    throws_shape(migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s);
678
679
680
681
682
683
}

TEST_CASE(test_unsqueeze)
{
    migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}};
    migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}};
684
    expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1);
685
686
687
688
689
690
}

TEST_CASE(test_unsqueeze_negative_axis)
{
    migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}};
    migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}};
691
    expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s1);
692
693
}

Shucai Xiao's avatar
Shucai Xiao committed
694
695
template <class T>
void test_reduce_ops()
Shucai Xiao's avatar
Shucai Xiao committed
696
{
697
698
699
700
701
702
703
    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, T{}, input);
    }

    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
Shucai Xiao's avatar
Shucai Xiao committed
704
705
        expect_shape(
            migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, T{{0, 1, 2, 3}}, input);
706
    }
Shucai Xiao's avatar
Shucai Xiao committed
707
708
709
710
711
712
713
714
    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 1, 1}}, T{{2, 3}}, input);
    }
    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}}, T{{0}}, input);
    }
715
716
717
718
    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 1}}, T{{-1}}, input);
    }
719
720
    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
Shucai Xiao's avatar
Shucai Xiao committed
721
        throws_shape(T{{4}}, input);
722
723
724
    }
}

Shucai Xiao's avatar
Shucai Xiao committed
725
726
TEST_CASE(reduce_sum) { test_reduce_ops<migraphx::op::reduce_sum>(); }
TEST_CASE(reduce_mean) { test_reduce_ops<migraphx::op::reduce_mean>(); }
727

728
729
// 2 inputs arguments
TEST_CASE(matmul)
730
731
{
    {
732
733
        migraphx::shape s_m1{migraphx::shape::float_type, {5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {5}};
734
        throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
735
736
737
738
739
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {5, 2}};
740
        throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
741
742
743
744
745
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {5}};
746
        throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
747
748
749
750
751
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {5, 4}};
752
753
754
755
        expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4}},
                     migraphx::make_op("dot"),
                     s_m1,
                     s_m2);
756
757
758
    }

    {
759
760
        migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {4, 4}};
761
        throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
762
763
764
765
766
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {6, 5, 4}};
767
        throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
768
769
    }

770
771
772
    {
        migraphx::shape s_m1{migraphx::shape::float_type, {6, 1, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {6, 5, 4}};
Shucai Xiao's avatar
Shucai Xiao committed
773
        expect_shape(migraphx::shape{migraphx::shape::float_type, {6, 1, 4}},
774
                     migraphx::make_op("dot"),
Shucai Xiao's avatar
Shucai Xiao committed
775
776
                     s_m1,
                     s_m2);
777
778
779
780
781
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {1, 6, 1, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {1, 6, 5, 4}};
Shucai Xiao's avatar
Shucai Xiao committed
782
        expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 6, 1, 4}},
783
                     migraphx::make_op("dot"),
Shucai Xiao's avatar
Shucai Xiao committed
784
785
                     s_m1,
                     s_m2);
786
787
788
789
790
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
791
792
793
794
        expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}},
                     migraphx::make_op("dot"),
                     s_m1,
                     s_m2);
795
796
    }

797
798
799
    {
        migraphx::shape s_m1{migraphx::shape::float_type, {1, 1}};
        migraphx::shape s_m2{migraphx::shape::float_type, {1, 1}};
800
801
802
803
        expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1}},
                     migraphx::make_op("dot"),
                     s_m1,
                     s_m2);
804
805
806
807
808
809
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 7}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4, 7}},
810
                     migraphx::make_op("dot"),
811
812
813
814
815
                     s_m1,
                     s_m2);
    }

    {
816
817
        migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {1, 1, 5, 7}};
818
        throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
819
820
821
822
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {1, 1, 4, 5}};
823
        migraphx::shape s_m2{migraphx::shape::float_type, {1, 2, 5, 7}};
824
        throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
825
826
827
828
829
830
831
832
833
834
    }
}

// 3 input arguments
TEST_CASE(gemm)
{
    {
        migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
        migraphx::shape s_m3{migraphx::shape::float_type, {1}};
835
        throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
836
837
838
839
840
841
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
        migraphx::shape s_m3{migraphx::shape::float_type, {1, 1}};
842
        throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
843
844
845
846
847
848
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
        migraphx::shape s_m3{migraphx::shape::float_type, {8}};
849
        throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
850
851
852
853
854
855
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
        migraphx::shape s_m3{migraphx::shape::float_type, {4, 1}};
856
        throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
857
858
859
860
861
862
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {4, 6}};
        migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
        migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}};
863
        throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
864
865
866
867
868
869
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
        migraphx::shape s_m3{migraphx::shape::float_type, {4}};
870
        throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
871
872
873
874
875
876
877
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
        migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}},
878
                     migraphx::make_op("dot"),
879
                     s_m1,
880
881
                     s_m2,
                     s_m3);
882
883
884
    }

    {
885
886
887
888
        migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
        migraphx::shape s_m3{migraphx::shape::float_type, {1, 4, 8}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4, 8}},
889
                     migraphx::make_op("dot"),
890
891
892
                     s_m1,
                     s_m2,
                     s_m3);
893
894
895
    }

    {
896
897
898
        migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 6}};
        migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
        migraphx::shape s_m3{migraphx::shape::float_type, {1, 4, 8}};
899
        throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
900
901
902
    }

    {
903
904
905
        migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
        migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}};
906
        throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
907
908
909
    }

    {
910
911
912
        migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
        migraphx::shape s_m3{migraphx::shape::float_type};
913
        throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
914
915
916
    }
}

917
918
919
920
921
922
923
// quant_dot
TEST_CASE(quant_dot_2args)
{
    {
        migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
        migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}};
        expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 8}},
924
                     migraphx::make_op("quant_dot"),
925
926
927
928
929
930
931
932
                     s_m1,
                     s_m2);
    }

    {
        migraphx::shape s_m1{migraphx::shape::int8_type, {3, 8}};
        migraphx::shape s_m2{migraphx::shape::int8_type, {8, 7}};
        expect_shape(migraphx::shape{migraphx::shape::int32_type, {3, 7}},
933
                     migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}),
934
935
936
937
938
939
940
                     s_m1,
                     s_m2);
    }

    {
        migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
        migraphx::shape s_m2{migraphx::shape::int8_type, {8, 8}};
941
        throws_shape(migraphx::make_op("quant_dot"), s_m1, s_m2);
942
943
944
945
946
947
948
949
950
951
    }
}

TEST_CASE(quant_dot_3args)
{
    {
        migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
        migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}};
        migraphx::shape s_m3{migraphx::shape::int32_type, {2, 8}};
        expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 8}},
952
                     migraphx::make_op("quant_dot"),
953
954
955
956
957
958
959
960
961
                     s_m1,
                     s_m2,
                     s_m3);
    }

    {
        migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
        migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}};
        migraphx::shape s_m3{migraphx::shape::int8_type, {2, 8}};
962
        throws_shape(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 2}}), s_m1, s_m2, s_m3);
963
964
965
    }
}

966
967
968
969
970
971
972
973
TEST_CASE(rnn)
{
    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 1;
Shucai Xiao's avatar
Shucai Xiao committed
974
        float clip              = 0.0f;
975
976
977
978
979
980
981

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
        migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};

Shucai Xiao's avatar
Shucai Xiao committed
982
983
984
        expect_shape(
            migraphx::shape{migraphx::shape::float_type,
                            {seq_len, num_dirct, batch_size, hidden_size}},
985
986
987
988
989
990
991
            migraphx::make_op(
                "rnn",
                {{"hidden_size", hidden_size},
                 {"actv_func",
                  migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
                 {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
                 {"clip", clip}}),
Shucai Xiao's avatar
Shucai Xiao committed
992
993
994
995
996
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
997
998
999
1000
1001
1002
1003
1004
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 1;
Shucai Xiao's avatar
Shucai Xiao committed
1005
        float clip              = 0.0f;
1006
1007
1008
1009
1010
1011
1012

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
        migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};

Shucai Xiao's avatar
Shucai Xiao committed
1013
1014
1015
        expect_shape(
            migraphx::shape{migraphx::shape::float_type,
                            {seq_len, num_dirct, batch_size, hidden_size}},
1016
1017
1018
1019
1020
1021
1022
            migraphx::make_op(
                "rnn",
                {{"hidden_size", hidden_size},
                 {"actv_func",
                  migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
                 {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
                 {"clip", clip}}),
Shucai Xiao's avatar
Shucai Xiao committed
1023
1024
1025
1026
1027
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
1028
1029
1030
1031
1032
1033
1034
1035
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 2;
Shucai Xiao's avatar
Shucai Xiao committed
1036
        float clip              = 0.0f;
1037
1038
1039
1040
1041
1042
1043

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
        migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};

1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
        expect_shape(
            migraphx::shape{migraphx::shape::float_type,
                            {seq_len, num_dirct, batch_size, hidden_size}},
            migraphx::make_op(
                "rnn",
                {{"hidden_size", hidden_size},
                 {"actv_func",
                  migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
                 {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
                 {"clip", clip}}),
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
1059
1060
1061
1062
1063
1064
1065
1066
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 1;
Shucai Xiao's avatar
Shucai Xiao committed
1067
        float clip              = 0.0f;
1068
1069
1070
1071
1072
1073
1074

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
        migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};

1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
        throws_shape(
            migraphx::make_op(
                "rnn",
                {{"hidden_size", hidden_size + 1},
                 {"actv_func",
                  migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
                 {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
                 {"clip", clip}}),
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
1088
1089
1090
1091
1092
1093
1094
1095
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 1;
Shucai Xiao's avatar
Shucai Xiao committed
1096
        float clip              = 0.0f;
1097
1098
1099
1100
1101
1102
1103

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
        migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};

1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
        throws_shape(
            migraphx::make_op(
                "rnn",
                {{"hidden_size", hidden_size},
                 {"actv_func",
                  migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
                 {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
                 {"clip", clip}}),
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
1117
1118
1119
1120
1121
1122
1123
1124
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 2;
Shucai Xiao's avatar
Shucai Xiao committed
1125
        float clip              = 0.0f;
1126
1127
1128
1129
1130
1131
1132

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
        migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};

Shucai Xiao's avatar
Shucai Xiao committed
1133
        throws_shape(
1134
1135
1136
1137
1138
1139
1140
            migraphx::make_op(
                "rnn",
                {{"hidden_size", hidden_size},
                 {"actv_func",
                  migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
                 {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
                 {"clip", clip}}),
Shucai Xiao's avatar
Shucai Xiao committed
1141
1142
1143
1144
1145
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
1146
1147
1148
    }
}

Shucai Xiao's avatar
Shucai Xiao committed
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
TEST_CASE(gru)
{
    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 1;
        float clip              = 0.0f;

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
Shucai Xiao's avatar
Shucai Xiao committed
1160
1161
1162
1163
        migraphx::shape w_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, hidden_size}};
Shucai Xiao's avatar
Shucai Xiao committed
1164
1165
1166
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};

Shucai Xiao's avatar
Shucai Xiao committed
1167
1168
1169
        expect_shape(
            migraphx::shape{migraphx::shape::float_type,
                            {seq_len, num_dirct, batch_size, hidden_size}},
1170
1171
1172
1173
1174
1175
1176
            migraphx::make_op(
                "gru",
                {{"hidden_size", hidden_size},
                 {"actv_func",
                  migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
                 {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
                 {"clip", clip}}),
Shucai Xiao's avatar
Shucai Xiao committed
1177
1178
1179
1180
1181
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
Shucai Xiao's avatar
Shucai Xiao committed
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 1;
        float clip              = 0.0f;

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
Shucai Xiao's avatar
Shucai Xiao committed
1193
1194
1195
1196
        migraphx::shape w_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, hidden_size}};
Shucai Xiao's avatar
Shucai Xiao committed
1197
1198
1199
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};

Shucai Xiao's avatar
Shucai Xiao committed
1200
1201
1202
        expect_shape(
            migraphx::shape{migraphx::shape::float_type,
                            {seq_len, num_dirct, batch_size, hidden_size}},
1203
1204
1205
1206
1207
1208
1209
            migraphx::make_op(
                "gru",
                {{"hidden_size", hidden_size},
                 {"actv_func",
                  migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
                 {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
                 {"clip", clip}}),
Shucai Xiao's avatar
Shucai Xiao committed
1210
1211
1212
1213
1214
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
Shucai Xiao's avatar
Shucai Xiao committed
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 2;
        float clip              = 0.0f;

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
Shucai Xiao's avatar
Shucai Xiao committed
1226
1227
1228
1229
        migraphx::shape w_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, hidden_size}};
Shucai Xiao's avatar
Shucai Xiao committed
1230
1231
1232
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};

1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
        expect_shape(
            migraphx::shape{migraphx::shape::float_type,
                            {seq_len, num_dirct, batch_size, hidden_size}},
            migraphx::make_op(
                "gru",
                {{"hidden_size", hidden_size},
                 {"actv_func",
                  migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
                 {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
                 {"clip", clip}}),
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
Shucai Xiao's avatar
Shucai Xiao committed
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 1;
        float clip              = 0.0f;

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
Shucai Xiao's avatar
Shucai Xiao committed
1259
1260
1261
1262
        migraphx::shape w_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, hidden_size}};
Shucai Xiao's avatar
Shucai Xiao committed
1263
1264
1265
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};

1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
        throws_shape(
            migraphx::make_op(
                "gru",
                {{"hidden_size", hidden_size + 1},
                 {"actv_func",
                  migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
                 {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
                 {"clip", clip}}),
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
Shucai Xiao's avatar
Shucai Xiao committed
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 1;
        float clip              = 0.0f;

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
Shucai Xiao's avatar
Shucai Xiao committed
1290
1291
1292
1293
        migraphx::shape w_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, hidden_size}};
Shucai Xiao's avatar
Shucai Xiao committed
1294
1295
1296
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};

1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
        throws_shape(
            migraphx::make_op(
                "gru",
                {{"hidden_size", hidden_size},
                 {"actv_func",
                  migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
                 {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
                 {"clip", clip}}),
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
Shucai Xiao's avatar
Shucai Xiao committed
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 2;
        float clip              = 0.0f;

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
Shucai Xiao's avatar
Shucai Xiao committed
1321
1322
1323
1324
        migraphx::shape w_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, hidden_size}};
Shucai Xiao's avatar
Shucai Xiao committed
1325
1326
1327
1328
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};

        throws_shape(
1329
1330
1331
1332
1333
1334
1335
            migraphx::make_op(
                "gru",
                {{"hidden_size", hidden_size},
                 {"actv_func",
                  migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
                 {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
                 {"clip", clip}}),
Shucai Xiao's avatar
Shucai Xiao committed
1336
1337
1338
1339
1340
1341
1342
1343
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
    }
}

1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
TEST_CASE(lstm)
{
    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 1;
        float clip              = 0.0f;

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
        migraphx::shape w_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, hidden_size}};

        expect_shape(
            migraphx::shape{migraphx::shape::float_type,
                            {seq_len, num_dirct, batch_size, hidden_size}},
1363
1364
1365
1366
1367
1368
1369
            migraphx::make_op(
                "lstm",
                {{"hidden_size", hidden_size},
                 {"actv_func",
                  migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
                 {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
                 {"clip", clip}}),
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
            in_shape,
            w_shape,
            r_shape);
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 1;
        float clip              = 0.0f;

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
        migraphx::shape w_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, hidden_size}};
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};

        expect_shape(
            migraphx::shape{migraphx::shape::float_type,
                            {seq_len, num_dirct, batch_size, hidden_size}},
1394
1395
1396
1397
1398
1399
1400
            migraphx::make_op(
                "lstm",
                {{"hidden_size", hidden_size},
                 {"actv_func",
                  migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
                 {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
                 {"clip", clip}}),
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 2;
        float clip              = 0.0f;

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
        migraphx::shape w_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, hidden_size}};
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};

1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
        expect_shape(
            migraphx::shape{migraphx::shape::float_type,
                            {seq_len, num_dirct, batch_size, hidden_size}},
            migraphx::make_op(
                "lstm",
                {{"hidden_size", hidden_size},
                 {"actv_func",
                  migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
                 {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
                 {"clip", clip}}),
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 1;
        float clip              = 0.0f;

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
        migraphx::shape w_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, hidden_size}};
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};

1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
        throws_shape(
            migraphx::make_op(
                "lstm",
                {{"hidden_size", hidden_size + 1},
                 {"actv_func",
                  migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
                 {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
                 {"clip", clip}}),
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 1;
        float clip              = 0.0f;

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
        migraphx::shape w_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, hidden_size}};
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};

1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
        throws_shape(
            migraphx::make_op(
                "lstm",
                {{"hidden_size", hidden_size},
                 {"actv_func",
                  migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
                 {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
                 {"clip", clip}}),
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 2;
        float clip              = 0.0f;

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
        migraphx::shape w_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, hidden_size}};
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};

        throws_shape(
1520
1521
1522
1523
1524
1525
1526
            migraphx::make_op(
                "lstm",
                {{"hidden_size", hidden_size},
                 {"actv_func",
                  migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
                 {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
                 {"clip", clip}}),
1527
1528
1529
1530
1531
1532
1533
1534
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
    }
}

Paul's avatar
Paul committed
1535
int main(int argc, const char* argv[]) { test::run(argc, argv); }