op_shape_test.cpp 59.7 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
#include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
5
#include <sstream>
6
7
8
9
#include <migraphx/make_op.hpp>

#include <migraphx/serialize.hpp>

10
11
#include "test.hpp"

Paul's avatar
Paul committed
12
template <class... Ts>
Paul's avatar
Paul committed
13
void expect_shape(const migraphx::shape& expected, const migraphx::operation& op, Ts... xs)
14
{
Paul's avatar
Paul committed
15
    migraphx::program p;
16
    auto* mm = p.get_main_module();
Paul's avatar
Paul committed
17
18
    std::vector<migraphx::shape> shapes{xs...};
    std::vector<migraphx::instruction_ref> args(shapes.size());
Paul's avatar
Paul committed
19
    std::transform(
20
21
        shapes.begin(), shapes.end(), args.begin(), [&](auto&& s) { return mm->add_outline(s); });
    mm->add_instruction(op, args);
22
    if(p.get_output_shapes().back() != expected)
Paul's avatar
Paul committed
23
    {
24
        std::cout << "FAILED: Incorrect shape for " << op.name() << ": ";
25
        std::cout << expected << " != " << p.get_output_shapes().back() << std::endl;
Paul's avatar
Paul committed
26
        for(auto&& s : shapes)
27
28
29
30
            std::cout << "    " << s << std::endl;
    }
}

Paul's avatar
Paul committed
31
template <class... Ts>
Paul's avatar
Paul committed
32
void throws_shape(const migraphx::operation& op, Ts... xs)
33
{
Paul's avatar
Paul committed
34
    migraphx::program p;
35
    auto* mm = p.get_main_module();
Paul's avatar
Paul committed
36
37
    std::vector<migraphx::shape> shapes{xs...};
    std::vector<migraphx::instruction_ref> args(shapes.size());
Paul's avatar
Paul committed
38
    std::transform(
39
40
        shapes.begin(), shapes.end(), args.begin(), [&](auto&& s) { return mm->add_outline(s); });
    bool thrown = test::throws([&] { mm->add_instruction(op, args); });
Paul's avatar
Paul committed
41
42
    if(not thrown)
    {
43
        std::cout << "FAILED: No error found for " << op.name() << ": ";
Paul's avatar
Paul committed
44
        for(auto&& s : shapes)
45
46
47
48
            std::cout << "    " << s << std::endl;
    }
}

Paul's avatar
Paul committed
49
50
51
52
template <class...>
struct always_false : std::false_type
{
};
53

Paul's avatar
Paul committed
54
template <class... Ts>
Paul's avatar
Paul committed
55
void throws_shape(const migraphx::shape&, Ts...)
56
{
Paul's avatar
Paul committed
57
58
    static_assert(always_false<Ts...>{},
                  "An expected shape should not be passed to throws_shape function");
59
60
}

Paul's avatar
Paul committed
61
TEST_CASE(batch_norm_inference_shape)
62
63
{
    const size_t channels = 3;
Paul's avatar
Paul committed
64
65
    migraphx::shape s{migraphx::shape::float_type, {4, channels, 3, 3}};
    migraphx::shape vars{migraphx::shape::float_type, {channels}};
66
67
68
    expect_shape(s, migraphx::make_op("batch_norm_inference"), s, vars, vars, vars, vars);
    throws_shape(migraphx::make_op("batch_norm_inference"), s);
    throws_shape(migraphx::make_op("batch_norm_inference"), s, vars, vars, vars, vars, vars);
69
70
}

Paul's avatar
Paul committed
71
TEST_CASE(convolution_shape)
72
{
Paul's avatar
Paul committed
73
74
75
    migraphx::shape output{migraphx::shape::float_type, {4, 4, 1, 1}};
    migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
    migraphx::shape weights{migraphx::shape::float_type, {4, 3, 3, 3}};
76
77
78
79
80
    expect_shape(output, migraphx::make_op("convolution"), input, weights);
    throws_shape(migraphx::make_op("convolution"), input);
    throws_shape(
        migraphx::make_op("convolution", {{"padding", {0}}, {"stride", {1}}, {"dilation", {1}}}),
        input);
Paul's avatar
Paul committed
81
82
83

    migraphx::shape input2{migraphx::shape::float_type, {3, 3}};
    migraphx::shape weights2{migraphx::shape::float_type, {3, 3}};
84
85
    throws_shape(migraphx::make_op("convolution"), input2, weights2);
    throws_shape(migraphx::make_op("convolution"), input2, weights);
86
87
88
89

    migraphx::shape output_1d{migraphx::shape::float_type, {4, 4, 1}};
    migraphx::shape input_1d{migraphx::shape::float_type, {4, 3, 3}};
    migraphx::shape weights_1d{migraphx::shape::float_type, {4, 3, 3}};
90
91
92
93
94
    expect_shape(
        output_1d,
        migraphx::make_op("convolution", {{"padding", {0}}, {"stride", {1}}, {"dilation", {1}}}),
        input_1d,
        weights_1d);
95
96
97
98

    migraphx::shape output_3d{migraphx::shape::float_type, {4, 4, 1, 1, 1}};
    migraphx::shape input_3d{migraphx::shape::float_type, {4, 3, 3, 3, 3}};
    migraphx::shape weights_3d{migraphx::shape::float_type, {4, 3, 3, 3, 3}};
99
100
101
102
103
104
105
106
    expect_shape(
        output_3d,
        migraphx::make_op("convolution",
                          {{"padding", {0, 0, 0}}, {"stride", {1, 1, 1}}, {"dilation", {1, 1, 1}}}),
        input_3d,
        weights_3d);

    throws_shape(migraphx::make_op("convolution"), input_3d, weights_3d);
107
108
}

kahmed10's avatar
kahmed10 committed
109
110
111
112
113
TEST_CASE(deconvolution_shape)
{
    migraphx::shape input{migraphx::shape::float_type, {4, 4, 1, 1}};
    migraphx::shape output{migraphx::shape::float_type, {4, 3, 3, 3}};
    migraphx::shape weights{migraphx::shape::float_type, {4, 3, 3, 3}};
114
115
116
117
118
    expect_shape(output, migraphx::make_op("deconvolution"), input, weights);
    throws_shape(migraphx::make_op("deconvolution"), input);
    throws_shape(
        migraphx::make_op("deconvolution", {{"padding", {0}}, {"stride", {1}}, {"dilation", {1}}}),
        input);
kahmed10's avatar
kahmed10 committed
119
120
121
122

    migraphx::shape input_1d{migraphx::shape::float_type, {4, 4, 1}};
    migraphx::shape output_1d{migraphx::shape::float_type, {4, 3, 3}};
    migraphx::shape weights_1d{migraphx::shape::float_type, {4, 3, 3}};
123
124
125
126
127
    expect_shape(
        output_1d,
        migraphx::make_op("deconvolution", {{"padding", {0}}, {"stride", {1}}, {"dilation", {1}}}),
        input_1d,
        weights_1d);
kahmed10's avatar
kahmed10 committed
128
129
130
131

    migraphx::shape input_3d{migraphx::shape::float_type, {4, 4, 1, 1, 1}};
    migraphx::shape output_3d{migraphx::shape::float_type, {4, 3, 3, 3, 3}};
    migraphx::shape weights_3d{migraphx::shape::float_type, {4, 3, 3, 3, 3}};
132
133
134
135
136
137
    expect_shape(
        output_3d,
        migraphx::make_op("deconvolution",
                          {{"padding", {0, 0, 0}}, {"stride", {1, 1, 1}}, {"dilation", {1, 1, 1}}}),
        input_3d,
        weights_3d);
kahmed10's avatar
kahmed10 committed
138
139
}

140
141
TEST_CASE(quant_convolution_shape)
{
142
    migraphx::shape output{migraphx::shape::int32_type, {4, 4, 1, 1}};
143
144
    migraphx::shape input{migraphx::shape::int8_type, {4, 3, 3, 3}};
    migraphx::shape weights{migraphx::shape::int8_type, {4, 3, 3, 3}};
145
146
147
148
149
150
151
152
153
154
    expect_shape(output, migraphx::make_op("quant_convolution"), input, weights);
    throws_shape(migraphx::make_op("quant_convolution"), input);
    throws_shape(migraphx::make_op("quant_convolution",
                                   {{"padding", {0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}),
                 input,
                 weights);
    throws_shape(migraphx::make_op("quant_convolution",
                                   {{"padding", {0}}, {"stride", {1}}, {"dilation", {1}}}),
                 input,
                 weights);
155

156
    migraphx::shape input2{migraphx::shape::int32_type, {3, 3}};
157
    migraphx::shape weights2{migraphx::shape::float_type, {3, 3}};
158
159
    throws_shape(migraphx::make_op("quant_convolution"), input2, weights2);
    throws_shape(migraphx::make_op("quant_convolution"), input2, weights);
160

161
    migraphx::shape input3{migraphx::shape::int32_type, {4, 3, 3, 3}};
162
    migraphx::shape weight3{migraphx::shape::float_type, {4, 3, 3, 3}};
163
164
165
    throws_shape(migraphx::make_op("quant_convolution"), input3, weights);
    throws_shape(migraphx::make_op("quant_convolution"), input, weight3);
    throws_shape(migraphx::make_op("quant_convolution"), input3, weight3);
166
167
}

kahmed10's avatar
kahmed10 committed
168
169
170
171
TEST_CASE(pooling_shape)
{
    migraphx::shape output{migraphx::shape::float_type, {4, 3, 1, 1}};
    migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
172
173
174
175
176
177
178
179
180
181
    throws_shape(
        migraphx::make_op("pooling",
                          {{"mode", "max"}, {"padding", {1}}, {"stride", {0}}, {"lengths", {1}}}),
        input);
    expect_shape(
        output,
        migraphx::make_op(
            "pooling",
            {{"mode", "max"}, {"padding", {0, 0}}, {"stride", {3, 3}}, {"lengths", {1, 1}}}),
        input);
Shucai Xiao's avatar
Shucai Xiao committed
182
183

    migraphx::shape output1{migraphx::shape::float_type, {4, 3, 2, 2}};
184
185
186
187
188
189
190
191
    expect_shape(output1,
                 migraphx::make_op("pooling",
                                   {{"mode", "max"},
                                    {"padding", {0, 0}},
                                    {"stride", {3, 3}},
                                    {"lengths", {1, 1}},
                                    {"ceil_mode", true}}),
                 input);
kahmed10's avatar
kahmed10 committed
192
193
}

194
195
196
197
TEST_CASE(inconsistent_attr_shape)
{
    migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
    migraphx::shape weights{migraphx::shape::float_type, {4, 3, 3, 3}};
198
199
200
201
202
203
204
205
206
207
208
209
    throws_shape(migraphx::make_op("convolution",
                                   {{"padding", {1, 1}}, {"stride", {2}}, {"dilation", {3, 3, 3}}}),
                 input,
                 weights);
    throws_shape(migraphx::make_op("deconvolution",
                                   {{"padding", {1, 1}}, {"stride", {2}}, {"dilation", {3, 3, 3}}}),
                 input,
                 weights);
    throws_shape(
        migraphx::make_op(
            "pooling", {{"mode", "max"}, {"padding", {1}}, {"stride", {0}}, {"lengths", {1, 1}}}),
        input);
210
211
}

Paul's avatar
Paul committed
212
TEST_CASE(transpose_shape)
213
{
Paul's avatar
Paul committed
214
215
    migraphx::shape input{migraphx::shape::float_type, {2, 2}};
    migraphx::shape output{migraphx::shape::float_type, {2, 2}, {1, 2}};
216
217
218
219
    expect_shape(input, migraphx::make_op("transpose", {{"dims", {0, 1}}}), input);
    expect_shape(output, migraphx::make_op("transpose", {{"dims", {1, 0}}}), input);
    expect_shape(output, migraphx::make_op("transpose"), input);
    throws_shape(migraphx::make_op("transpose", {{"dims", {1, 2}}}), input);
220
221
}

Paul's avatar
Paul committed
222
TEST_CASE(contiguous_shape)
223
{
Paul's avatar
Paul committed
224
225
    migraphx::shape output{migraphx::shape::float_type, {2, 2}};
    migraphx::shape input{migraphx::shape::float_type, {2, 2}, {1, 2}};
226
227
    expect_shape(output, migraphx::make_op("contiguous"), input);
    throws_shape(migraphx::make_op("contiguous"), input, input);
Paul's avatar
Paul committed
228

Paul's avatar
Paul committed
229
    migraphx::shape single{migraphx::shape::float_type, {2}};
230
    expect_shape(single, migraphx::make_op("contiguous"), single);
231
232
}

233
234
235
236
TEST_CASE(contiguous_shape_scalar)
{
    migraphx::shape output{migraphx::shape::float_type};
    migraphx::shape input{migraphx::shape::float_type};
237
    expect_shape(output, migraphx::make_op("contiguous"), input);
238
239
}

Paul's avatar
Paul committed
240
TEST_CASE(reshape_shape)
241
{
Paul's avatar
Paul committed
242
    migraphx::shape input{migraphx::shape::float_type, {24, 1, 1, 1}};
Paul's avatar
Paul committed
243
244
    for(auto&& new_shape :
        std::vector<std::vector<int64_t>>{{8, 3, 1, 1}, {1, 3, 4, 2}, {1, 3, 4, 2}})
245
246
247
    {
        std::vector<std::size_t> lens(new_shape.size());
        std::copy(new_shape.begin(), new_shape.end(), lens.begin());
Paul's avatar
Paul committed
248
        migraphx::shape output{migraphx::shape::float_type, lens};
249
        expect_shape(output, migraphx::make_op("reshape", {{"dims", new_shape}}), input);
250
251
    }

Shucai Xiao's avatar
Shucai Xiao committed
252
253
    for(auto&& new_shape :
        std::vector<std::vector<int64_t>>{{8, 3, 2, 2}, {1, 3, -1, -1}, {3, 0, 0}, {3, 2, 0}})
254
    {
255
        throws_shape(migraphx::make_op("reshape", {{"dims", new_shape}}), input);
256
    }
Shucai Xiao's avatar
Shucai Xiao committed
257

Shucai Xiao's avatar
Shucai Xiao committed
258
    std::vector<std::pair<std::vector<int64_t>, migraphx::shape>> minus1_tests{
Shucai Xiao's avatar
Shucai Xiao committed
259
260
261
262
263
264
265
266
        {{2, -1, 3}, {migraphx::shape::float_type, {2, 4, 3}}},
        {{0, -1, 0}, {migraphx::shape::float_type, {24, 1, 1}}},
        {{2, -1, 0}, {migraphx::shape::float_type, {2, 12, 1}}},
        {{0, 0, -1}, {migraphx::shape::float_type, {24, 1, 1}}},
        {{2, 0, -1}, {migraphx::shape::float_type, {2, 1, 12}}},
        {{-1, 2, 3}, {migraphx::shape::float_type, {4, 2, 3}}},
        {{-1, 0, 3}, {migraphx::shape::float_type, {8, 1, 3}}},
        {{-1, 0, 0}, {migraphx::shape::float_type, {24, 1, 1}}},
Shucai Xiao's avatar
Shucai Xiao committed
267
        {{-1, 3, 0}, {migraphx::shape::float_type, {8, 3, 1}}}};
Shucai Xiao's avatar
Shucai Xiao committed
268

Shucai Xiao's avatar
Shucai Xiao committed
269
    for(auto& it : minus1_tests)
Shucai Xiao's avatar
Shucai Xiao committed
270
    {
271
        expect_shape(it.second, migraphx::make_op("reshape", {{"dims", it.first}}), input);
Shucai Xiao's avatar
Shucai Xiao committed
272
    }
273
274
}

Paul's avatar
Paul committed
275
TEST_CASE(flatten_shape)
276
{
Paul's avatar
Paul committed
277
278
    migraphx::shape input{migraphx::shape::float_type, {2, 4, 6, 8}};
    expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 2 * 4 * 6 * 8}},
279
                 migraphx::make_op("flatten", {{"axis", 0}}),
Scott Thornton's avatar
Scott Thornton committed
280
                 input);
281
    expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 2 * 4 * 6 * 8}},
282
                 migraphx::make_op("flatten", {{"axis", -4}}),
283
                 input);
Paul's avatar
Paul committed
284
    expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 4 * 6 * 8}},
285
                 migraphx::make_op("flatten", {{"axis", 1}}),
Paul's avatar
Paul committed
286
                 input);
287
    expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 4 * 6 * 8}},
288
                 migraphx::make_op("flatten", {{"axis", -3}}),
289
                 input);
Paul's avatar
Paul committed
290
    expect_shape(migraphx::shape{migraphx::shape::float_type, {2 * 4, 6 * 8}},
291
                 migraphx::make_op("flatten", {{"axis", 2}}),
Paul's avatar
Paul committed
292
293
                 input);
    expect_shape(migraphx::shape{migraphx::shape::float_type, {2 * 4 * 6, 8}},
294
                 migraphx::make_op("flatten", {{"axis", 3}}),
Paul's avatar
Paul committed
295
                 input);
Paul's avatar
Paul committed
296
    expect_shape(migraphx::shape{migraphx::shape::float_type, {2 * 4 * 6 * 8, 1}},
297
                 migraphx::make_op("flatten", {{"axis", 4}}),
Scott Thornton's avatar
Scott Thornton committed
298
                 input);
299
300
    throws_shape(migraphx::make_op("flatten", {{"axis", 5}}), input);
    throws_shape(migraphx::make_op("flatten", {{"axis", -5}}), input);
301
302
}

Paul's avatar
Paul committed
303
TEST_CASE(slice_shape)
Scott Thornton's avatar
Scott Thornton committed
304
{
Paul's avatar
Paul committed
305
306
    migraphx::shape input{migraphx::shape::int32_type, {2, 2, 3}};
    expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}},
307
                 migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1}}, {"ends", {3}}}),
Scott Thornton's avatar
Scott Thornton committed
308
                 input);
Paul's avatar
Paul committed
309
    expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}},
310
311
                 migraphx::make_op(
                     "slice", {{"axes", {0, 1, 2}}, {"starts", {0, 0, 1}}, {"ends", {2, 2, 3}}}),
Scott Thornton's avatar
Scott Thornton committed
312
                 input);
Paul's avatar
Paul committed
313
    expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 1}, {6, 3, 1}},
314
                 migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {10}}}),
Scott Thornton's avatar
Scott Thornton committed
315
316
                 input);
}
Scott Thornton's avatar
Scott Thornton committed
317

wsttiger's avatar
wsttiger committed
318
TEST_CASE(multibroadcast)
Scott Thornton's avatar
Scott Thornton committed
319
320
{
    {
321
        std::vector<std::size_t> lens{4, 2, 5, 3};
Paul's avatar
Paul committed
322
323
        migraphx::shape input{migraphx::shape::float_type, {2, 1, 3}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 3, 0, 1}},
324
                     migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
325
                     input);
Scott Thornton's avatar
Scott Thornton committed
326
327
    }
    {
328
        std::vector<std::size_t> lens{4, 2, 5, 3};
Paul's avatar
Paul committed
329
330
        migraphx::shape input{migraphx::shape::float_type, {2, 1, 1}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 1, 0, 0}},
331
                     migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
332
                     input);
Scott Thornton's avatar
Scott Thornton committed
333
334
    }
    {
335
        std::vector<std::size_t> lens{4, 2, 5, 3};
Paul's avatar
Paul committed
336
337
        migraphx::shape input{migraphx::shape::float_type, {5, 1}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 0, 1, 0}},
338
                     migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
339
                     input);
Scott Thornton's avatar
Scott Thornton committed
340
341
    }
    {
342
        std::vector<std::size_t> lens{4, 2, 5, 3};
Paul's avatar
Paul committed
343
344
        migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {1, 0, 0, 0}},
345
                     migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
346
347
348
349
                     input);
    }
    {
        std::vector<std::size_t> lens{4, 2, 5, 3};
Paul's avatar
Paul committed
350
351
        migraphx::shape input{migraphx::shape::float_type, {3}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 0, 0, 1}},
352
                     migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
353
354
                     input);
    }
355
356
    {
        std::vector<std::size_t> lens{4, 4, 1, 3};
Paul's avatar
Paul committed
357
358
        migraphx::shape input{migraphx::shape::float_type, {4, 1, 3}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 3, 3, 1}},
359
                     migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
360
361
                     input);
    }
362
363
    {
        std::vector<std::size_t> lens{4, 1, 1, 3};
Paul's avatar
Paul committed
364
365
        migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {1, 1, 1, 0}},
366
                     migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
367
368
369
370
                     input);
    }
    {
        std::vector<std::size_t> lens{4, 1, 3};
Paul's avatar
Paul committed
371
        migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
372
        throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input);
Scott Thornton's avatar
Scott Thornton committed
373
374
    }
    {
375
        std::vector<std::size_t> lens{4, 1, 3};
Paul's avatar
Paul committed
376
        migraphx::shape input{migraphx::shape::float_type, {}};
377
        throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input);
Scott Thornton's avatar
Scott Thornton committed
378
    }
Shucai Xiao's avatar
Shucai Xiao committed
379
380
381
    {
        std::vector<std::size_t> lens{2, 3, 4, 5};
        migraphx::shape input{migraphx::shape::float_type, {3, 4}};
382
        throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input);
Shucai Xiao's avatar
Shucai Xiao committed
383
384
385
386
    }
    {
        std::vector<std::size_t> lens{2, 3, 4, 5};
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4}};
387
        throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input);
Shucai Xiao's avatar
Shucai Xiao committed
388
    }
Scott Thornton's avatar
Scott Thornton committed
389
390
}

391
392
393
394
395
396
TEST_CASE(broadcast)
{
    {
        std::vector<std::size_t> lens{1, 1};
        migraphx::shape input{migraphx::shape::float_type, {4, 1, 3}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1}, {0, 0}},
397
                     migraphx::make_op("broadcast", {{"axis", 0}, {"dims", lens}}),
398
399
                     input);
    }
kahmed10's avatar
kahmed10 committed
400

401
402
    {
        std::vector<std::size_t> lens{1, 1};
kahmed10's avatar
kahmed10 committed
403
404
        migraphx::shape input{migraphx::shape::float_type, {1}, {0}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1}, {0, 0}},
405
                     migraphx::make_op("broadcast", {{"axis", 1}, {"dims", lens}}),
kahmed10's avatar
kahmed10 committed
406
                     input);
407
408
409
410
411
412
    }

    {
        std::vector<std::size_t> lens{3, 2, 4, 3};
        migraphx::shape input{migraphx::shape::float_type, {4, 3}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 2, 4, 3}, {0, 0, 3, 1}},
413
                     migraphx::make_op("broadcast", {{"axis", 2}, {"dims", lens}}),
414
415
416
417
418
419
                     input);
    }

    {
        std::vector<std::size_t> lens{3, 2, 4, 3};
        migraphx::shape input{migraphx::shape::float_type, {4, 4}};
420
        throws_shape(migraphx::make_op("broadcast", {{"axis", 2}, {"dims", lens}}), input);
421
422
423
    }
}

Shucai Xiao's avatar
Shucai Xiao committed
424
TEST_CASE(gather)
425
426
427
428
{
    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
        migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
429
        int axis = 1;
430
        expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 2, 3, 4, 5}},
431
                     migraphx::make_op("gather", {{"axis", axis}}),
Shucai Xiao's avatar
Shucai Xiao committed
432
433
                     input,
                     indices);
434
435
    }

436
437
438
439
    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
        migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
        int axis = -4;
440
        expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 3, 4, 5}},
441
                     migraphx::make_op("gather", {{"axis", axis}}),
442
443
444
445
                     input,
                     indices);
    }

446
447
448
449
450
    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
        migraphx::shape indices{migraphx::shape::int32_type, {1}};
        int axis = -4;
        expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}},
451
                     migraphx::make_op("gather", {{"axis", axis}}),
452
453
454
455
456
457
                     input,
                     indices);
    }

    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
458
        migraphx::shape indices{migraphx::shape::int32_type};
459
460
        int axis = -4;
        expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 4, 5}},
461
                     migraphx::make_op("gather", {{"axis", axis}}),
462
463
464
465
466
467
                     input,
                     indices);
    }

    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
468
        migraphx::shape indices{migraphx::shape::int32_type};
469
470
        int axis = 3;
        expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4}},
471
                     migraphx::make_op("gather", {{"axis", axis}}),
472
473
474
475
476
477
                     input,
                     indices);
    }

    {
        migraphx::shape input{migraphx::shape::float_type, {3}};
478
        migraphx::shape indices{migraphx::shape::int32_type};
479
        int axis = 0;
480
        expect_shape(migraphx::shape{migraphx::shape::float_type},
481
                     migraphx::make_op("gather", {{"axis", axis}}),
482
483
484
485
486
487
488
489
490
                     input,
                     indices);
    }

    {
        migraphx::shape input{migraphx::shape::float_type, {3}};
        migraphx::shape indices{migraphx::shape::int32_type, {1}};
        int axis = 0;
        expect_shape(migraphx::shape{migraphx::shape::float_type, {1}},
491
                     migraphx::make_op("gather", {{"axis", axis}}),
492
493
494
495
                     input,
                     indices);
    }

496
497
498
    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
        migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
499
        int axis = 4;
500
        throws_shape(migraphx::make_op("gather", {{"axis", axis}}), input, indices);
501
    }
502
503
504
505
506

    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
        migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
        int axis = -5;
507
        throws_shape(migraphx::make_op("gather", {{"axis", axis}}), input, indices);
508
    }
509
510
}

Khalique's avatar
Khalique committed
511
template <class T>
Khalique's avatar
Khalique committed
512
void test_softmax_variations()
513
514
515
{
    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
Khalique's avatar
Khalique committed
516
        expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{0}, input);
517
518
519
520
    }

    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
Khalique's avatar
Khalique committed
521
        expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{1}, input);
522
523
524
525
    }

    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
Khalique's avatar
Khalique committed
526
        expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{2}, input);
527
528
529
530
    }

    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
Khalique's avatar
Khalique committed
531
        expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{3}, input);
532
533
534
535
536
    }

    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
        int axis = 4;
Khalique's avatar
Khalique committed
537
        throws_shape(T{axis}, input);
538
539
540
    }
}

Khalique's avatar
Khalique committed
541
TEST_CASE(softmax) { test_softmax_variations<migraphx::op::softmax>(); }
542

Khalique's avatar
Khalique committed
543
TEST_CASE(logsoftmax) { test_softmax_variations<migraphx::op::logsoftmax>(); }
544

Shucai Xiao's avatar
Shucai Xiao committed
545
TEST_CASE(test_argmax)
546
547
548
{
    {
        migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
Shucai Xiao's avatar
Shucai Xiao committed
549
        expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}},
550
                     migraphx::make_op("argmax", {{"axis", 0}}),
Shucai Xiao's avatar
Shucai Xiao committed
551
                     input);
Shucai Xiao's avatar
Shucai Xiao committed
552
553
554
555
    }

    {
        migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
Shucai Xiao's avatar
Shucai Xiao committed
556
        expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}},
557
                     migraphx::make_op("argmax", {{"axis", 1}}),
Shucai Xiao's avatar
Shucai Xiao committed
558
                     input);
Shucai Xiao's avatar
Shucai Xiao committed
559
560
561
562
    }

    {
        migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
Shucai Xiao's avatar
Shucai Xiao committed
563
        expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}},
564
                     migraphx::make_op("argmax", {{"axis", 2}}),
Shucai Xiao's avatar
Shucai Xiao committed
565
                     input);
Shucai Xiao's avatar
Shucai Xiao committed
566
567
568
569
    }

    {
        migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
Shucai Xiao's avatar
Shucai Xiao committed
570
        expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}},
571
                     migraphx::make_op("argmax", {{"axis", 3}}),
Shucai Xiao's avatar
Shucai Xiao committed
572
                     input);
Shucai Xiao's avatar
Shucai Xiao committed
573
574
575
576
    }

    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
577
        throws_shape(migraphx::make_op("argmax", {{"axis", 4}}), input);
578
    }
Shucai Xiao's avatar
Shucai Xiao committed
579
}
580

Shucai Xiao's avatar
Shucai Xiao committed
581
TEST_CASE(test_argmin)
582
583
584
{
    {
        migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
Shucai Xiao's avatar
Shucai Xiao committed
585
        expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}},
586
                     migraphx::make_op("argmin", {{"axis", 0}}),
Shucai Xiao's avatar
Shucai Xiao committed
587
                     input);
588
589
590
591
    }

    {
        migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
Shucai Xiao's avatar
Shucai Xiao committed
592
        expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}},
593
                     migraphx::make_op("argmin", {{"axis", 1}}),
Shucai Xiao's avatar
Shucai Xiao committed
594
                     input);
595
596
597
598
    }

    {
        migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
Shucai Xiao's avatar
Shucai Xiao committed
599
        expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}},
600
                     migraphx::make_op("argmin", {{"axis", 2}}),
Shucai Xiao's avatar
Shucai Xiao committed
601
                     input);
602
603
604
605
    }

    {
        migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
Shucai Xiao's avatar
Shucai Xiao committed
606
        expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}},
607
                     migraphx::make_op("argmin", {{"axis", 3}}),
Shucai Xiao's avatar
Shucai Xiao committed
608
                     input);
Shucai Xiao's avatar
Shucai Xiao committed
609
610
611
612
    }

    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
613
        throws_shape(migraphx::make_op("argmin", {{"axis", 4}}), input);
614
    }
Shucai Xiao's avatar
Shucai Xiao committed
615
}
616

617
618
619
620
TEST_CASE(test_scalar)
{
    migraphx::shape s1{migraphx::shape::float_type, {1}, {1}};
    migraphx::shape s2{migraphx::shape::float_type, {2, 3, 4, 5}, {0, 0, 0, 0}};
621
    expect_shape(s2, migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 3, 4, 5}}}), s1);
622
623
624
625
626
}

TEST_CASE(test_scalar_nelemnts)
{
    migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
627
    throws_shape(migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 3, 4, 5}}}), input);
628
629
}

630
631
TEST_CASE(test_squeeze)
{
632
633
    migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}};
    migraphx::shape s2{migraphx::shape::float_type, {4, 1, 3, 3}};
634
    expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {3}}}), s1);
635
}
636

637
638
639
640
TEST_CASE(test_squeeze_negative_axis)
{
    migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}};
    migraphx::shape s2{migraphx::shape::float_type, {4, 1, 3, 3}};
641
    expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {-2}}}), s1);
642
643
644
645
646
}

TEST_CASE(test_squeeze_wrong_axis)
{
    migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}};
647
    throws_shape(migraphx::make_op("squeeze", {{"axes", {0}}}), s1);
648
649
650
651
652
653
}

TEST_CASE(test_squeeze_all)
{
    migraphx::shape s1{migraphx::shape::float_type, {1}};
    migraphx::shape s2{migraphx::shape::float_type};
654
    expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {0}}}), s1);
655
656
657
658
659
660
}

TEST_CASE(test_unsqueeze_scalar)
{
    migraphx::shape s1{migraphx::shape::float_type, {1}, {0}};
    migraphx::shape s2{migraphx::shape::float_type, {1}, {1}};
661
    expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {0}}}), s1);
662
663
664
665
666
}

TEST_CASE(test_unsqueeze_scalar_tensor1)
{
    migraphx::shape s{migraphx::shape::float_type, {4, 3, 3}, {0, 0, 0}};
667
    throws_shape(migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s);
668
669
670
671
672
}

TEST_CASE(test_unsqueeze_scalar_tensor2)
{
    migraphx::shape s{migraphx::shape::float_type, {1, 1, 1}, {0, 0, 0}};
673
    throws_shape(migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s);
674
675
676
677
678
679
}

TEST_CASE(test_unsqueeze)
{
    migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}};
    migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}};
680
    expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1);
681
682
683
684
685
686
}

TEST_CASE(test_unsqueeze_negative_axis)
{
    migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}};
    migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}};
687
    expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s1);
688
689
}

Shucai Xiao's avatar
Shucai Xiao committed
690
691
template <class T>
void test_reduce_ops()
Shucai Xiao's avatar
Shucai Xiao committed
692
{
693
694
695
696
697
698
699
    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, T{}, input);
    }

    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
Shucai Xiao's avatar
Shucai Xiao committed
700
701
        expect_shape(
            migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, T{{0, 1, 2, 3}}, input);
702
    }
Shucai Xiao's avatar
Shucai Xiao committed
703
704
705
706
707
708
709
710
    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 1, 1}}, T{{2, 3}}, input);
    }
    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}}, T{{0}}, input);
    }
711
712
713
714
    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 1}}, T{{-1}}, input);
    }
715
716
    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
Shucai Xiao's avatar
Shucai Xiao committed
717
        throws_shape(T{{4}}, input);
718
719
720
    }
}

Shucai Xiao's avatar
Shucai Xiao committed
721
722
TEST_CASE(reduce_sum) { test_reduce_ops<migraphx::op::reduce_sum>(); }
TEST_CASE(reduce_mean) { test_reduce_ops<migraphx::op::reduce_mean>(); }
723

724
725
// 2 inputs arguments
TEST_CASE(matmul)
726
727
{
    {
728
729
        migraphx::shape s_m1{migraphx::shape::float_type, {5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {5}};
730
        throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
731
732
733
734
735
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {5, 2}};
736
        throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
737
738
739
740
741
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {5}};
742
        throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
743
744
745
746
747
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {5, 4}};
748
749
750
751
        expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4}},
                     migraphx::make_op("dot"),
                     s_m1,
                     s_m2);
752
753
754
    }

    {
755
756
        migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {4, 4}};
757
        throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
758
759
760
761
762
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {6, 5, 4}};
763
        throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
764
765
    }

766
767
768
    {
        migraphx::shape s_m1{migraphx::shape::float_type, {6, 1, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {6, 5, 4}};
Shucai Xiao's avatar
Shucai Xiao committed
769
        expect_shape(migraphx::shape{migraphx::shape::float_type, {6, 1, 4}},
770
                     migraphx::make_op("dot"),
Shucai Xiao's avatar
Shucai Xiao committed
771
772
                     s_m1,
                     s_m2);
773
774
775
776
777
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {1, 6, 1, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {1, 6, 5, 4}};
Shucai Xiao's avatar
Shucai Xiao committed
778
        expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 6, 1, 4}},
779
                     migraphx::make_op("dot"),
Shucai Xiao's avatar
Shucai Xiao committed
780
781
                     s_m1,
                     s_m2);
782
783
784
785
786
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
787
788
789
790
        expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}},
                     migraphx::make_op("dot"),
                     s_m1,
                     s_m2);
791
792
    }

793
794
795
    {
        migraphx::shape s_m1{migraphx::shape::float_type, {1, 1}};
        migraphx::shape s_m2{migraphx::shape::float_type, {1, 1}};
796
797
798
799
        expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1}},
                     migraphx::make_op("dot"),
                     s_m1,
                     s_m2);
800
801
802
803
804
805
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 7}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4, 7}},
806
                     migraphx::make_op("dot"),
807
808
809
810
811
                     s_m1,
                     s_m2);
    }

    {
812
813
        migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {1, 1, 5, 7}};
814
        throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
815
816
817
818
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {1, 1, 4, 5}};
819
        migraphx::shape s_m2{migraphx::shape::float_type, {1, 2, 5, 7}};
820
        throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
821
822
823
824
825
826
827
828
829
830
    }
}

// 3 input arguments
TEST_CASE(gemm)
{
    {
        migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
        migraphx::shape s_m3{migraphx::shape::float_type, {1}};
831
        throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
832
833
834
835
836
837
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
        migraphx::shape s_m3{migraphx::shape::float_type, {1, 1}};
838
        throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
839
840
841
842
843
844
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
        migraphx::shape s_m3{migraphx::shape::float_type, {8}};
845
        throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
846
847
848
849
850
851
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
        migraphx::shape s_m3{migraphx::shape::float_type, {4, 1}};
852
        throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
853
854
855
856
857
858
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {4, 6}};
        migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
        migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}};
859
        throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
860
861
862
863
864
865
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
        migraphx::shape s_m3{migraphx::shape::float_type, {4}};
866
        throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
867
868
869
870
871
872
873
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
        migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}},
874
                     migraphx::make_op("dot"),
875
                     s_m1,
876
877
                     s_m2,
                     s_m3);
878
879
880
    }

    {
881
882
883
884
        migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
        migraphx::shape s_m3{migraphx::shape::float_type, {1, 4, 8}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4, 8}},
885
                     migraphx::make_op("dot"),
886
887
888
                     s_m1,
                     s_m2,
                     s_m3);
889
890
891
    }

    {
892
893
894
        migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 6}};
        migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
        migraphx::shape s_m3{migraphx::shape::float_type, {1, 4, 8}};
895
        throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
896
897
898
    }

    {
899
900
901
        migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
        migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}};
902
        throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
903
904
905
    }

    {
906
907
908
        migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
        migraphx::shape s_m3{migraphx::shape::float_type};
909
        throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
910
911
912
    }
}

913
914
915
916
917
918
919
// quant_dot
TEST_CASE(quant_dot_2args)
{
    {
        migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
        migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}};
        expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 8}},
920
                     migraphx::make_op("quant_dot"),
921
922
923
924
925
926
927
928
                     s_m1,
                     s_m2);
    }

    {
        migraphx::shape s_m1{migraphx::shape::int8_type, {3, 8}};
        migraphx::shape s_m2{migraphx::shape::int8_type, {8, 7}};
        expect_shape(migraphx::shape{migraphx::shape::int32_type, {3, 7}},
929
                     migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}),
930
931
932
933
934
935
936
                     s_m1,
                     s_m2);
    }

    {
        migraphx::shape s_m1{migraphx::shape::int8_type, {2, 3}};
        migraphx::shape s_m2{migraphx::shape::int8_type, {3, 8}};
937
        throws_shape(migraphx::make_op("quant_dot"), s_m1, s_m2);
938
939
940
941
942
    }

    {
        migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
        migraphx::shape s_m2{migraphx::shape::int8_type, {8, 8}};
943
        throws_shape(migraphx::make_op("quant_dot"), s_m1, s_m2);
944
945
946
947
948
949
950
951
952
953
    }
}

TEST_CASE(quant_dot_3args)
{
    {
        migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
        migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}};
        migraphx::shape s_m3{migraphx::shape::int32_type, {2, 8}};
        expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 8}},
954
                     migraphx::make_op("quant_dot"),
955
956
957
958
959
960
961
962
963
                     s_m1,
                     s_m2,
                     s_m3);
    }

    {
        migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
        migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}};
        migraphx::shape s_m3{migraphx::shape::int8_type, {2, 8}};
964
        throws_shape(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 2}}), s_m1, s_m2, s_m3);
965
966
967
    }
}

968
969
970
971
972
973
974
975
TEST_CASE(rnn)
{
    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 1;
Shucai Xiao's avatar
Shucai Xiao committed
976
        float clip              = 0.0f;
977
978
979
980
981
982
983

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
        migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};

Shucai Xiao's avatar
Shucai Xiao committed
984
985
986
        expect_shape(
            migraphx::shape{migraphx::shape::float_type,
                            {seq_len, num_dirct, batch_size, hidden_size}},
987
988
989
990
991
992
993
            migraphx::make_op(
                "rnn",
                {{"hidden_size", hidden_size},
                 {"actv_func",
                  migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
                 {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
                 {"clip", clip}}),
Shucai Xiao's avatar
Shucai Xiao committed
994
995
996
997
998
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
999
1000
1001
1002
1003
1004
1005
1006
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 1;
Shucai Xiao's avatar
Shucai Xiao committed
1007
        float clip              = 0.0f;
1008
1009
1010
1011
1012
1013
1014

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
        migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};

Shucai Xiao's avatar
Shucai Xiao committed
1015
1016
1017
        expect_shape(
            migraphx::shape{migraphx::shape::float_type,
                            {seq_len, num_dirct, batch_size, hidden_size}},
1018
1019
1020
1021
1022
1023
1024
            migraphx::make_op(
                "rnn",
                {{"hidden_size", hidden_size},
                 {"actv_func",
                  migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
                 {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
                 {"clip", clip}}),
Shucai Xiao's avatar
Shucai Xiao committed
1025
1026
1027
1028
1029
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
1030
1031
1032
1033
1034
1035
1036
1037
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 2;
Shucai Xiao's avatar
Shucai Xiao committed
1038
        float clip              = 0.0f;
1039
1040
1041
1042
1043
1044
1045

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
        migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};

1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
        expect_shape(
            migraphx::shape{migraphx::shape::float_type,
                            {seq_len, num_dirct, batch_size, hidden_size}},
            migraphx::make_op(
                "rnn",
                {{"hidden_size", hidden_size},
                 {"actv_func",
                  migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
                 {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
                 {"clip", clip}}),
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
1061
1062
1063
1064
1065
1066
1067
1068
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 1;
Shucai Xiao's avatar
Shucai Xiao committed
1069
        float clip              = 0.0f;
1070
1071
1072
1073
1074
1075
1076

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
        migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};

1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
        throws_shape(
            migraphx::make_op(
                "rnn",
                {{"hidden_size", hidden_size + 1},
                 {"actv_func",
                  migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
                 {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
                 {"clip", clip}}),
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
1090
1091
1092
1093
1094
1095
1096
1097
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 1;
Shucai Xiao's avatar
Shucai Xiao committed
1098
        float clip              = 0.0f;
1099
1100
1101
1102
1103
1104
1105

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
        migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};

1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
        throws_shape(
            migraphx::make_op(
                "rnn",
                {{"hidden_size", hidden_size},
                 {"actv_func",
                  migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
                 {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
                 {"clip", clip}}),
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
1119
1120
1121
1122
1123
1124
1125
1126
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 2;
Shucai Xiao's avatar
Shucai Xiao committed
1127
        float clip              = 0.0f;
1128
1129
1130
1131
1132
1133
1134

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
        migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};

Shucai Xiao's avatar
Shucai Xiao committed
1135
        throws_shape(
1136
1137
1138
1139
1140
1141
1142
            migraphx::make_op(
                "rnn",
                {{"hidden_size", hidden_size},
                 {"actv_func",
                  migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
                 {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
                 {"clip", clip}}),
Shucai Xiao's avatar
Shucai Xiao committed
1143
1144
1145
1146
1147
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
1148
1149
1150
    }
}

Shucai Xiao's avatar
Shucai Xiao committed
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
TEST_CASE(gru)
{
    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 1;
        float clip              = 0.0f;

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
Shucai Xiao's avatar
Shucai Xiao committed
1162
1163
1164
1165
        migraphx::shape w_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, hidden_size}};
Shucai Xiao's avatar
Shucai Xiao committed
1166
1167
1168
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};

Shucai Xiao's avatar
Shucai Xiao committed
1169
1170
1171
        expect_shape(
            migraphx::shape{migraphx::shape::float_type,
                            {seq_len, num_dirct, batch_size, hidden_size}},
1172
1173
1174
1175
1176
1177
1178
            migraphx::make_op(
                "gru",
                {{"hidden_size", hidden_size},
                 {"actv_func",
                  migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
                 {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
                 {"clip", clip}}),
Shucai Xiao's avatar
Shucai Xiao committed
1179
1180
1181
1182
1183
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
Shucai Xiao's avatar
Shucai Xiao committed
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 1;
        float clip              = 0.0f;

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
Shucai Xiao's avatar
Shucai Xiao committed
1195
1196
1197
1198
        migraphx::shape w_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, hidden_size}};
Shucai Xiao's avatar
Shucai Xiao committed
1199
1200
1201
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};

Shucai Xiao's avatar
Shucai Xiao committed
1202
1203
1204
        expect_shape(
            migraphx::shape{migraphx::shape::float_type,
                            {seq_len, num_dirct, batch_size, hidden_size}},
1205
1206
1207
1208
1209
1210
1211
            migraphx::make_op(
                "gru",
                {{"hidden_size", hidden_size},
                 {"actv_func",
                  migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
                 {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
                 {"clip", clip}}),
Shucai Xiao's avatar
Shucai Xiao committed
1212
1213
1214
1215
1216
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
Shucai Xiao's avatar
Shucai Xiao committed
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 2;
        float clip              = 0.0f;

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
Shucai Xiao's avatar
Shucai Xiao committed
1228
1229
1230
1231
        migraphx::shape w_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, hidden_size}};
Shucai Xiao's avatar
Shucai Xiao committed
1232
1233
1234
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};

1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
        expect_shape(
            migraphx::shape{migraphx::shape::float_type,
                            {seq_len, num_dirct, batch_size, hidden_size}},
            migraphx::make_op(
                "gru",
                {{"hidden_size", hidden_size},
                 {"actv_func",
                  migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
                 {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
                 {"clip", clip}}),
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
Shucai Xiao's avatar
Shucai Xiao committed
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 1;
        float clip              = 0.0f;

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
Shucai Xiao's avatar
Shucai Xiao committed
1261
1262
1263
1264
        migraphx::shape w_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, hidden_size}};
Shucai Xiao's avatar
Shucai Xiao committed
1265
1266
1267
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};

1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
        throws_shape(
            migraphx::make_op(
                "gru",
                {{"hidden_size", hidden_size + 1},
                 {"actv_func",
                  migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
                 {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
                 {"clip", clip}}),
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
Shucai Xiao's avatar
Shucai Xiao committed
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 1;
        float clip              = 0.0f;

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
Shucai Xiao's avatar
Shucai Xiao committed
1292
1293
1294
1295
        migraphx::shape w_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, hidden_size}};
Shucai Xiao's avatar
Shucai Xiao committed
1296
1297
1298
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};

1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
        throws_shape(
            migraphx::make_op(
                "gru",
                {{"hidden_size", hidden_size},
                 {"actv_func",
                  migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
                 {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
                 {"clip", clip}}),
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
Shucai Xiao's avatar
Shucai Xiao committed
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 2;
        float clip              = 0.0f;

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
Shucai Xiao's avatar
Shucai Xiao committed
1323
1324
1325
1326
        migraphx::shape w_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, hidden_size}};
Shucai Xiao's avatar
Shucai Xiao committed
1327
1328
1329
1330
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};

        throws_shape(
1331
1332
1333
1334
1335
1336
1337
            migraphx::make_op(
                "gru",
                {{"hidden_size", hidden_size},
                 {"actv_func",
                  migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
                 {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
                 {"clip", clip}}),
Shucai Xiao's avatar
Shucai Xiao committed
1338
1339
1340
1341
1342
1343
1344
1345
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
    }
}

1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
TEST_CASE(lstm)
{
    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 1;
        float clip              = 0.0f;

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
        migraphx::shape w_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, hidden_size}};

        expect_shape(
            migraphx::shape{migraphx::shape::float_type,
                            {seq_len, num_dirct, batch_size, hidden_size}},
1365
1366
1367
1368
1369
1370
1371
            migraphx::make_op(
                "lstm",
                {{"hidden_size", hidden_size},
                 {"actv_func",
                  migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
                 {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
                 {"clip", clip}}),
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
            in_shape,
            w_shape,
            r_shape);
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 1;
        float clip              = 0.0f;

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
        migraphx::shape w_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, hidden_size}};
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};

        expect_shape(
            migraphx::shape{migraphx::shape::float_type,
                            {seq_len, num_dirct, batch_size, hidden_size}},
1396
1397
1398
1399
1400
1401
1402
            migraphx::make_op(
                "lstm",
                {{"hidden_size", hidden_size},
                 {"actv_func",
                  migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
                 {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
                 {"clip", clip}}),
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 2;
        float clip              = 0.0f;

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
        migraphx::shape w_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, hidden_size}};
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};

1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
        expect_shape(
            migraphx::shape{migraphx::shape::float_type,
                            {seq_len, num_dirct, batch_size, hidden_size}},
            migraphx::make_op(
                "lstm",
                {{"hidden_size", hidden_size},
                 {"actv_func",
                  migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
                 {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
                 {"clip", clip}}),
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 1;
        float clip              = 0.0f;

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
        migraphx::shape w_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, hidden_size}};
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};

1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
        throws_shape(
            migraphx::make_op(
                "lstm",
                {{"hidden_size", hidden_size + 1},
                 {"actv_func",
                  migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
                 {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
                 {"clip", clip}}),
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 1;
        float clip              = 0.0f;

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
        migraphx::shape w_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, hidden_size}};
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};

1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
        throws_shape(
            migraphx::make_op(
                "lstm",
                {{"hidden_size", hidden_size},
                 {"actv_func",
                  migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
                 {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
                 {"clip", clip}}),
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 2;
        float clip              = 0.0f;

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
        migraphx::shape w_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, hidden_size}};
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};

        throws_shape(
1522
1523
1524
1525
1526
1527
1528
            migraphx::make_op(
                "lstm",
                {{"hidden_size", hidden_size},
                 {"actv_func",
                  migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
                 {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
                 {"clip", clip}}),
1529
1530
1531
1532
1533
1534
1535
1536
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
    }
}

Paul's avatar
Paul committed
1537
int main(int argc, const char* argv[]) { test::run(argc, argv); }