op_shape_test.cpp 52.6 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
#include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
5
6
7
#include <sstream>
#include "test.hpp"

Paul's avatar
Paul committed
8
template <class... Ts>
Paul's avatar
Paul committed
9
void expect_shape(const migraphx::shape& expected, const migraphx::operation& op, Ts... xs)
10
{
Paul's avatar
Paul committed
11
12
13
    migraphx::program p;
    std::vector<migraphx::shape> shapes{xs...};
    std::vector<migraphx::instruction_ref> args(shapes.size());
Paul's avatar
Paul committed
14
15
    std::transform(
        shapes.begin(), shapes.end(), args.begin(), [&](auto&& s) { return p.add_outline(s); });
16
    p.add_instruction(op, args);
17
    if(p.get_output_shapes().back() != expected)
Paul's avatar
Paul committed
18
    {
19
        std::cout << "FAILED: Incorrect shape for " << op.name() << ": ";
20
        std::cout << expected << " != " << p.get_output_shapes().back() << std::endl;
Paul's avatar
Paul committed
21
        for(auto&& s : shapes)
22
23
24
25
            std::cout << "    " << s << std::endl;
    }
}

Paul's avatar
Paul committed
26
template <class... Ts>
Paul's avatar
Paul committed
27
void throws_shape(const migraphx::operation& op, Ts... xs)
28
{
Paul's avatar
Paul committed
29
30
31
    migraphx::program p;
    std::vector<migraphx::shape> shapes{xs...};
    std::vector<migraphx::instruction_ref> args(shapes.size());
Paul's avatar
Paul committed
32
33
    std::transform(
        shapes.begin(), shapes.end(), args.begin(), [&](auto&& s) { return p.add_outline(s); });
34
    bool thrown = test::throws([&] { p.add_instruction(op, args); });
Paul's avatar
Paul committed
35
36
    if(not thrown)
    {
37
        std::cout << "FAILED: No error found for " << op.name() << ": ";
Paul's avatar
Paul committed
38
        for(auto&& s : shapes)
39
40
41
42
            std::cout << "    " << s << std::endl;
    }
}

Paul's avatar
Paul committed
43
44
45
46
template <class...>
struct always_false : std::false_type
{
};
47

Paul's avatar
Paul committed
48
template <class... Ts>
Paul's avatar
Paul committed
49
void throws_shape(const migraphx::shape&, Ts...)
50
{
Paul's avatar
Paul committed
51
52
    static_assert(always_false<Ts...>{},
                  "An expected shape should not be passed to throws_shape function");
53
54
}

Paul's avatar
Paul committed
55
TEST_CASE(batch_norm_inference_shape)
56
57
{
    const size_t channels = 3;
Paul's avatar
Paul committed
58
59
60
61
62
    migraphx::shape s{migraphx::shape::float_type, {4, channels, 3, 3}};
    migraphx::shape vars{migraphx::shape::float_type, {channels}};
    expect_shape(s, migraphx::op::batch_norm_inference{}, s, vars, vars, vars, vars);
    throws_shape(migraphx::op::batch_norm_inference{}, s);
    throws_shape(migraphx::op::batch_norm_inference{}, s, vars, vars, vars, vars, vars);
63
64
}

Paul's avatar
Paul committed
65
TEST_CASE(convolution_shape)
66
{
Paul's avatar
Paul committed
67
68
69
70
71
72
73
74
75
76
    migraphx::shape output{migraphx::shape::float_type, {4, 4, 1, 1}};
    migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
    migraphx::shape weights{migraphx::shape::float_type, {4, 3, 3, 3}};
    expect_shape(output, migraphx::op::convolution{}, input, weights);
    throws_shape(migraphx::op::convolution{}, input);

    migraphx::shape input2{migraphx::shape::float_type, {3, 3}};
    migraphx::shape weights2{migraphx::shape::float_type, {3, 3}};
    throws_shape(migraphx::op::convolution{}, input2, weights2);
    throws_shape(migraphx::op::convolution{}, input2, weights);
77
78
79
80
81
82
83
84
85
86
87
88
89

    migraphx::shape output_1d{migraphx::shape::float_type, {4, 4, 1}};
    migraphx::shape input_1d{migraphx::shape::float_type, {4, 3, 3}};
    migraphx::shape weights_1d{migraphx::shape::float_type, {4, 3, 3}};
    expect_shape(output_1d, migraphx::op::convolution{{0}, {1}, {1}}, input_1d, weights_1d);

    migraphx::shape output_3d{migraphx::shape::float_type, {4, 4, 1, 1, 1}};
    migraphx::shape input_3d{migraphx::shape::float_type, {4, 3, 3, 3, 3}};
    migraphx::shape weights_3d{migraphx::shape::float_type, {4, 3, 3, 3, 3}};
    expect_shape(output_3d,
                 migraphx::op::convolution{{0, 0, 0}, {1, 1, 1}, {1, 1, 1}},
                 input_3d,
                 weights_3d);
Shucai Xiao's avatar
Shucai Xiao committed
90
91

    throws_shape(migraphx::op::convolution{}, input_3d, weights_3d);
92
93
}

kahmed10's avatar
kahmed10 committed
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
TEST_CASE(deconvolution_shape)
{
    migraphx::shape input{migraphx::shape::float_type, {4, 4, 1, 1}};
    migraphx::shape output{migraphx::shape::float_type, {4, 3, 3, 3}};
    migraphx::shape weights{migraphx::shape::float_type, {4, 3, 3, 3}};
    expect_shape(output, migraphx::op::deconvolution{}, input, weights);
    throws_shape(migraphx::op::deconvolution{}, input);

    migraphx::shape input_1d{migraphx::shape::float_type, {4, 4, 1}};
    migraphx::shape output_1d{migraphx::shape::float_type, {4, 3, 3}};
    migraphx::shape weights_1d{migraphx::shape::float_type, {4, 3, 3}};
    expect_shape(output_1d, migraphx::op::deconvolution{{0}, {1}, {1}}, input_1d, weights_1d);

    migraphx::shape input_3d{migraphx::shape::float_type, {4, 4, 1, 1, 1}};
    migraphx::shape output_3d{migraphx::shape::float_type, {4, 3, 3, 3, 3}};
    migraphx::shape weights_3d{migraphx::shape::float_type, {4, 3, 3, 3, 3}};
    expect_shape(output_3d,
                 migraphx::op::deconvolution{{0, 0, 0}, {1, 1, 1}, {1, 1, 1}},
                 input_3d,
                 weights_3d);
}

116
117
TEST_CASE(quant_convolution_shape)
{
118
    migraphx::shape output{migraphx::shape::int32_type, {4, 4, 1, 1}};
119
120
121
122
    migraphx::shape input{migraphx::shape::int8_type, {4, 3, 3, 3}};
    migraphx::shape weights{migraphx::shape::int8_type, {4, 3, 3, 3}};
    expect_shape(output, migraphx::op::quant_convolution{}, input, weights);
    throws_shape(migraphx::op::quant_convolution{}, input);
kahmed10's avatar
kahmed10 committed
123
    throws_shape(migraphx::op::quant_convolution{{0}, {1, 1}, {1, 1}}, input, weights);
124

125
    migraphx::shape input2{migraphx::shape::int32_type, {3, 3}};
126
127
128
129
    migraphx::shape weights2{migraphx::shape::float_type, {3, 3}};
    throws_shape(migraphx::op::quant_convolution{}, input2, weights2);
    throws_shape(migraphx::op::quant_convolution{}, input2, weights);

130
    migraphx::shape input3{migraphx::shape::int32_type, {4, 3, 3, 3}};
131
132
133
134
135
136
    migraphx::shape weight3{migraphx::shape::float_type, {4, 3, 3, 3}};
    throws_shape(migraphx::op::quant_convolution{}, input3, weights);
    throws_shape(migraphx::op::quant_convolution{}, input, weight3);
    throws_shape(migraphx::op::quant_convolution{}, input3, weight3);
}

kahmed10's avatar
kahmed10 committed
137
138
139
140
141
142
143
144
TEST_CASE(pooling_shape)
{
    migraphx::shape output{migraphx::shape::float_type, {4, 3, 1, 1}};
    migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
    throws_shape(migraphx::op::pooling{"max", {1}, {0}, {1}}, input);
    expect_shape(output, migraphx::op::pooling{"max", {0, 0}, {1, 1}, {3, 3}}, input);
}

145
146
147
148
149
TEST_CASE(inconsistent_attr_shape)
{
    migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
    migraphx::shape weights{migraphx::shape::float_type, {4, 3, 3, 3}};
    throws_shape(migraphx::op::convolution{{1, 1}, {2}, {3, 3, 3}}, input, weights);
kahmed10's avatar
kahmed10 committed
150
    throws_shape(migraphx::op::deconvolution{{1, 1}, {2}, {3, 3, 3}}, input, weights);
151
152
153
    throws_shape(migraphx::op::pooling{"max", {1}, {0}, {1, 1}}, input);
}

Paul's avatar
Paul committed
154
TEST_CASE(transpose_shape)
155
{
Paul's avatar
Paul committed
156
157
158
159
    migraphx::shape input{migraphx::shape::float_type, {2, 2}};
    migraphx::shape output{migraphx::shape::float_type, {2, 2}, {1, 2}};
    expect_shape(input, migraphx::op::transpose{{0, 1}}, input);
    expect_shape(output, migraphx::op::transpose{{1, 0}}, input);
160
    expect_shape(output, migraphx::op::transpose{}, input);
Paul's avatar
Paul committed
161
    throws_shape(migraphx::op::transpose{{1, 2}}, input);
162
163
}

Paul's avatar
Paul committed
164
TEST_CASE(contiguous_shape)
165
{
Paul's avatar
Paul committed
166
167
168
169
    migraphx::shape output{migraphx::shape::float_type, {2, 2}};
    migraphx::shape input{migraphx::shape::float_type, {2, 2}, {1, 2}};
    expect_shape(output, migraphx::op::contiguous{}, input);
    throws_shape(migraphx::op::contiguous{}, input, input);
Paul's avatar
Paul committed
170

Paul's avatar
Paul committed
171
172
    migraphx::shape single{migraphx::shape::float_type, {2}};
    expect_shape(single, migraphx::op::contiguous{}, single);
173
174
}

175
176
177
178
179
180
181
TEST_CASE(contiguous_shape_scalar)
{
    migraphx::shape output{migraphx::shape::float_type};
    migraphx::shape input{migraphx::shape::float_type};
    expect_shape(output, migraphx::op::contiguous{}, input);
}

Paul's avatar
Paul committed
182
TEST_CASE(reshape_shape)
183
{
Paul's avatar
Paul committed
184
    migraphx::shape input{migraphx::shape::float_type, {24, 1, 1, 1}};
Paul's avatar
Paul committed
185
186
    for(auto&& new_shape :
        std::vector<std::vector<int64_t>>{{8, 3, 1, 1}, {1, 3, 4, 2}, {1, 3, 4, 2}})
187
188
189
    {
        std::vector<std::size_t> lens(new_shape.size());
        std::copy(new_shape.begin(), new_shape.end(), lens.begin());
Paul's avatar
Paul committed
190
191
        migraphx::shape output{migraphx::shape::float_type, lens};
        expect_shape(output, migraphx::op::reshape{new_shape}, input);
192
193
    }

Shucai Xiao's avatar
Shucai Xiao committed
194
195
    for(auto&& new_shape :
        std::vector<std::vector<int64_t>>{{8, 3, 2, 2}, {1, 3, -1, -1}, {3, 0, 0}, {3, 2, 0}})
196
    {
Paul's avatar
Paul committed
197
        throws_shape(migraphx::op::reshape{new_shape}, input);
198
    }
Shucai Xiao's avatar
Shucai Xiao committed
199

Shucai Xiao's avatar
Shucai Xiao committed
200
    std::vector<std::pair<std::vector<int64_t>, migraphx::shape>> minus1_tests{
Shucai Xiao's avatar
Shucai Xiao committed
201
202
203
204
205
206
207
208
        {{2, -1, 3}, {migraphx::shape::float_type, {2, 4, 3}}},
        {{0, -1, 0}, {migraphx::shape::float_type, {24, 1, 1}}},
        {{2, -1, 0}, {migraphx::shape::float_type, {2, 12, 1}}},
        {{0, 0, -1}, {migraphx::shape::float_type, {24, 1, 1}}},
        {{2, 0, -1}, {migraphx::shape::float_type, {2, 1, 12}}},
        {{-1, 2, 3}, {migraphx::shape::float_type, {4, 2, 3}}},
        {{-1, 0, 3}, {migraphx::shape::float_type, {8, 1, 3}}},
        {{-1, 0, 0}, {migraphx::shape::float_type, {24, 1, 1}}},
Shucai Xiao's avatar
Shucai Xiao committed
209
        {{-1, 3, 0}, {migraphx::shape::float_type, {8, 3, 1}}}};
Shucai Xiao's avatar
Shucai Xiao committed
210

Shucai Xiao's avatar
Shucai Xiao committed
211
    for(auto& it : minus1_tests)
Shucai Xiao's avatar
Shucai Xiao committed
212
213
214
    {
        expect_shape(it.second, migraphx::op::reshape{it.first}, input);
    }
215
216
}

Paul's avatar
Paul committed
217
TEST_CASE(flatten_shape)
218
{
Paul's avatar
Paul committed
219
220
221
    migraphx::shape input{migraphx::shape::float_type, {2, 4, 6, 8}};
    expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 2 * 4 * 6 * 8}},
                 migraphx::op::flatten{0},
Scott Thornton's avatar
Scott Thornton committed
222
                 input);
223
224
225
    expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 2 * 4 * 6 * 8}},
                 migraphx::op::flatten{-4},
                 input);
Paul's avatar
Paul committed
226
227
228
    expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 4 * 6 * 8}},
                 migraphx::op::flatten{1},
                 input);
229
230
231
    expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 4 * 6 * 8}},
                 migraphx::op::flatten{-3},
                 input);
Paul's avatar
Paul committed
232
233
234
235
236
237
    expect_shape(migraphx::shape{migraphx::shape::float_type, {2 * 4, 6 * 8}},
                 migraphx::op::flatten{2},
                 input);
    expect_shape(migraphx::shape{migraphx::shape::float_type, {2 * 4 * 6, 8}},
                 migraphx::op::flatten{3},
                 input);
Paul's avatar
Paul committed
238
239
    expect_shape(migraphx::shape{migraphx::shape::float_type, {2 * 4 * 6 * 8, 1}},
                 migraphx::op::flatten{4},
Scott Thornton's avatar
Scott Thornton committed
240
                 input);
Paul's avatar
Paul committed
241
    throws_shape(migraphx::op::flatten{5}, input);
242
    throws_shape(migraphx::op::flatten{-5}, input);
243
244
}

Paul's avatar
Paul committed
245
TEST_CASE(slice_shape)
Scott Thornton's avatar
Scott Thornton committed
246
{
Paul's avatar
Paul committed
247
248
249
    migraphx::shape input{migraphx::shape::int32_type, {2, 2, 3}};
    expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}},
                 migraphx::op::slice{{2}, {1}, {3}},
Scott Thornton's avatar
Scott Thornton committed
250
                 input);
Paul's avatar
Paul committed
251
252
    expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}},
                 migraphx::op::slice{{0, 1, 2}, {0, 0, 1}, {2, 2, 3}},
Scott Thornton's avatar
Scott Thornton committed
253
                 input);
Paul's avatar
Paul committed
254
255
    expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 1}, {6, 3, 1}},
                 migraphx::op::slice{{2}, {2}, {10}},
Scott Thornton's avatar
Scott Thornton committed
256
257
                 input);
}
Scott Thornton's avatar
Scott Thornton committed
258

wsttiger's avatar
wsttiger committed
259
TEST_CASE(multibroadcast)
Scott Thornton's avatar
Scott Thornton committed
260
261
{
    {
262
        std::vector<std::size_t> lens{4, 2, 5, 3};
Paul's avatar
Paul committed
263
264
265
        migraphx::shape input{migraphx::shape::float_type, {2, 1, 3}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 3, 0, 1}},
                     migraphx::op::multibroadcast{lens},
266
                     input);
Scott Thornton's avatar
Scott Thornton committed
267
268
    }
    {
269
        std::vector<std::size_t> lens{4, 2, 5, 3};
Paul's avatar
Paul committed
270
271
272
        migraphx::shape input{migraphx::shape::float_type, {2, 1, 1}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 1, 0, 0}},
                     migraphx::op::multibroadcast{lens},
273
                     input);
Scott Thornton's avatar
Scott Thornton committed
274
275
    }
    {
276
        std::vector<std::size_t> lens{4, 2, 5, 3};
Paul's avatar
Paul committed
277
278
279
        migraphx::shape input{migraphx::shape::float_type, {5, 1}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 0, 1, 0}},
                     migraphx::op::multibroadcast{lens},
280
                     input);
Scott Thornton's avatar
Scott Thornton committed
281
282
    }
    {
283
        std::vector<std::size_t> lens{4, 2, 5, 3};
Paul's avatar
Paul committed
284
285
286
        migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {1, 0, 0, 0}},
                     migraphx::op::multibroadcast{lens},
287
288
289
290
                     input);
    }
    {
        std::vector<std::size_t> lens{4, 2, 5, 3};
Paul's avatar
Paul committed
291
292
293
        migraphx::shape input{migraphx::shape::float_type, {3}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 0, 0, 1}},
                     migraphx::op::multibroadcast{lens},
294
295
                     input);
    }
296
297
    {
        std::vector<std::size_t> lens{4, 4, 1, 3};
Paul's avatar
Paul committed
298
299
300
        migraphx::shape input{migraphx::shape::float_type, {4, 1, 3}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 3, 3, 1}},
                     migraphx::op::multibroadcast{lens},
301
302
                     input);
    }
303
304
    {
        std::vector<std::size_t> lens{4, 1, 1, 3};
Paul's avatar
Paul committed
305
306
307
        migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {1, 1, 1, 0}},
                     migraphx::op::multibroadcast{lens},
308
309
310
311
                     input);
    }
    {
        std::vector<std::size_t> lens{4, 1, 3};
Paul's avatar
Paul committed
312
313
        migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
        throws_shape(migraphx::op::multibroadcast{lens}, input);
Scott Thornton's avatar
Scott Thornton committed
314
315
    }
    {
316
        std::vector<std::size_t> lens{4, 1, 3};
Paul's avatar
Paul committed
317
318
        migraphx::shape input{migraphx::shape::float_type, {}};
        throws_shape(migraphx::op::multibroadcast{lens}, input);
Scott Thornton's avatar
Scott Thornton committed
319
    }
Shucai Xiao's avatar
Shucai Xiao committed
320
321
322
323
324
325
326
327
328
329
    {
        std::vector<std::size_t> lens{2, 3, 4, 5};
        migraphx::shape input{migraphx::shape::float_type, {3, 4}};
        throws_shape(migraphx::op::multibroadcast{lens}, input);
    }
    {
        std::vector<std::size_t> lens{2, 3, 4, 5};
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4}};
        throws_shape(migraphx::op::multibroadcast{lens}, input);
    }
Scott Thornton's avatar
Scott Thornton committed
330
331
}

332
333
334
335
336
337
338
339
340
341
342
343
TEST_CASE(broadcast)
{
    {
        std::vector<std::size_t> lens{1, 1};
        migraphx::shape input{migraphx::shape::float_type, {4, 1, 3}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1}, {0, 0}},
                     migraphx::op::broadcast{0, lens},
                     input);
    }
    {
        std::vector<std::size_t> lens{1, 1};
        migraphx::shape input{migraphx::shape::float_type, {4, 1, 3}};
Shucai Xiao's avatar
Shucai Xiao committed
344
        throws_shape(migraphx::op::broadcast{1, lens}, input);
345
346
347
348
349
350
351
352
353
354
355
356
357
    }

    {
        std::vector<std::size_t> lens{3, 2, 4, 3};
        migraphx::shape input{migraphx::shape::float_type, {4, 3}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 2, 4, 3}, {0, 0, 3, 1}},
                     migraphx::op::broadcast{2, lens},
                     input);
    }

    {
        std::vector<std::size_t> lens{3, 2, 4, 3};
        migraphx::shape input{migraphx::shape::float_type, {4, 4}};
Shucai Xiao's avatar
Shucai Xiao committed
358
        throws_shape(migraphx::op::broadcast{2, lens}, input);
359
360
361
    }
}

Shucai Xiao's avatar
Shucai Xiao committed
362
TEST_CASE(gather)
363
364
365
366
{
    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
        migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
367
        int axis = 1;
368
        expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 2, 3, 4, 5}},
Shucai Xiao's avatar
Shucai Xiao committed
369
370
371
                     migraphx::op::gather{axis},
                     input,
                     indices);
372
373
    }

374
375
376
377
    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
        migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
        int axis = -4;
378
        expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 3, 4, 5}},
379
380
381
382
383
                     migraphx::op::gather{axis},
                     input,
                     indices);
    }

384
385
386
387
388
389
390
391
392
393
394
395
    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
        migraphx::shape indices{migraphx::shape::int32_type, {1}};
        int axis = -4;
        expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}},
                     migraphx::op::gather{axis},
                     input,
                     indices);
    }

    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
396
        migraphx::shape indices{migraphx::shape::int32_type};
397
398
399
400
401
402
403
404
405
        int axis = -4;
        expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 4, 5}},
                     migraphx::op::gather{axis},
                     input,
                     indices);
    }

    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
406
        migraphx::shape indices{migraphx::shape::int32_type};
407
408
409
410
411
412
413
414
415
        int axis = 3;
        expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4}},
                     migraphx::op::gather{axis},
                     input,
                     indices);
    }

    {
        migraphx::shape input{migraphx::shape::float_type, {3}};
416
        migraphx::shape indices{migraphx::shape::int32_type};
417
        int axis = 0;
418
        expect_shape(migraphx::shape{migraphx::shape::float_type},
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
                     migraphx::op::gather{axis},
                     input,
                     indices);
    }

    {
        migraphx::shape input{migraphx::shape::float_type, {3}};
        migraphx::shape indices{migraphx::shape::int32_type, {1}};
        int axis = 0;
        expect_shape(migraphx::shape{migraphx::shape::float_type, {1}},
                     migraphx::op::gather{axis},
                     input,
                     indices);
    }

434
435
436
    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
        migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
437
        int axis = 4;
438
439
        throws_shape(migraphx::op::gather{axis}, input, indices);
    }
440
441
442
443
444
445
446

    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
        migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
        int axis = -5;
        throws_shape(migraphx::op::gather{axis}, input, indices);
    }
447
448
}

Khalique's avatar
Khalique committed
449
template <class T>
Khalique's avatar
Khalique committed
450
void test_softmax_variations()
451
452
453
{
    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
Khalique's avatar
Khalique committed
454
        expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{0}, input);
455
456
457
458
    }

    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
Khalique's avatar
Khalique committed
459
        expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{1}, input);
460
461
462
463
    }

    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
Khalique's avatar
Khalique committed
464
        expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{2}, input);
465
466
467
468
    }

    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
Khalique's avatar
Khalique committed
469
        expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{3}, input);
470
471
472
473
474
    }

    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
        int axis = 4;
Khalique's avatar
Khalique committed
475
        throws_shape(T{axis}, input);
476
477
478
    }
}

Khalique's avatar
Khalique committed
479
TEST_CASE(softmax) { test_softmax_variations<migraphx::op::softmax>(); }
480

Khalique's avatar
Khalique committed
481
TEST_CASE(logsoftmax) { test_softmax_variations<migraphx::op::logsoftmax>(); }
482

Shucai Xiao's avatar
Shucai Xiao committed
483
TEST_CASE(test_argmax)
484
485
486
{
    {
        migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
Shucai Xiao's avatar
Shucai Xiao committed
487
488
489
        expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}},
                     migraphx::op::argmax{0},
                     input);
Shucai Xiao's avatar
Shucai Xiao committed
490
491
492
493
    }

    {
        migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
Shucai Xiao's avatar
Shucai Xiao committed
494
495
496
        expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}},
                     migraphx::op::argmax{1},
                     input);
Shucai Xiao's avatar
Shucai Xiao committed
497
498
499
500
    }

    {
        migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
Shucai Xiao's avatar
Shucai Xiao committed
501
502
503
        expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}},
                     migraphx::op::argmax{2},
                     input);
Shucai Xiao's avatar
Shucai Xiao committed
504
505
506
507
    }

    {
        migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
Shucai Xiao's avatar
Shucai Xiao committed
508
509
510
        expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}},
                     migraphx::op::argmax{3},
                     input);
Shucai Xiao's avatar
Shucai Xiao committed
511
512
513
514
515
    }

    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
        throws_shape(migraphx::op::argmax{4}, input);
516
    }
Shucai Xiao's avatar
Shucai Xiao committed
517
}
518

Shucai Xiao's avatar
Shucai Xiao committed
519
TEST_CASE(test_argmin)
520
521
522
{
    {
        migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
Shucai Xiao's avatar
Shucai Xiao committed
523
524
525
        expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}},
                     migraphx::op::argmin{0},
                     input);
526
527
528
529
    }

    {
        migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
Shucai Xiao's avatar
Shucai Xiao committed
530
531
532
        expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}},
                     migraphx::op::argmin{1},
                     input);
533
534
535
536
    }

    {
        migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
Shucai Xiao's avatar
Shucai Xiao committed
537
538
539
        expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}},
                     migraphx::op::argmin{2},
                     input);
540
541
542
543
    }

    {
        migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
Shucai Xiao's avatar
Shucai Xiao committed
544
545
546
        expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}},
                     migraphx::op::argmin{3},
                     input);
Shucai Xiao's avatar
Shucai Xiao committed
547
548
549
550
551
    }

    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
        throws_shape(migraphx::op::argmin{4}, input);
552
    }
Shucai Xiao's avatar
Shucai Xiao committed
553
}
554

555
556
TEST_CASE(test_squeeze)
{
557
558
559
560
    migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}};
    migraphx::shape s2{migraphx::shape::float_type, {4, 1, 3, 3}};
    expect_shape(s2, migraphx::op::squeeze{{3}}, s1);
}
561

562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
TEST_CASE(test_squeeze_negative_axis)
{
    migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}};
    migraphx::shape s2{migraphx::shape::float_type, {4, 1, 3, 3}};
    expect_shape(s2, migraphx::op::squeeze{{-2}}, s1);
}

TEST_CASE(test_squeeze_wrong_axis)
{
    migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}};
    throws_shape(migraphx::op::squeeze{{0}}, s1);
}

TEST_CASE(test_squeeze_all)
{
    migraphx::shape s1{migraphx::shape::float_type, {1}};
    migraphx::shape s2{migraphx::shape::float_type};
    expect_shape(s2, migraphx::op::squeeze{{0}}, s1);
}

TEST_CASE(test_unsqueeze_scalar)
{
    migraphx::shape s1{migraphx::shape::float_type, {1}, {0}};
    migraphx::shape s2{migraphx::shape::float_type, {1}, {1}};
    expect_shape(s2, migraphx::op::unsqueeze{{0}}, s1);
}

TEST_CASE(test_unsqueeze_scalar_tensor1)
{
    migraphx::shape s{migraphx::shape::float_type, {4, 3, 3}, {0, 0, 0}};
    throws_shape(migraphx::op::unsqueeze{{-2}}, s);
}

TEST_CASE(test_unsqueeze_scalar_tensor2)
{
    migraphx::shape s{migraphx::shape::float_type, {1, 1, 1}, {0, 0, 0}};
    throws_shape(migraphx::op::unsqueeze{{-2}}, s);
}

TEST_CASE(test_unsqueeze)
{
    migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}};
    migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}};
    expect_shape(s2, migraphx::op::unsqueeze{{2}}, s1);
}

TEST_CASE(test_unsqueeze_negative_axis)
{
    migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}};
    migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}};
    expect_shape(s2, migraphx::op::unsqueeze{{-2}}, s1);
613
614
}

Shucai Xiao's avatar
Shucai Xiao committed
615
616
template <class T>
void test_reduce_ops()
Shucai Xiao's avatar
Shucai Xiao committed
617
{
618
619
620
621
622
623
624
    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, T{}, input);
    }

    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
Shucai Xiao's avatar
Shucai Xiao committed
625
626
        expect_shape(
            migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, T{{0, 1, 2, 3}}, input);
627
    }
Shucai Xiao's avatar
Shucai Xiao committed
628
629
630
631
632
633
634
635
    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 1, 1}}, T{{2, 3}}, input);
    }
    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}}, T{{0}}, input);
    }
636
637
638
639
    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 1}}, T{{-1}}, input);
    }
640
641
    {
        migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
Shucai Xiao's avatar
Shucai Xiao committed
642
        throws_shape(T{{4}}, input);
643
644
645
    }
}

Shucai Xiao's avatar
Shucai Xiao committed
646
647
TEST_CASE(reduce_sum) { test_reduce_ops<migraphx::op::reduce_sum>(); }
TEST_CASE(reduce_mean) { test_reduce_ops<migraphx::op::reduce_mean>(); }
648

649
650
// 2 inputs arguments
TEST_CASE(matmul)
651
652
{
    {
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
        migraphx::shape s_m1{migraphx::shape::float_type, {5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {5}};
        throws_shape(migraphx::op::dot{}, s_m1, s_m2);
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {5, 2}};
        throws_shape(migraphx::op::dot{}, s_m1, s_m2);
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {5}};
        throws_shape(migraphx::op::dot{}, s_m1, s_m2);
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {5, 4}};
Shucai Xiao's avatar
Shucai Xiao committed
673
        expect_shape(
674
            migraphx::shape{migraphx::shape::float_type, {1, 4}}, migraphx::op::dot{}, s_m1, s_m2);
675
676
677
    }

    {
678
679
680
681
682
683
684
685
        migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {4, 4}};
        throws_shape(migraphx::op::dot{}, s_m1, s_m2);
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {6, 5, 4}};
686
687
688
        throws_shape(migraphx::op::dot{}, s_m1, s_m2);
    }

689
690
691
    {
        migraphx::shape s_m1{migraphx::shape::float_type, {6, 1, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {6, 5, 4}};
Shucai Xiao's avatar
Shucai Xiao committed
692
693
694
695
        expect_shape(migraphx::shape{migraphx::shape::float_type, {6, 1, 4}},
                     migraphx::op::dot{},
                     s_m1,
                     s_m2);
696
697
698
699
700
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {1, 6, 1, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {1, 6, 5, 4}};
Shucai Xiao's avatar
Shucai Xiao committed
701
702
703
704
        expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 6, 1, 4}},
                     migraphx::op::dot{},
                     s_m1,
                     s_m2);
705
706
707
708
709
710
711
712
713
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
        expect_shape(
            migraphx::shape{migraphx::shape::float_type, {4, 8}}, migraphx::op::dot{}, s_m1, s_m2);
    }

714
715
716
    {
        migraphx::shape s_m1{migraphx::shape::float_type, {1, 1}};
        migraphx::shape s_m2{migraphx::shape::float_type, {1, 1}};
Shucai Xiao's avatar
Shucai Xiao committed
717
718
        expect_shape(
            migraphx::shape{migraphx::shape::float_type, {1, 1}}, migraphx::op::dot{}, s_m1, s_m2);
719
720
721
722
723
724
725
726
727
728
729
730
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 7}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4, 7}},
                     migraphx::op::dot{},
                     s_m1,
                     s_m2);
    }

    {
731
732
733
        migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {1, 1, 5, 7}};
        throws_shape(migraphx::op::dot{}, s_m1, s_m2);
734
735
736
737
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {1, 1, 4, 5}};
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
        migraphx::shape s_m2{migraphx::shape::float_type, {1, 2, 5, 7}};
        throws_shape(migraphx::op::dot{}, s_m1, s_m2);
    }
}

// 3 input arguments
TEST_CASE(gemm)
{
    {
        migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
        migraphx::shape s_m3{migraphx::shape::float_type, {1}};
        throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
        migraphx::shape s_m3{migraphx::shape::float_type, {1, 1}};
        throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
        migraphx::shape s_m3{migraphx::shape::float_type, {8}};
        throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
        migraphx::shape s_m3{migraphx::shape::float_type, {4, 1}};
        throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {4, 6}};
        migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
        migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}};
        throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
        migraphx::shape s_m3{migraphx::shape::float_type, {4}};
        throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
    }

    {
        migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
        migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}},
793
794
                     migraphx::op::dot{},
                     s_m1,
795
796
                     s_m2,
                     s_m3);
797
798
799
    }

    {
800
801
802
803
804
805
806
807
        migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
        migraphx::shape s_m3{migraphx::shape::float_type, {1, 4, 8}};
        expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4, 8}},
                     migraphx::op::dot{},
                     s_m1,
                     s_m2,
                     s_m3);
808
809
810
    }

    {
811
812
813
814
        migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 6}};
        migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
        migraphx::shape s_m3{migraphx::shape::float_type, {1, 4, 8}};
        throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
815
816
817
    }

    {
818
819
820
821
        migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
        migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}};
        throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
822
823
824
    }

    {
825
826
827
828
        migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
        migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
        migraphx::shape s_m3{migraphx::shape::float_type};
        throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
829
830
831
    }
}

832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
// quant_dot
TEST_CASE(quant_dot_2args)
{
    {
        migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
        migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}};
        expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 8}},
                     migraphx::op::quant_dot{},
                     s_m1,
                     s_m2);
    }

    {
        migraphx::shape s_m1{migraphx::shape::int8_type, {3, 8}};
        migraphx::shape s_m2{migraphx::shape::int8_type, {8, 7}};
        expect_shape(migraphx::shape{migraphx::shape::int32_type, {3, 7}},
                     migraphx::op::quant_dot{1, 0},
                     s_m1,
                     s_m2);
    }

    {
        migraphx::shape s_m1{migraphx::shape::int8_type, {2, 3}};
        migraphx::shape s_m2{migraphx::shape::int8_type, {3, 8}};
Shucai Xiao's avatar
Shucai Xiao committed
856
        throws_shape(migraphx::op::quant_dot{}, s_m1, s_m2);
857
858
859
860
861
    }

    {
        migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
        migraphx::shape s_m2{migraphx::shape::int8_type, {8, 8}};
Shucai Xiao's avatar
Shucai Xiao committed
862
        throws_shape(migraphx::op::quant_dot{}, s_m1, s_m2);
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
    }
}

TEST_CASE(quant_dot_3args)
{
    {
        migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
        migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}};
        migraphx::shape s_m3{migraphx::shape::int32_type, {2, 8}};
        expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 8}},
                     migraphx::op::quant_dot{},
                     s_m1,
                     s_m2,
                     s_m3);
    }

    {
        migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
        migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}};
        migraphx::shape s_m3{migraphx::shape::int8_type, {2, 8}};
Shucai Xiao's avatar
Shucai Xiao committed
883
        throws_shape(migraphx::op::quant_dot{1, 2}, s_m1, s_m2, s_m3);
884
885
886
    }
}

887
888
889
890
891
892
893
894
TEST_CASE(rnn)
{
    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 1;
Shucai Xiao's avatar
Shucai Xiao committed
895
        float clip              = 0.0f;
896
897
898
899
900
901
902

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
        migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};

Shucai Xiao's avatar
Shucai Xiao committed
903
904
905
906
907
908
909
910
911
912
        expect_shape(
            migraphx::shape{migraphx::shape::float_type,
                            {seq_len, num_dirct, batch_size, hidden_size}},
            migraphx::op::rnn{
                hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
913
914
915
916
917
918
919
920
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 1;
Shucai Xiao's avatar
Shucai Xiao committed
921
        float clip              = 0.0f;
922
923
924
925
926
927
928

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
        migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};

Shucai Xiao's avatar
Shucai Xiao committed
929
930
931
932
933
934
935
936
937
938
        expect_shape(
            migraphx::shape{migraphx::shape::float_type,
                            {seq_len, num_dirct, batch_size, hidden_size}},
            migraphx::op::rnn{
                hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::reverse, clip},
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
939
940
941
942
943
944
945
946
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 2;
Shucai Xiao's avatar
Shucai Xiao committed
947
        float clip              = 0.0f;
948
949
950
951
952
953
954

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
        migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};

Shucai Xiao's avatar
Shucai Xiao committed
955
956
957
958
959
960
961
962
963
964
965
        expect_shape(migraphx::shape{migraphx::shape::float_type,
                                     {seq_len, num_dirct, batch_size, hidden_size}},
                     migraphx::op::rnn{hidden_size,
                                       {migraphx::op::tanh{}},
                                       migraphx::op::rnn_direction::bidirectional,
                                       clip},
                     in_shape,
                     w_shape,
                     r_shape,
                     b_shape,
                     ih_shape);
966
967
968
969
970
971
972
973
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 1;
Shucai Xiao's avatar
Shucai Xiao committed
974
        float clip              = 0.0f;
975
976
977
978
979
980
981

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
        migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};

Shucai Xiao's avatar
Shucai Xiao committed
982
983
984
985
986
987
988
989
990
        throws_shape(migraphx::op::rnn{hidden_size + 1,
                                       {migraphx::op::tanh{}},
                                       migraphx::op::rnn_direction::forward,
                                       clip},
                     in_shape,
                     w_shape,
                     r_shape,
                     b_shape,
                     ih_shape);
991
992
993
994
995
996
997
998
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 1;
Shucai Xiao's avatar
Shucai Xiao committed
999
        float clip              = 0.0f;
1000
1001
1002
1003
1004
1005
1006

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
        migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};

Shucai Xiao's avatar
Shucai Xiao committed
1007
1008
1009
1010
1011
1012
1013
1014
1015
        throws_shape(migraphx::op::rnn{hidden_size,
                                       {migraphx::op::tanh{}},
                                       migraphx::op::rnn_direction::bidirectional,
                                       clip},
                     in_shape,
                     w_shape,
                     r_shape,
                     b_shape,
                     ih_shape);
1016
1017
1018
1019
1020
1021
1022
1023
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 2;
Shucai Xiao's avatar
Shucai Xiao committed
1024
        float clip              = 0.0f;
1025
1026
1027
1028
1029
1030
1031

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
        migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};

Shucai Xiao's avatar
Shucai Xiao committed
1032
1033
        throws_shape(
            migraphx::op::rnn{
1034
                hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
Shucai Xiao's avatar
Shucai Xiao committed
1035
1036
1037
1038
1039
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
1040
1041
1042
    }
}

Shucai Xiao's avatar
Shucai Xiao committed
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
TEST_CASE(gru)
{
    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 1;
        float clip              = 0.0f;

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
Shucai Xiao's avatar
Shucai Xiao committed
1054
1055
1056
1057
        migraphx::shape w_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, hidden_size}};
Shucai Xiao's avatar
Shucai Xiao committed
1058
1059
1060
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};

Shucai Xiao's avatar
Shucai Xiao committed
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
        expect_shape(
            migraphx::shape{migraphx::shape::float_type,
                            {seq_len, num_dirct, batch_size, hidden_size}},
            migraphx::op::gru{
                hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
Shucai Xiao's avatar
Shucai Xiao committed
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 1;
        float clip              = 0.0f;

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
Shucai Xiao's avatar
Shucai Xiao committed
1082
1083
1084
1085
        migraphx::shape w_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, hidden_size}};
Shucai Xiao's avatar
Shucai Xiao committed
1086
1087
1088
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};

Shucai Xiao's avatar
Shucai Xiao committed
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
        expect_shape(
            migraphx::shape{migraphx::shape::float_type,
                            {seq_len, num_dirct, batch_size, hidden_size}},
            migraphx::op::gru{
                hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::reverse, clip},
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
Shucai Xiao's avatar
Shucai Xiao committed
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 2;
        float clip              = 0.0f;

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
Shucai Xiao's avatar
Shucai Xiao committed
1110
1111
1112
1113
        migraphx::shape w_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, hidden_size}};
Shucai Xiao's avatar
Shucai Xiao committed
1114
1115
1116
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};

Shucai Xiao's avatar
Shucai Xiao committed
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
        expect_shape(migraphx::shape{migraphx::shape::float_type,
                                     {seq_len, num_dirct, batch_size, hidden_size}},
                     migraphx::op::gru{hidden_size,
                                       {migraphx::op::tanh{}},
                                       migraphx::op::rnn_direction::bidirectional,
                                       clip},
                     in_shape,
                     w_shape,
                     r_shape,
                     b_shape,
                     ih_shape);
Shucai Xiao's avatar
Shucai Xiao committed
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 1;
        float clip              = 0.0f;

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
Shucai Xiao's avatar
Shucai Xiao committed
1139
1140
1141
1142
        migraphx::shape w_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, hidden_size}};
Shucai Xiao's avatar
Shucai Xiao committed
1143
1144
1145
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};

Shucai Xiao's avatar
Shucai Xiao committed
1146
1147
1148
1149
1150
1151
1152
1153
1154
        throws_shape(migraphx::op::gru{hidden_size + 1,
                                       {migraphx::op::tanh{}},
                                       migraphx::op::rnn_direction::forward,
                                       clip},
                     in_shape,
                     w_shape,
                     r_shape,
                     b_shape,
                     ih_shape);
Shucai Xiao's avatar
Shucai Xiao committed
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 1;
        float clip              = 0.0f;

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
Shucai Xiao's avatar
Shucai Xiao committed
1166
1167
1168
1169
        migraphx::shape w_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, hidden_size}};
Shucai Xiao's avatar
Shucai Xiao committed
1170
1171
1172
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};

Shucai Xiao's avatar
Shucai Xiao committed
1173
1174
1175
1176
1177
1178
1179
1180
1181
        throws_shape(migraphx::op::gru{hidden_size,
                                       {migraphx::op::tanh{}},
                                       migraphx::op::rnn_direction::bidirectional,
                                       clip},
                     in_shape,
                     w_shape,
                     r_shape,
                     b_shape,
                     ih_shape);
Shucai Xiao's avatar
Shucai Xiao committed
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 2;
        float clip              = 0.0f;

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
Shucai Xiao's avatar
Shucai Xiao committed
1193
1194
1195
1196
        migraphx::shape w_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, hidden_size}};
Shucai Xiao's avatar
Shucai Xiao committed
1197
1198
1199
1200
1201
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};

        throws_shape(
            migraphx::op::gru{
1202
                hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
Shucai Xiao's avatar
Shucai Xiao committed
1203
1204
1205
1206
1207
1208
1209
1210
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
    }
}

1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
TEST_CASE(lstm)
{
    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 1;
        float clip              = 0.0f;

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
        migraphx::shape w_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, hidden_size}};

        expect_shape(
            migraphx::shape{migraphx::shape::float_type,
                            {seq_len, num_dirct, batch_size, hidden_size}},
            migraphx::op::lstm{
                hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
            in_shape,
            w_shape,
            r_shape);
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 1;
        float clip              = 0.0f;

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
        migraphx::shape w_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, hidden_size}};
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};

        expect_shape(
            migraphx::shape{migraphx::shape::float_type,
                            {seq_len, num_dirct, batch_size, hidden_size}},
            migraphx::op::lstm{
                hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::reverse, clip},
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 2;
        float clip              = 0.0f;

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
        migraphx::shape w_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, hidden_size}};
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};

        expect_shape(migraphx::shape{migraphx::shape::float_type,
                                     {seq_len, num_dirct, batch_size, hidden_size}},
                     migraphx::op::lstm{hidden_size,
Shucai Xiao's avatar
Shucai Xiao committed
1284
1285
1286
                                        {migraphx::op::tanh{}},
                                        migraphx::op::rnn_direction::bidirectional,
                                        clip},
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
                     in_shape,
                     w_shape,
                     r_shape,
                     b_shape,
                     ih_shape);
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 1;
        float clip              = 0.0f;

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
        migraphx::shape w_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, hidden_size}};
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};

        throws_shape(migraphx::op::lstm{hidden_size + 1,
Shucai Xiao's avatar
Shucai Xiao committed
1311
1312
1313
                                        {migraphx::op::tanh{}},
                                        migraphx::op::rnn_direction::forward,
                                        clip},
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
                     in_shape,
                     w_shape,
                     r_shape,
                     b_shape,
                     ih_shape);
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 1;
        float clip              = 0.0f;

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
        migraphx::shape w_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, hidden_size}};
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};

        throws_shape(migraphx::op::lstm{hidden_size,
Shucai Xiao's avatar
Shucai Xiao committed
1338
1339
1340
                                        {migraphx::op::tanh{}},
                                        migraphx::op::rnn_direction::bidirectional,
                                        clip},
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
                     in_shape,
                     w_shape,
                     r_shape,
                     b_shape,
                     ih_shape);
    }

    {
        std::size_t batch_size  = 2;
        std::size_t seq_len     = 2;
        std::size_t hidden_size = 4;
        std::size_t input_size  = 3;
        std::size_t num_dirct   = 2;
        float clip              = 0.0f;

        migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
        migraphx::shape w_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, input_size}};
        migraphx::shape r_shape{migraphx::shape::float_type,
                                {num_dirct, 3 * hidden_size, hidden_size}};
        migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
        migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};

        throws_shape(
            migraphx::op::lstm{
                hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
            in_shape,
            w_shape,
            r_shape,
            b_shape,
            ih_shape);
    }
}

Paul's avatar
Paul committed
1375
int main(int argc, const char* argv[]) { test::run(argc, argv); }