Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
6a94c42a
Commit
6a94c42a
authored
Feb 20, 2019
by
Shucai Xiao
Browse files
clang format
parent
dcaf8fd3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
228 additions
and
310 deletions
+228
-310
test/cpu_rnn_ops_test.cpp
test/cpu_rnn_ops_test.cpp
+228
-310
No files found.
test/cpu_rnn_ops_test.cpp
View file @
6a94c42a
...
@@ -2774,313 +2774,236 @@ TEST_CASE(lstm_reverse)
...
@@ -2774,313 +2774,236 @@ TEST_CASE(lstm_reverse)
TEST_CASE(lstm_bidirectional)
TEST_CASE(lstm_bidirectional)
{
{
std
::
size_t
batch_size
=
3
;
std::size_t batch_size = 3;
std
::
size_t
seq_len
=
4
;
std::size_t seq_len = 4;
std
::
size_t
hidden_size
=
4
;
std::size_t hidden_size = 4;
std
::
size_t
input_size
=
3
;
std::size_t input_size = 3;
std
::
size_t
num_dirct
=
2
;
std::size_t num_dirct = 2;
std
::
vector
<
float
>
w_data
{
std::vector<float> w_data{
0.1236
,
-
0.3942
,
0.4149
,
0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.2602, -0.3098, 0.0567, 0.3344,
0.0795
,
0.4934
,
-
0.2858
,
0.3607, -0.0551, 0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, 0.4267, 0.2382,
0.2602
,
-
0.3098
,
0.0567
,
-0.0784, -0.0032, -0.2476, -0.0206, -0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279,
0.3344
,
0.3607
,
-
0.0551
,
-0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504, 0.0575, -0.2138, 0.1071, 0.1976,
0.4952
,
0.3799
,
0.0630
,
-0.0758, 0.0139, -0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285, -0.2763, -0.4715,
-
0.3532
,
0.0023
,
-
0.0592
,
-0.3010, -0.2306, -0.2283, -0.2656, 0.2035, 0.3570, -0.1499, 0.4390, -0.1843, 0.2351,
0.4267
,
0.2382
,
-
0.0784
,
0.3357, 0.1217, 0.1401, 0.3300, -0.0429, 0.3266, 0.4834, -0.3914, -0.1480, 0.3734,
-
0.0032
,
-
0.2476
,
-
0.0206
,
-0.0372, -0.1746, 0.0550, 0.4177, -0.1332, 0.4391, -0.3287, -0.4401, 0.1486, 0.1346,
-
0.4963
,
0.4837
,
0.0827
,
0.1048, -0.4361, 0.0886, -0.3840, -0.2730, -0.1710, 0.3274, 0.0169, -0.4462, 0.0729,
0.0123
,
-
0.1203
,
-
0.0279
,
0.3983, -0.0669, 0.0756, 0.4150, -0.4684, -0.2522};
-
0.0049
,
0.4721
,
-
0.3564
,
-
0.1286
,
0.4090
,
-
0.0504
,
0.0575
,
-
0.2138
,
0.1071
,
0.1976
,
-
0.0758
,
0.0139
,
-
0.0761
,
0.3991
,
-
0.2965
,
-
0.4845
,
-
0.1496
,
0.3285
,
-
0.2763
,
-
0.4715
,
-
0.3010
,
-
0.2306
,
-
0.2283
,
-
0.2656
,
0.2035
,
0.3570
,
-
0.1499
,
0.4390
,
-
0.1843
,
0.2351
,
0.3357
,
0.1217
,
0.1401
,
0.3300
,
-
0.0429
,
0.3266
,
0.4834
,
-
0.3914
,
-
0.1480
,
0.3734
,
-
0.0372
,
-
0.1746
,
0.0550
,
0.4177
,
-
0.1332
,
0.4391
,
-
0.3287
,
-
0.4401
,
0.1486
,
0.1346
,
0.1048
,
-
0.4361
,
0.0886
,
-
0.3840
,
-
0.2730
,
-
0.1710
,
0.3274
,
0.0169
,
-
0.4462
,
0.0729
,
0.3983
,
-
0.0669
,
0.0756
,
0.4150
,
-
0.4684
,
-
0.2522
};
std
::
vector
<
float
>
r_data
{
0.1237
,
0.1229
,
-
0.0766
,
-
0.1144
,
-
0.1186
,
0.2922
,
0.2478
,
0.3159
,
-
0.0522
,
0.1685
,
-
0.4621
,
0.1728
,
0.0670
,
-
0.2458
,
-
0.3835
,
-
0.4589
,
-
0.3109
,
0.4908
,
-
0.0133
,
-
0.1858
,
-
0.0590
,
-
0.0347
,
-
0.2353
,
-
0.0671
,
-
0.3812
,
-
0.0004
,
-
0.1432
,
0.2406
,
0.1033
,
-
0.0265
,
-
0.3902
,
0.0755
,
0.3733
,
0.4383
,
-
0.3140
,
0.2537
,
-
0.1818
,
-
0.4127
,
0.3506
,
0.2562
,
0.2926
,
0.1620
,
-
0.4849
,
-
0.4861
,
0.4426
,
0.2106
,
-
0.0005
,
0.4418
,
-
0.2926
,
-
0.3100
,
0.1500
,
-
0.0362
,
-
0.3801
,
-
0.0065
,
-
0.0631
,
0.1277
,
0.2315
,
0.4087
,
-
0.3963
,
-
0.4161
,
-
0.2169
,
-
0.1344
,
0.3468
,
-
0.2260
,
-
0.4564
,
-
0.4432
,
0.1605
,
0.4387
,
0.0034
,
0.4116
,
0.2824
,
0.4775
,
-
0.2729
,
-
0.4707
,
0.1363
,
0.2218
,
0.0559
,
0.2828
,
0.2093
,
0.4687
,
0.3794
,
-
0.1069
,
-
0.3049
,
0.1430
,
-
0.2506
,
0.4644
,
0.2755
,
-
0.3645
,
-
0.3155
,
0.1425
,
0.2891
,
0.1786
,
-
0.3274
,
0.2365
,
0.2522
,
-
0.4312
,
-
0.0562
,
-
0.2748
,
0.0776
,
-
0.3154
,
0.2851
,
-
0.3930
,
-
0.1174
,
0.4360
,
0.2436
,
0.0164
,
-
0.0680
,
0.3403
,
-
0.2857
,
-
0.0459
,
-
0.2991
,
-
0.2624
,
0.4194
,
-
0.3291
,
-
0.4659
,
0.3300
,
0.0454
,
0.4981
,
-
0.4706
,
-
0.4584
,
0.2596
,
0.2871
,
-
0.3509
,
-
0.1910
,
0.3987
,
-
0.1687
,
-
0.0032
,
-
0.1038
};
std
::
vector
<
float
>
bias_data
{
0.0088
,
0.1183
,
0.1642
,
-
0.2631
,
-
0.1330
,
-
0.4008
,
0.3881
,
-
0.4407
,
-
0.2760
,
0.1274
,
-
0.0083
,
-
0.2885
,
0.3949
,
-
0.0182
,
0.4445
,
0.3477
,
0.2266
,
0.3423
,
-
0.0674
,
-
0.4067
,
0.0807
,
0.1109
,
-
0.2036
,
0.1782
,
-
0.2467
,
-
0.0730
,
-
0.4216
,
0.0316
,
-
0.3025
,
0.3637
,
-
0.3181
,
-
0.4655
,
-
0.0258
,
0.0073
,
-
0.4780
,
-
0.4101
,
-
0.3556
,
-
0.1017
,
0.3632
,
-
0.1823
,
0.1479
,
0.1677
,
-
0.2603
,
0.0381
,
0.1575
,
0.1896
,
0.4755
,
-
0.4794
,
0.2167
,
-
0.4474
,
-
0.3139
,
0.1018
,
0.4470
,
-
0.4232
,
0.3247
,
-
0.1636
,
-
0.1582
,
-
0.1703
,
0.3920
,
0.2055
,
-
0.4386
,
0.4208
,
0.0717
,
0.3789
};
std
::
vector
<
float
>
input_data
{
-
0.5516
,
0.2391
,
-
1.6951
,
-
0.4313
,
-
0.9730
,
-
0.2005
,
2.3930
,
-
0.5221
,
-
0.1331
,
-
0.0910
,
1.2122
,
-
0.1952
,
0.4661
,
0.6494
,
2.1332
,
-
1.0972
,
0.9816
,
0.1122
,
0.3577
,
1.3508
,
-
0.5366
,
1.7449
,
0.5483
,
-
0.0701
,
-
0.4100
,
-
2.2344
,
0.3685
,
0.4583
,
2.3794
,
1.0372
,
-
0.8887
,
0.7892
,
-
0.4012
,
-
0.2818
,
-
2.3374
,
1.5310
};
std
::
vector
<
float
>
ih_data
{
std::vector<float> r_data{
1.9104
,
-
1.9004
,
0.3337
,
0.5741
,
0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685,
0.5671
,
0.0458
,
0.4514
,
-
0.8968
,
-0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858,
-
0.9201
,
0.1962
,
0.5771
,
-
0.5332
,
-0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265,
1.5289
,
1.0986
,
0.6091
,
1.6462
,
-0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562,
0.8720
,
0.5349
,
-
0.1962
,
-
1.7416
,
0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100,
-
0.9912
,
1.2831
,
1.0896
,
-
0.6959
};
0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161,
-0.2169, -0.1344, 0.3468, -0.2260, -0.4564, -0.4432, 0.1605, 0.4387, 0.0034, 0.4116,
std
::
vector
<
float
>
ic_data
{
0.2824, 0.4775, -0.2729, -0.4707, 0.1363, 0.2218, 0.0559, 0.2828, 0.2093, 0.4687,
0.9569
,
-
0.5981
,
1.1312
,
1.0945
,
0.3794, -0.1069, -0.3049, 0.1430, -0.2506, 0.4644, 0.2755, -0.3645, -0.3155, 0.1425,
1.1055
,
-
0.1212
,
-
0.9097
,
0.7831
,
0.2891, 0.1786, -0.3274, 0.2365, 0.2522, -0.4312, -0.0562, -0.2748, 0.0776, -0.3154,
-
1.6991
,
-
1.9498
,
-
1.2567
,
-
0.4114
,
0.2851, -0.3930, -0.1174, 0.4360, 0.2436, 0.0164, -0.0680, 0.3403, -0.2857, -0.0459,
-
0.8323
,
0.3998
,
0.1831
,
0.5938
,
-0.2991, -0.2624, 0.4194, -0.3291, -0.4659, 0.3300, 0.0454, 0.4981, -0.4706, -0.4584,
2.7096
,
-
0.1790
,
0.0022
,
-
0.8040
,
0.2596, 0.2871, -0.3509, -0.1910, 0.3987, -0.1687, -0.0032, -0.1038};
0.1578
,
0.0567
,
0.8069
,
-
0.5141
};
std::vector<float> bias_data{
std
::
vector
<
float
>
pph_data
{
0.0088, 0.1183, 0.1642, -0.2631, -0.1330, -0.4008, 0.3881, -0.4407, -0.2760, 0.1274,
1.84369764
,
0.68413646
,
-
0.44892886
,
-
1.50904413
,
-0.0083, -0.2885, 0.3949, -0.0182, 0.4445, 0.3477, 0.2266, 0.3423, -0.0674, -0.4067,
0.3860796
,
-
0.52186625
,
1.08474445
,
-
1.80867321
,
0.0807, 0.1109, -0.2036, 0.1782, -0.2467, -0.0730, -0.4216, 0.0316, -0.3025, 0.3637,
1.32594529
,
0.4336262
,
-
0.83699064
,
0.49162736
,
-0.3181, -0.4655, -0.0258, 0.0073, -0.4780, -0.4101, -0.3556, -0.1017, 0.3632, -0.1823,
-
0.8271
,
-
0.5683
,
0.4562
,
-
1.2545
,
0.1479, 0.1677, -0.2603, 0.0381, 0.1575, 0.1896, 0.4755, -0.4794, 0.2167, -0.4474,
1.2729
,
-
0.4082
,
-
0.4392
,
-
0.9406
,
-0.3139, 0.1018, 0.4470, -0.4232, 0.3247, -0.1636, -0.1582, -0.1703, 0.3920, 0.2055,
0.7794
,
1.8194
,
-
0.5811
,
0.2166
};
-0.4386, 0.4208, 0.0717, 0.3789};
float
clip
=
0.0
f
;
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
seq_len
,
batch_size
,
input_size
}};
std::vector<float> input_data{
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
input_size
}};
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331,
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
hidden_size
}};
-0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122,
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
8
*
hidden_size
}};
0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685,
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
batch_size
,
hidden_size
}};
0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310};
migraphx
::
shape
ic_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
batch_size
,
hidden_size
}};
migraphx
::
shape
pph_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
3
*
hidden_size
}};
std::vector<float> ih_data{1.9104, -1.9004, 0.3337, 0.5741, 0.5671, 0.0458,
0.4514, -0.8968, -0.9201, 0.1962, 0.5771, -0.5332,
1.5289, 1.0986, 0.6091, 1.6462, 0.8720, 0.5349,
-0.1962, -1.7416, -0.9912, 1.2831, 1.0896, -0.6959};
std::vector<float> ic_data{0.9569, -0.5981, 1.1312, 1.0945, 1.1055, -0.1212,
-0.9097, 0.7831, -1.6991, -1.9498, -1.2567, -0.4114,
-0.8323, 0.3998, 0.1831, 0.5938, 2.7096, -0.1790,
0.0022, -0.8040, 0.1578, 0.0567, 0.8069, -0.5141};
std::vector<float> pph_data{1.84369764, 0.68413646, -0.44892886, -1.50904413, 0.3860796,
-0.52186625, 1.08474445, -1.80867321, 1.32594529, 0.4336262,
-0.83699064, 0.49162736, -0.8271, -0.5683, 0.4562,
-1.2545, 1.2729, -0.4082, -0.4392, -0.9406,
0.7794, 1.8194, -0.5811, 0.2166};
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
// concatenation of hidden states as program output
// concatenation of hidden states as program output
{
{
migraphx
::
program
p
;
migraphx::program p;
auto
seq
=
p
.
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto
ih
=
p
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
ih_data
});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto
ic
=
p
.
add_literal
(
migraphx
::
literal
{
ic_shape
,
ic_data
});
auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
auto
w
=
p
.
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto
r
=
p
.
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto
bias
=
p
.
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto
pph
=
p
.
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data});
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
auto und = p.add_instruction(migraphx::op::undefined{});
p
.
add_instruction
(
migraphx
::
op
::
lstm
{
hidden_size
,
p.add_instruction(
{
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}},
migraphx::op::lstm{
migraphx
::
op
::
rnn_direction
::
bidirectional
,
hidden_size,
clip
,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
0
},
migraphx::op::rnn_direction::bidirectional,
seq
,
clip,
w
,
0},
r
,
seq,
bias
,
w,
und
,
r,
ih
,
bias,
ic
,
und,
pph
);
ih,
p
.
compile
(
migraphx
::
cpu
::
target
{});
ic,
pph);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
auto hs_concat = p.eval({});
std
::
vector
<
float
>
output_data
;
std::vector<float> output_data;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
std::vector<float> output_data_gold{
0.079753
,
-
0.289854
,
0.160043
,
0.115056
,
0.079753, -0.289854, 0.160043, 0.115056, 0.294074, -0.0319677, -0.0955337,
0.294074
,
-
0.0319677
,
-
0.0955337
,
0.104168
,
0.104168, 0.022618, -0.121195, -0.4065, -0.252054, -0.120174, 0.043157,
0.022618
,
-
0.121195
,
-
0.4065
,
-
0.252054
,
0.117138, -0.222188, 0.789732, 0.128538, 0.20909, 0.0553812, -0.224905,
-
0.120174
,
0.043157
,
0.117138
,
-
0.222188
,
0.32421, 0.344048, 0.271694, 0.186991, -0.0624168, 0.205513, 0.0836373,
0.789732
,
0.128538
,
0.20909
,
0.0553812
,
0.421857, 0.0459771, -0.144955, 0.0720673, -0.0300906, -0.0890598, -0.135266,
-
0.224905
,
0.32421
,
0.344048
,
0.271694
,
-0.0413375, -0.175114, -0.00543549, 0.178681, -0.266999, 0.928866, 0.113685,
0.186991
,
-
0.0624168
,
0.205513
,
0.0836373
,
0.220626, -0.0432316, -0.063456, 0.148524, 0.05108, -0.0234895, 0.0459032,
0.421857
,
0.0459771
,
-
0.144955
,
0.0720673
,
0.0414126, 0.272303, 0.0393149, 0.218258, 0.0944405, 0.0431211, -0.132394,
-
0.0300906
,
-
0.0890598
,
-
0.135266
,
-
0.0413375
,
0.103489, 0.0142918, -0.123408, 0.0401075, -0.182201, -0.0232277, 0.235501,
-
0.175114
,
-
0.00543549
,
0.178681
,
-
0.266999
,
-0.213485, 0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356,
0.928866
,
0.113685
,
0.220626
,
-
0.0432316
,
0.146353, 0.0789186, -0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878,
-
0.063456
,
0.148524
,
0.05108
,
-
0.0234895
,
0.0971544, 0.149294, -0.0492549, 0.187761, 0.0501726, -0.121584, 0.0606723,
0.0459032
,
0.0414126
,
0.272303
,
0.0393149
,
-0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676, 0.044508,
0.218258
,
0.0944405
,
0.0431211
,
-
0.132394
,
-0.373961, -0.0681467, 0.382748, 0.230211, -0.161537};
0.103489
,
0.0142918
,
-
0.123408
,
0.0401075
,
-
0.182201
,
-
0.0232277
,
0.235501
,
-
0.213485
,
0.960938
,
0.133565
,
0.269741
,
0.130438
,
-
0.0252804
,
0.267356
,
0.146353
,
0.0789186
,
-
0.058052
,
0.0795391
,
0.266617
,
-
0.0128746
,
0.0309878
,
0.0971544
,
0.149294
,
-
0.0492549
,
0.187761
,
0.0501726
,
-
0.121584
,
0.0606723
,
-
0.185038
,
-
0.026845
,
0.177273
,
-
0.0774616
,
0.946669
,
0.0868676
,
0.044508
,
-
0.373961
,
-
0.0681467
,
0.382748
,
0.230211
,
-
0.161537
};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
}
// last hidden state as program output
// last hidden state as program output
{
{
migraphx
::
program
p
;
migraphx::program p;
auto
seq
=
p
.
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto
ih
=
p
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
ih_data
});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto
ic
=
p
.
add_literal
(
migraphx
::
literal
{
ic_shape
,
ic_data
});
auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
auto
w
=
p
.
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto
r
=
p
.
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto
bias
=
p
.
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto
pph
=
p
.
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data});
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
auto und = p.add_instruction(migraphx::op::undefined{});
auto
hs
=
p
.
add_instruction
(
migraphx
::
op
::
lstm
{
hidden_size
,
auto hs = p.add_instruction(
{
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}},
migraphx::op::lstm{
migraphx
::
op
::
rnn_direction
::
bidirectional
,
hidden_size,
clip
,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
0
},
migraphx::op::rnn_direction::bidirectional,
seq
,
clip,
w
,
0},
r
,
seq,
bias
,
w,
und
,
r,
ih
,
bias,
ic
,
und,
pph
);
ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_output{}, hs);
p.add_instruction(migraphx::op::rnn_last_output{}, hs);
p
.
compile
(
migraphx
::
cpu
::
target
{});
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
auto hs_concat = p.eval({});
std
::
vector
<
float
>
output_data
;
std::vector<float> output_data;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
std::vector<float> output_data_gold{
-
0.058052
,
0.0795391
,
0.266617
,
-
0.0128746
,
-0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878, 0.0971544, 0.149294, -0.0492549,
0.0309878
,
0.0971544
,
0.149294
,
-
0.0492549
,
0.187761, 0.0501726, -0.121584, 0.0606723, -0.120174, 0.043157, 0.117138, -0.222188,
0.187761
,
0.0501726
,
-
0.121584
,
0.0606723
,
0.789732, 0.128538, 0.20909, 0.0553812, -0.224905, 0.32421, 0.344048, 0.271694};
-
0.120174
,
0.043157
,
0.117138
,
-
0.222188
,
0.789732
,
0.128538
,
0.20909
,
0.0553812
,
-
0.224905
,
0.32421
,
0.344048
,
0.271694
};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
}
// last cell output as program output
// last cell output as program output
{
{
migraphx
::
program
p
;
migraphx::program p;
auto
seq
=
p
.
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto
ih
=
p
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
ih_data
});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto
ic
=
p
.
add_literal
(
migraphx
::
literal
{
ic_shape
,
ic_data
});
auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
auto
w
=
p
.
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto
r
=
p
.
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto
bias
=
p
.
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto
pph
=
p
.
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data});
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
auto und = p.add_instruction(migraphx::op::undefined{});
auto
hs
=
p
.
add_instruction
(
migraphx
::
op
::
lstm
{
hidden_size
,
auto hs = p.add_instruction(
{
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}},
migraphx::op::lstm{
migraphx
::
op
::
rnn_direction
::
bidirectional
,
hidden_size,
clip
,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
0
},
migraphx::op::rnn_direction::bidirectional,
seq
,
clip,
w
,
0},
r
,
seq,
bias
,
w,
und
,
r,
ih
,
bias,
ic
,
und,
pph
);
ih,
ic,
pph);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, hs);
p
.
compile
(
migraphx
::
cpu
::
target
{});
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
auto hs_concat = p.eval({});
std
::
vector
<
float
>
output_data
;
std::vector<float> output_data;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
std::vector<float> output_data_gold{
-
0.077353
,
0.245616
,
0.361023
,
-
0.0443759
,
-0.077353, 0.245616, 0.361023, -0.0443759, 0.0685243, 0.20465, 0.277867, -0.112934,
0.0685243
,
0.20465
,
0.277867
,
-
0.112934
,
0.67312, 0.120508, -0.726968, 0.113845, -0.889294, 0.182463, 0.186512, -0.402334,
0.67312
,
0.120508
,
-
0.726968
,
0.113845
,
1.48161, 0.524116, 0.347113, 0.181813, -0.434265, 0.747833, 0.416053, 0.558713};
-
0.889294
,
0.182463
,
0.186512
,
-
0.402334
,
1.48161
,
0.524116
,
0.347113
,
0.181813
,
-
0.434265
,
0.747833
,
0.416053
,
0.558713
};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
}
// 3 args, concatenation of hidden states as program output
// 3 args, concatenation of hidden states as program output
{
{
migraphx
::
program
p
;
migraphx::program p;
auto
seq
=
p
.
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto
w
=
p
.
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto
r
=
p
.
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
p
.
add_instruction
(
migraphx
::
op
::
lstm
{
hidden_size
,
p.add_instruction(
{
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}},
migraphx::op::lstm{
migraphx
::
op
::
rnn_direction
::
bidirectional
,
hidden_size,
clip
,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
0
},
migraphx::op::rnn_direction::bidirectional,
seq
,
clip,
w
,
0},
r
);
seq,
p
.
compile
(
migraphx
::
cpu
::
target
{});
w,
r);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
auto hs_concat = p.eval({});
std
::
vector
<
float
>
output_data
;
std::vector<float> output_data;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
std::vector<float> output_data_gold{
-
0.0327039
,
-
0.0543852
,
0.114378
,
-
0.0768855
,
-0.0327039, -0.0543852, 0.114378, -0.0768855, 0.0319021, -0.00298698, -0.0623361,
0.0319021
,
-
0.00298698
,
-
0.0623361
,
0.0598866
,
0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617, -0.162851, -0.102647,
0.101585
,
0.0687269
,
-
0.161725
,
-
0.25617
,
-0.113827, -0.142818, 0.0513685, 0.0547876, 0.0201981, -0.00808453, -0.00520328,
-
0.162851
,
-
0.102647
,
-
0.113827
,
-
0.142818
,
0.0945081, 0.264123, 0.410805, -0.0786602, -0.0613048, 0.179592, -0.071286,
0.0513685
,
0.0547876
,
0.0201981
,
-
0.00808453
,
0.074206, 0.0124086, -0.139544, 0.108016, -0.00973633, -0.0552699, 0.0252681,
-
0.00520328
,
0.0945081
,
0.264123
,
0.410805
,
-0.0562072, -0.123496, -0.153616, -0.032874, -0.195349, 0.0192675, -0.108636,
-
0.0786602
,
-
0.0613048
,
0.179592
,
-
0.071286
,
0.098927, -0.140733, 0.162602, 0.0143099, -0.0455534, 0.0151574, -0.102509,
0.074206
,
0.0124086
,
-
0.139544
,
0.108016
,
-0.0372696, 0.252296, -0.144544, 0.00496085, 0.0662588, -0.048577, -0.187329,
-
0.00973633
,
-
0.0552699
,
0.0252681
,
-
0.0562072
,
0.0855831, -0.0171894, -0.140202, 0.0828391, -0.1073, -0.150145, 0.015065,
-
0.123496
,
-
0.153616
,
-
0.032874
,
-
0.195349
,
-0.192699, -0.112764, -0.120496, 0.155754, 0.148256, 0.208491, 0.348432,
0.0192675
,
-
0.108636
,
0.098927
,
-
0.140733
,
0.0291103, 0.230275, -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544,
0.162602
,
0.0143099
,
-
0.0455534
,
0.0151574
,
-0.0401315, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
-
0.102509
,
-
0.0372696
,
0.252296
,
-
0.144544
,
-0.021205, -0.125423, 0.0206439, -0.187097, -0.0051453, -0.0767618, -0.0735348,
0.00496085
,
0.0662588
,
-
0.048577
,
-
0.187329
,
-0.0826436, 0.214159, 0.262295, 0.0247127, 0.14472};
0.0855831
,
-
0.0171894
,
-
0.140202
,
0.0828391
,
-
0.1073
,
-
0.150145
,
0.015065
,
-
0.192699
,
-
0.112764
,
-
0.120496
,
0.155754
,
0.148256
,
0.208491
,
0.348432
,
0.0291103
,
0.230275
,
-
0.165194
,
-
0.0372928
,
0.273786
,
-
0.100877
,
-
0.0458544
,
-
0.0401315
,
0.0737483
,
-
0.064505
,
0.136898
,
0.00160891
,
-
0.184812
,
0.147774
,
-
0.021205
,
-
0.125423
,
0.0206439
,
-
0.187097
,
-
0.0051453
,
-
0.0767618
,
-
0.0735348
,
-
0.0826436
,
0.214159
,
0.262295
,
0.0247127
,
0.14472
};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
}
...
@@ -3088,38 +3011,33 @@ TEST_CASE(lstm_bidirectional)
...
@@ -3088,38 +3011,33 @@ TEST_CASE(lstm_bidirectional)
{
{
migraphx::program p;
migraphx::program p;
seq_len = 1;
seq_len = 1;
migraphx
::
shape
in_shape1
{
migraphx
::
shape
::
float_type
,
{
seq_len
,
batch_size
,
input_size
}};
migraphx::shape in_shape1{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
std
::
vector
<
float
>
input_data1
{
std::vector<float> input_data1{
-
0.5516
,
0.2391
,
-
1.6951
,
-0.5516, 0.2391, -1.6951,
-0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331};
-
0.4313
,
-
0.9730
,
-
0.2005
,
auto seq = p.add_literal(migraphx::literal{in_shape1, input_data1});
2.3930
,
-
0.5221
,
-
0.1331
}
;
auto w = p.add_literal(migraphx::literal{w_shape, w_data})
;
auto
seq
=
p
.
add_literal
(
migraphx
::
literal
{
in
_shape
1
,
input
_data
1
});
auto
r
= p.add_literal(migraphx::literal{
r
_shape,
r
_data});
auto
w
=
p
.
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
p.add_instruction(
auto
r
=
p
.
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
migraphx::op::lstm{
p
.
add_instruction
(
migraphx
::
op
::
lstm
{
hidden_size
,
hidden_size,
{
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}},
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx
::
op
::
rnn_direction
::
bidirectional
,
migraphx::op::rnn_direction::bidirectional,
clip
,
clip,
0
},
0},
seq
,
seq,
w
,
w,
r
);
r);
p
.
compile
(
migraphx
::
cpu
::
target
{});
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
auto hs_concat = p.eval({});
std
::
vector
<
float
>
output_data
;
std::vector<float> output_data;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
std::vector<float> output_data_gold{
-
0.0327039
,
-
0.0543852
,
0.114378
,
-
0.0768855
,
-0.0327039, -0.0543852, 0.114378, -0.0768855, 0.0319021, -0.00298698,
0.0319021
,
-
0.00298698
,
-
0.0623361
,
0.0598866
,
-0.0623361, 0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617,
0.101585
,
0.0687269
,
-
0.161725
,
-
0.25617
,
-0.104351, -0.0471426, -0.0905753, 0.01506, 0.059797, 0.104239,
-
0.104351
,
-
0.0471426
,
-
0.0905753
,
0.01506
,
-0.0266768, 0.0727547, -0.146298, 0.070535, 0.327809, 0.407388};
0.059797
,
0.104239
,
-
0.0266768
,
0.0727547
,
-
0.146298
,
0.070535
,
0.327809
,
0.407388
};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
}
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment