Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
fbc53b14
Commit
fbc53b14
authored
Feb 07, 2019
by
Shucai Xiao
Browse files
fix clang format issues.
parent
a40e58d3
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
287 additions
and
0 deletions
+287
-0
test/cpu_rnn_ops_test.cpp
test/cpu_rnn_ops_test.cpp
+268
-0
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+19
-0
No files found.
test/cpu_rnn_ops_test.cpp
View file @
fbc53b14
...
...
@@ -736,6 +736,64 @@ TEST_CASE(gru_forward)
EXPECT
(
migraphx
::
verify_range
(
hs_data
,
hs_data_gold
));
}
}
TEST_CASE
(
gru_forward_args
)
{
std
::
size_t
batch_size
=
2
;
std
::
size_t
seq_len
=
3
;
std
::
size_t
hidden_size
=
5
;
std
::
size_t
input_size
=
3
;
std
::
size_t
num_dirct
=
1
;
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
3
*
hidden_size
,
input_size
}};
std
::
vector
<
float
>
w_data
{
0.3485
,
-
0.0378
,
-
0.1782
,
0.1416
,
-
0.3096
,
-
0.2212
,
-
0.3883
,
0.1983
,
-
0.2418
,
0.1480
,
-
0.3255
,
0.1359
,
-
0.3551
,
-
0.3605
,
-
0.3482
,
-
0.1424
,
-
0.0495
,
-
0.1640
,
-
0.1979
,
-
0.2577
,
-
0.4097
,
-
0.1211
,
-
0.0412
,
0.1801
,
0.1721
,
-
0.4327
,
-
0.0498
,
0.2628
,
-
0.1573
,
-
0.1577
,
0.2759
,
-
0.2023
,
-
0.1185
,
-
0.2136
,
0.1294
,
-
0.2331
,
0.0701
,
0.4316
,
0.0480
,
0.0247
,
-
0.0166
,
-
0.2729
,
0.1712
,
-
0.3984
,
-
0.3905
};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
3
*
hidden_size
,
hidden_size
}};
std
::
vector
<
float
>
r_data
{
0.2848
,
-
0.2851
,
-
0.3466
,
-
0.1718
,
-
0.1492
,
-
0.0082
,
0.2452
,
-
0.0401
,
0.3399
,
0.2529
,
-
0.0953
,
-
0.0903
,
-
0.1518
,
-
0.1373
,
0.3848
,
-
0.0130
,
-
0.4339
,
0.0406
,
-
0.1926
,
-
0.1131
,
0.4285
,
-
0.0013
,
0.2243
,
0.2752
,
0.1776
,
-
0.1720
,
0.0822
,
-
0.0295
,
0.1062
,
-
0.2721
,
-
0.2736
,
-
0.1826
,
0.3541
,
-
0.4259
,
0.2188
,
0.0706
,
0.3650
,
0.3947
,
0.2522
,
0.2179
,
-
0.0744
,
0.2122
,
-
0.4346
,
0.2760
,
0.4076
,
0.1183
,
-
0.1500
,
-
0.1704
,
0.3090
,
-
0.0706
,
-
0.2442
,
0.3021
,
0.1680
,
0.0783
,
-
0.3754
,
-
0.3469
,
-
0.2972
,
-
0.0170
,
0.4143
,
0.3801
,
0.3852
,
-
0.1170
,
-
0.2937
,
0.2979
,
-
0.1357
,
0.4257
,
0.3884
,
-
0.2916
,
0.1071
,
0.0934
,
0.3645
,
-
0.4310
,
-
0.3480
,
0.0702
,
-
0.1558
};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
6
*
hidden_size
}};
std
::
vector
<
float
>
bias_data
{
0.0560
,
0.0310
,
-
0.1669
,
-
0.0781
,
0.1793
,
-
0.1758
,
0.3173
,
-
0.1650
,
-
0.3732
,
0.2946
,
-
0.0912
,
0.3118
,
0.1391
,
0.2755
,
0.2695
,
-
0.1059
,
-
0.2357
,
0.3629
,
-
0.2534
,
-
0.0494
,
0.0556
,
0.0881
,
-
0.2592
,
-
0.2213
,
0.2310
,
-
0.4044
,
0.1801
,
0.1438
,
0.3108
,
-
0.3607
};
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
seq_len
,
batch_size
,
input_size
}};
std
::
vector
<
float
>
input
{
-
0.8432
,
-
0.9887
,
1.3041
,
-
2.6430
,
-
0.3306
,
-
0.8504
,
-
0.3933
,
0.5151
,
-
0.2951
,
0.0093
,
-
1.1948
,
-
0.1239
,
0.0373
,
1.3211
,
0.7854
,
-
0.4838
,
-
1.0536
,
-
0.2529
};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
batch_size
,
hidden_size
}};
std
::
vector
<
float
>
ih_data
{
-
0.0468
,
0.5691
,
-
0.0882
,
0.8340
,
0.1483
,
-
0.3902
,
-
0.5348
,
0.4178
,
1.0175
,
0.9212
};
float
clip
=
0.0
f
;
// 3 args
{
...
...
@@ -833,6 +891,64 @@ TEST_CASE(gru_forward)
EXPECT
(
migraphx
::
verify_range
(
hs_data
,
hs_data_gold
));
}
}
TEST_CASE
(
gru_forward_actv_funcs
)
{
std
::
size_t
batch_size
=
2
;
std
::
size_t
seq_len
=
3
;
std
::
size_t
hidden_size
=
5
;
std
::
size_t
input_size
=
3
;
std
::
size_t
num_dirct
=
1
;
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
3
*
hidden_size
,
input_size
}};
std
::
vector
<
float
>
w_data
{
0.3485
,
-
0.0378
,
-
0.1782
,
0.1416
,
-
0.3096
,
-
0.2212
,
-
0.3883
,
0.1983
,
-
0.2418
,
0.1480
,
-
0.3255
,
0.1359
,
-
0.3551
,
-
0.3605
,
-
0.3482
,
-
0.1424
,
-
0.0495
,
-
0.1640
,
-
0.1979
,
-
0.2577
,
-
0.4097
,
-
0.1211
,
-
0.0412
,
0.1801
,
0.1721
,
-
0.4327
,
-
0.0498
,
0.2628
,
-
0.1573
,
-
0.1577
,
0.2759
,
-
0.2023
,
-
0.1185
,
-
0.2136
,
0.1294
,
-
0.2331
,
0.0701
,
0.4316
,
0.0480
,
0.0247
,
-
0.0166
,
-
0.2729
,
0.1712
,
-
0.3984
,
-
0.3905
};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
3
*
hidden_size
,
hidden_size
}};
std
::
vector
<
float
>
r_data
{
0.2848
,
-
0.2851
,
-
0.3466
,
-
0.1718
,
-
0.1492
,
-
0.0082
,
0.2452
,
-
0.0401
,
0.3399
,
0.2529
,
-
0.0953
,
-
0.0903
,
-
0.1518
,
-
0.1373
,
0.3848
,
-
0.0130
,
-
0.4339
,
0.0406
,
-
0.1926
,
-
0.1131
,
0.4285
,
-
0.0013
,
0.2243
,
0.2752
,
0.1776
,
-
0.1720
,
0.0822
,
-
0.0295
,
0.1062
,
-
0.2721
,
-
0.2736
,
-
0.1826
,
0.3541
,
-
0.4259
,
0.2188
,
0.0706
,
0.3650
,
0.3947
,
0.2522
,
0.2179
,
-
0.0744
,
0.2122
,
-
0.4346
,
0.2760
,
0.4076
,
0.1183
,
-
0.1500
,
-
0.1704
,
0.3090
,
-
0.0706
,
-
0.2442
,
0.3021
,
0.1680
,
0.0783
,
-
0.3754
,
-
0.3469
,
-
0.2972
,
-
0.0170
,
0.4143
,
0.3801
,
0.3852
,
-
0.1170
,
-
0.2937
,
0.2979
,
-
0.1357
,
0.4257
,
0.3884
,
-
0.2916
,
0.1071
,
0.0934
,
0.3645
,
-
0.4310
,
-
0.3480
,
0.0702
,
-
0.1558
};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
6
*
hidden_size
}};
std
::
vector
<
float
>
bias_data
{
0.0560
,
0.0310
,
-
0.1669
,
-
0.0781
,
0.1793
,
-
0.1758
,
0.3173
,
-
0.1650
,
-
0.3732
,
0.2946
,
-
0.0912
,
0.3118
,
0.1391
,
0.2755
,
0.2695
,
-
0.1059
,
-
0.2357
,
0.3629
,
-
0.2534
,
-
0.0494
,
0.0556
,
0.0881
,
-
0.2592
,
-
0.2213
,
0.2310
,
-
0.4044
,
0.1801
,
0.1438
,
0.3108
,
-
0.3607
};
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
seq_len
,
batch_size
,
input_size
}};
std
::
vector
<
float
>
input
{
-
0.8432
,
-
0.9887
,
1.3041
,
-
2.6430
,
-
0.3306
,
-
0.8504
,
-
0.3933
,
0.5151
,
-
0.2951
,
0.0093
,
-
1.1948
,
-
0.1239
,
0.0373
,
1.3211
,
0.7854
,
-
0.4838
,
-
1.0536
,
-
0.2529
};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
batch_size
,
hidden_size
}};
std
::
vector
<
float
>
ih_data
{
-
0.0468
,
0.5691
,
-
0.0882
,
0.8340
,
0.1483
,
-
0.3902
,
-
0.5348
,
0.4178
,
1.0175
,
0.9212
};
float
clip
=
0.0
f
;
// no activation function specified, so default is used.
{
...
...
@@ -1422,6 +1538,82 @@ TEST_CASE(gru_bidirectional)
EXPECT
(
migraphx
::
verify_range
(
hs_data
,
hs_data_gold
));
}
}
TEST_CASE
(
gru_bidirectional_args
)
{
std
::
size_t
batch_size
=
2
;
std
::
size_t
seq_len
=
3
;
std
::
size_t
hidden_size
=
5
;
std
::
size_t
input_size
=
3
;
std
::
size_t
num_dirct
=
2
;
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
3
*
hidden_size
,
input_size
}};
std
::
vector
<
float
>
w_data
{
0.3809
,
0.4283
,
0.2294
,
-
0.1018
,
-
0.1226
,
-
0.0037
,
0.2449
,
-
0.2712
,
-
0.1418
,
0.1363
,
-
0.3453
,
-
0.0693
,
-
0.2281
,
0.2699
,
-
0.2024
,
-
0.3085
,
-
0.3338
,
0.4109
,
0.2605
,
-
0.1019
,
-
0.2813
,
0.3323
,
-
0.1590
,
0.0788
,
-
0.3535
,
0.0397
,
0.2732
,
0.2906
,
0.0519
,
0.3617
,
-
0.2664
,
0.1441
,
0.0464
,
-
0.1057
,
0.2204
,
-
0.3294
,
0.3670
,
0.1411
,
0.3852
,
0.3572
,
0.3918
,
0.0483
,
-
0.3906
,
-
0.2841
,
-
0.2778
,
-
0.4272
,
0.2335
,
-
0.1811
,
-
0.3885
,
-
0.1279
,
0.1000
,
0.0206
,
-
0.3284
,
-
0.0353
,
0.1197
,
0.1190
,
0.3862
,
0.0965
,
-
0.0492
,
0.2657
,
-
0.1430
,
0.0597
,
0.1408
,
-
0.0315
,
0.1248
,
0.0751
,
0.3838
,
0.3020
,
0.0515
,
0.2375
,
-
0.4255
,
0.1714
,
-
0.0432
,
0.3447
,
-
0.2441
,
-
0.3989
,
-
0.3428
,
-
0.4204
,
-
0.4080
,
-
0.2683
,
-
0.0996
,
-
0.1685
,
-
0.0532
,
-
0.1258
,
0.1663
,
-
0.3526
,
-
0.3915
,
-
0.1721
,
0.1292
,
-
0.2279
};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
3
*
hidden_size
,
hidden_size
}};
std
::
vector
<
float
>
r_data
{
-
0.2683
,
0.0699
,
-
0.4021
,
-
0.1379
,
0.0042
,
-
0.2447
,
0.4006
,
0.0270
,
-
0.0446
,
0.1063
,
0.1381
,
0.1310
,
-
0.3596
,
0.3869
,
0.3929
,
0.2750
,
0.0890
,
0.3069
,
-
0.1691
,
-
0.2194
,
-
0.1066
,
0.3187
,
-
0.4369
,
-
0.0603
,
-
0.0834
,
-
0.1182
,
-
0.2047
,
0.3253
,
-
0.2931
,
0.2082
,
0.0424
,
0.1111
,
-
0.2773
,
-
0.0279
,
-
0.0869
,
0.1413
,
-
0.4227
,
-
0.3672
,
0.4137
,
0.0609
,
0.4223
,
-
0.4032
,
0.2945
,
0.3600
,
0.3345
,
-
0.3880
,
-
0.0192
,
-
0.0090
,
-
0.2648
,
0.4339
,
-
0.0155
,
0.4437
,
-
0.1766
,
0.1957
,
0.2475
,
0.3773
,
-
0.2710
,
0.3289
,
-
0.2077
,
-
0.2534
,
-
0.0832
,
-
0.1632
,
0.0728
,
0.2520
,
0.4153
,
0.1659
,
-
0.4342
,
0.0541
,
0.1812
,
-
0.2305
,
0.4440
,
0.0946
,
0.0410
,
-
0.4381
,
-
0.3161
,
0.3906
,
-
0.3958
,
-
0.4238
,
0.1975
,
0.3440
,
0.1437
,
-
0.0568
,
0.1492
,
-
0.4248
,
-
0.3304
,
0.2786
,
-
0.1328
,
-
0.3740
,
-
0.3566
,
0.3074
,
0.0924
,
0.2684
,
-
0.1527
,
0.1826
,
0.2424
,
0.2002
,
0.3479
,
-
0.1089
,
0.3472
,
-
0.3677
,
-
0.4231
,
-
0.0798
,
-
0.3709
,
0.3924
,
0.2774
,
-
0.3690
,
-
0.0233
,
0.2845
,
0.1969
,
0.1618
,
-
0.3742
,
-
0.3619
,
0.2925
,
-
0.1838
,
-
0.1495
,
-
0.3747
,
0.0341
,
-
0.4243
,
-
0.0732
,
-
0.3997
,
0.2139
,
0.2425
,
0.4171
,
-
0.3358
,
0.3534
,
0.0938
,
-
0.0582
,
-
0.2681
,
-
0.4293
,
0.1027
,
0.4101
,
0.2641
,
-
0.4110
,
-
0.1681
,
0.3582
,
-
0.2089
,
0.0852
,
0.0963
,
0.3866
,
0.1955
,
-
0.2174
,
0.1996
,
-
0.2252
,
0.1748
,
0.1833
,
-
0.3155
,
0.2567
,
-
0.4387
,
0.3402
,
0.0599
};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
6
*
hidden_size
}};
std
::
vector
<
float
>
bias_data
{
-
0.1582
,
-
0.0826
,
0.4008
,
0.0118
,
0.2511
,
0.1900
,
-
0.2838
,
0.2549
,
-
0.2484
,
0.2363
,
-
0.4083
,
-
0.0295
,
-
0.1161
,
0.1211
,
0.2509
,
-
0.1414
,
-
0.2628
,
-
0.2992
,
0.1517
,
0.1817
,
-
0.2783
,
0.3183
,
-
0.1629
,
-
0.3108
,
-
0.3418
,
0.0411
,
0.2203
,
0.2187
,
-
0.2990
,
-
0.0416
,
0.0209
,
-
0.1024
,
0.4443
,
-
0.4420
,
-
0.0330
,
-
0.3591
,
-
0.2990
,
0.2167
,
0.1395
,
0.2317
,
0.1318
,
0.1909
,
-
0.3615
,
0.1953
,
-
0.2582
,
-
0.2217
,
0.3723
,
0.1458
,
0.2630
,
-
0.0377
,
0.1754
,
0.0800
,
-
0.3964
,
-
0.3247
,
0.4219
,
-
0.0900
,
0.3553
,
0.2614
,
-
0.1298
,
-
0.1124
};
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
seq_len
,
batch_size
,
input_size
}};
std
::
vector
<
float
>
input
{
-
0.8432
,
-
0.9887
,
1.3041
,
-
2.6430
,
-
0.3306
,
-
0.8504
,
-
0.3933
,
0.5151
,
-
0.2951
,
0.0093
,
-
1.1948
,
-
0.1239
,
0.0373
,
1.3211
,
0.7854
,
-
0.4838
,
-
1.0536
,
-
0.2529
};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
batch_size
,
hidden_size
}};
std
::
vector
<
float
>
ih_data
{
-
0.0468
,
0.5691
,
-
0.0882
,
0.8340
,
0.1483
,
-
0.3902
,
-
0.5348
,
0.4178
,
1.0175
,
0.9212
,
-
0.0468
,
0.5691
,
-
0.0882
,
0.8340
,
0.1483
,
-
0.3902
,
-
0.5348
,
0.4178
,
1.0175
,
0.9212
};
float
clip
=
0.0
f
;
// 3 args
{
...
...
@@ -1530,6 +1722,82 @@ TEST_CASE(gru_bidirectional)
-
0.0339407
,
0.413089
,
0.721238
,
0.431879
};
EXPECT
(
migraphx
::
verify_range
(
hs_data
,
hs_data_gold
));
}
}
TEST_CASE
(
gru_bidirectional_actv_funcs
)
{
std
::
size_t
batch_size
=
2
;
std
::
size_t
seq_len
=
3
;
std
::
size_t
hidden_size
=
5
;
std
::
size_t
input_size
=
3
;
std
::
size_t
num_dirct
=
2
;
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
3
*
hidden_size
,
input_size
}};
std
::
vector
<
float
>
w_data
{
0.3809
,
0.4283
,
0.2294
,
-
0.1018
,
-
0.1226
,
-
0.0037
,
0.2449
,
-
0.2712
,
-
0.1418
,
0.1363
,
-
0.3453
,
-
0.0693
,
-
0.2281
,
0.2699
,
-
0.2024
,
-
0.3085
,
-
0.3338
,
0.4109
,
0.2605
,
-
0.1019
,
-
0.2813
,
0.3323
,
-
0.1590
,
0.0788
,
-
0.3535
,
0.0397
,
0.2732
,
0.2906
,
0.0519
,
0.3617
,
-
0.2664
,
0.1441
,
0.0464
,
-
0.1057
,
0.2204
,
-
0.3294
,
0.3670
,
0.1411
,
0.3852
,
0.3572
,
0.3918
,
0.0483
,
-
0.3906
,
-
0.2841
,
-
0.2778
,
-
0.4272
,
0.2335
,
-
0.1811
,
-
0.3885
,
-
0.1279
,
0.1000
,
0.0206
,
-
0.3284
,
-
0.0353
,
0.1197
,
0.1190
,
0.3862
,
0.0965
,
-
0.0492
,
0.2657
,
-
0.1430
,
0.0597
,
0.1408
,
-
0.0315
,
0.1248
,
0.0751
,
0.3838
,
0.3020
,
0.0515
,
0.2375
,
-
0.4255
,
0.1714
,
-
0.0432
,
0.3447
,
-
0.2441
,
-
0.3989
,
-
0.3428
,
-
0.4204
,
-
0.4080
,
-
0.2683
,
-
0.0996
,
-
0.1685
,
-
0.0532
,
-
0.1258
,
0.1663
,
-
0.3526
,
-
0.3915
,
-
0.1721
,
0.1292
,
-
0.2279
};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
3
*
hidden_size
,
hidden_size
}};
std
::
vector
<
float
>
r_data
{
-
0.2683
,
0.0699
,
-
0.4021
,
-
0.1379
,
0.0042
,
-
0.2447
,
0.4006
,
0.0270
,
-
0.0446
,
0.1063
,
0.1381
,
0.1310
,
-
0.3596
,
0.3869
,
0.3929
,
0.2750
,
0.0890
,
0.3069
,
-
0.1691
,
-
0.2194
,
-
0.1066
,
0.3187
,
-
0.4369
,
-
0.0603
,
-
0.0834
,
-
0.1182
,
-
0.2047
,
0.3253
,
-
0.2931
,
0.2082
,
0.0424
,
0.1111
,
-
0.2773
,
-
0.0279
,
-
0.0869
,
0.1413
,
-
0.4227
,
-
0.3672
,
0.4137
,
0.0609
,
0.4223
,
-
0.4032
,
0.2945
,
0.3600
,
0.3345
,
-
0.3880
,
-
0.0192
,
-
0.0090
,
-
0.2648
,
0.4339
,
-
0.0155
,
0.4437
,
-
0.1766
,
0.1957
,
0.2475
,
0.3773
,
-
0.2710
,
0.3289
,
-
0.2077
,
-
0.2534
,
-
0.0832
,
-
0.1632
,
0.0728
,
0.2520
,
0.4153
,
0.1659
,
-
0.4342
,
0.0541
,
0.1812
,
-
0.2305
,
0.4440
,
0.0946
,
0.0410
,
-
0.4381
,
-
0.3161
,
0.3906
,
-
0.3958
,
-
0.4238
,
0.1975
,
0.3440
,
0.1437
,
-
0.0568
,
0.1492
,
-
0.4248
,
-
0.3304
,
0.2786
,
-
0.1328
,
-
0.3740
,
-
0.3566
,
0.3074
,
0.0924
,
0.2684
,
-
0.1527
,
0.1826
,
0.2424
,
0.2002
,
0.3479
,
-
0.1089
,
0.3472
,
-
0.3677
,
-
0.4231
,
-
0.0798
,
-
0.3709
,
0.3924
,
0.2774
,
-
0.3690
,
-
0.0233
,
0.2845
,
0.1969
,
0.1618
,
-
0.3742
,
-
0.3619
,
0.2925
,
-
0.1838
,
-
0.1495
,
-
0.3747
,
0.0341
,
-
0.4243
,
-
0.0732
,
-
0.3997
,
0.2139
,
0.2425
,
0.4171
,
-
0.3358
,
0.3534
,
0.0938
,
-
0.0582
,
-
0.2681
,
-
0.4293
,
0.1027
,
0.4101
,
0.2641
,
-
0.4110
,
-
0.1681
,
0.3582
,
-
0.2089
,
0.0852
,
0.0963
,
0.3866
,
0.1955
,
-
0.2174
,
0.1996
,
-
0.2252
,
0.1748
,
0.1833
,
-
0.3155
,
0.2567
,
-
0.4387
,
0.3402
,
0.0599
};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
6
*
hidden_size
}};
std
::
vector
<
float
>
bias_data
{
-
0.1582
,
-
0.0826
,
0.4008
,
0.0118
,
0.2511
,
0.1900
,
-
0.2838
,
0.2549
,
-
0.2484
,
0.2363
,
-
0.4083
,
-
0.0295
,
-
0.1161
,
0.1211
,
0.2509
,
-
0.1414
,
-
0.2628
,
-
0.2992
,
0.1517
,
0.1817
,
-
0.2783
,
0.3183
,
-
0.1629
,
-
0.3108
,
-
0.3418
,
0.0411
,
0.2203
,
0.2187
,
-
0.2990
,
-
0.0416
,
0.0209
,
-
0.1024
,
0.4443
,
-
0.4420
,
-
0.0330
,
-
0.3591
,
-
0.2990
,
0.2167
,
0.1395
,
0.2317
,
0.1318
,
0.1909
,
-
0.3615
,
0.1953
,
-
0.2582
,
-
0.2217
,
0.3723
,
0.1458
,
0.2630
,
-
0.0377
,
0.1754
,
0.0800
,
-
0.3964
,
-
0.3247
,
0.4219
,
-
0.0900
,
0.3553
,
0.2614
,
-
0.1298
,
-
0.1124
};
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
seq_len
,
batch_size
,
input_size
}};
std
::
vector
<
float
>
input
{
-
0.8432
,
-
0.9887
,
1.3041
,
-
2.6430
,
-
0.3306
,
-
0.8504
,
-
0.3933
,
0.5151
,
-
0.2951
,
0.0093
,
-
1.1948
,
-
0.1239
,
0.0373
,
1.3211
,
0.7854
,
-
0.4838
,
-
1.0536
,
-
0.2529
};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
batch_size
,
hidden_size
}};
std
::
vector
<
float
>
ih_data
{
-
0.0468
,
0.5691
,
-
0.0882
,
0.8340
,
0.1483
,
-
0.3902
,
-
0.5348
,
0.4178
,
1.0175
,
0.9212
,
-
0.0468
,
0.5691
,
-
0.0882
,
0.8340
,
0.1483
,
-
0.3902
,
-
0.5348
,
0.4178
,
1.0175
,
0.9212
};
float
clip
=
0.0
f
;
// no activation function specified, so default is used.
{
...
...
test/onnx/onnx_test.cpp
View file @
fbc53b14
...
...
@@ -741,6 +741,16 @@ TEST_CASE(gru_test)
EXPECT
(
p
==
prog
);
}
}
TEST_CASE
(
gru_test_args
)
{
std
::
size_t
sl
=
5
;
// sequence len
std
::
size_t
bs
=
3
;
// batch size
std
::
size_t
hs
=
20
;
// hidden size
std
::
size_t
is
=
10
;
// input size
std
::
size_t
nd
=
2
;
// num directions
float
clip
=
0.0
f
;
// 3 arguments
{
...
...
@@ -836,7 +846,16 @@ TEST_CASE(gru_test)
EXPECT
(
p
==
prog
);
}
}
TEST_CASE
(
gru_test_actv_funcs
)
{
std
::
size_t
sl
=
5
;
// sequence len
std
::
size_t
bs
=
3
;
// batch size
std
::
size_t
hs
=
20
;
// hidden size
std
::
size_t
is
=
10
;
// input size
std
::
size_t
nd
=
2
;
// num directions
float
clip
=
0.0
f
;
// bidirection, 0 actv function
{
nd
=
2
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment