Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
ceb68567
"...git@developer.sourcefind.cn:chenzk/alphafold2_jax.git" did not exist on "d9e5e1d9c65ff339328c0fb7078057593eecd043"
Commit
ceb68567
authored
Feb 20, 2019
by
Shucai Xiao
Browse files
add actv function test for the lstm operator.
parent
6a94c42a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
131 additions
and
0 deletions
+131
-0
test/cpu_rnn_ops_test.cpp
test/cpu_rnn_ops_test.cpp
+131
-0
No files found.
test/cpu_rnn_ops_test.cpp
View file @
ceb68567
...
@@ -3040,4 +3040,135 @@ TEST_CASE(lstm_bidirectional)
...
@@ -3040,4 +3040,135 @@ TEST_CASE(lstm_bidirectional)
}
}
}
}
TEST_CASE
(
lstm_bidirectional_actv_func
)
{
std
::
size_t
batch_size
=
3
;
std
::
size_t
seq_len
=
4
;
std
::
size_t
hidden_size
=
4
;
std
::
size_t
input_size
=
3
;
std
::
size_t
num_dirct
=
2
;
std
::
vector
<
float
>
w_data
{
0.1236
,
-
0.3942
,
0.4149
,
0.0795
,
0.4934
,
-
0.2858
,
0.2602
,
-
0.3098
,
0.0567
,
0.3344
,
0.3607
,
-
0.0551
,
0.4952
,
0.3799
,
0.0630
,
-
0.3532
,
0.0023
,
-
0.0592
,
0.4267
,
0.2382
,
-
0.0784
,
-
0.0032
,
-
0.2476
,
-
0.0206
,
-
0.4963
,
0.4837
,
0.0827
,
0.0123
,
-
0.1203
,
-
0.0279
,
-
0.0049
,
0.4721
,
-
0.3564
,
-
0.1286
,
0.4090
,
-
0.0504
,
0.0575
,
-
0.2138
,
0.1071
,
0.1976
,
-
0.0758
,
0.0139
,
-
0.0761
,
0.3991
,
-
0.2965
,
-
0.4845
,
-
0.1496
,
0.3285
,
-
0.2763
,
-
0.4715
,
-
0.3010
,
-
0.2306
,
-
0.2283
,
-
0.2656
,
0.2035
,
0.3570
,
-
0.1499
,
0.4390
,
-
0.1843
,
0.2351
,
0.3357
,
0.1217
,
0.1401
,
0.3300
,
-
0.0429
,
0.3266
,
0.4834
,
-
0.3914
,
-
0.1480
,
0.3734
,
-
0.0372
,
-
0.1746
,
0.0550
,
0.4177
,
-
0.1332
,
0.4391
,
-
0.3287
,
-
0.4401
,
0.1486
,
0.1346
,
0.1048
,
-
0.4361
,
0.0886
,
-
0.3840
,
-
0.2730
,
-
0.1710
,
0.3274
,
0.0169
,
-
0.4462
,
0.0729
,
0.3983
,
-
0.0669
,
0.0756
,
0.4150
,
-
0.4684
,
-
0.2522
};
std
::
vector
<
float
>
r_data
{
0.1237
,
0.1229
,
-
0.0766
,
-
0.1144
,
-
0.1186
,
0.2922
,
0.2478
,
0.3159
,
-
0.0522
,
0.1685
,
-
0.4621
,
0.1728
,
0.0670
,
-
0.2458
,
-
0.3835
,
-
0.4589
,
-
0.3109
,
0.4908
,
-
0.0133
,
-
0.1858
,
-
0.0590
,
-
0.0347
,
-
0.2353
,
-
0.0671
,
-
0.3812
,
-
0.0004
,
-
0.1432
,
0.2406
,
0.1033
,
-
0.0265
,
-
0.3902
,
0.0755
,
0.3733
,
0.4383
,
-
0.3140
,
0.2537
,
-
0.1818
,
-
0.4127
,
0.3506
,
0.2562
,
0.2926
,
0.1620
,
-
0.4849
,
-
0.4861
,
0.4426
,
0.2106
,
-
0.0005
,
0.4418
,
-
0.2926
,
-
0.3100
,
0.1500
,
-
0.0362
,
-
0.3801
,
-
0.0065
,
-
0.0631
,
0.1277
,
0.2315
,
0.4087
,
-
0.3963
,
-
0.4161
,
-
0.2169
,
-
0.1344
,
0.3468
,
-
0.2260
,
-
0.4564
,
-
0.4432
,
0.1605
,
0.4387
,
0.0034
,
0.4116
,
0.2824
,
0.4775
,
-
0.2729
,
-
0.4707
,
0.1363
,
0.2218
,
0.0559
,
0.2828
,
0.2093
,
0.4687
,
0.3794
,
-
0.1069
,
-
0.3049
,
0.1430
,
-
0.2506
,
0.4644
,
0.2755
,
-
0.3645
,
-
0.3155
,
0.1425
,
0.2891
,
0.1786
,
-
0.3274
,
0.2365
,
0.2522
,
-
0.4312
,
-
0.0562
,
-
0.2748
,
0.0776
,
-
0.3154
,
0.2851
,
-
0.3930
,
-
0.1174
,
0.4360
,
0.2436
,
0.0164
,
-
0.0680
,
0.3403
,
-
0.2857
,
-
0.0459
,
-
0.2991
,
-
0.2624
,
0.4194
,
-
0.3291
,
-
0.4659
,
0.3300
,
0.0454
,
0.4981
,
-
0.4706
,
-
0.4584
,
0.2596
,
0.2871
,
-
0.3509
,
-
0.1910
,
0.3987
,
-
0.1687
,
-
0.0032
,
-
0.1038
};
std
::
vector
<
float
>
input_data
{
-
0.5516
,
0.2391
,
-
1.6951
,
-
0.4313
,
-
0.9730
,
-
0.2005
,
2.3930
,
-
0.5221
,
-
0.1331
,
-
0.0910
,
1.2122
,
-
0.1952
,
0.4661
,
0.6494
,
2.1332
,
-
1.0972
,
0.9816
,
0.1122
,
0.3577
,
1.3508
,
-
0.5366
,
1.7449
,
0.5483
,
-
0.0701
,
-
0.4100
,
-
2.2344
,
0.3685
,
0.4583
,
2.3794
,
1.0372
,
-
0.8887
,
0.7892
,
-
0.4012
,
-
0.2818
,
-
2.3374
,
1.5310
};
float
clip
=
0.0
f
;
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
seq_len
,
batch_size
,
input_size
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
input_size
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
hidden_size
}};
// 3 args, 0 actv func
{
migraphx
::
program
p
;
auto
seq
=
p
.
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
w
=
p
.
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
p
.
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
p
.
add_instruction
(
migraphx
::
op
::
lstm
{
hidden_size
,
{},
migraphx
::
op
::
rnn_direction
::
bidirectional
,
clip
,
0
},
seq
,
w
,
r
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
hs_concat
=
p
.
eval
({});
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.0327039
,
-
0.0543852
,
0.114378
,
-
0.0768855
,
0.0319021
,
-
0.00298698
,
-
0.0623361
,
0.0598866
,
0.101585
,
0.0687269
,
-
0.161725
,
-
0.25617
,
-
0.162851
,
-
0.102647
,
-
0.113827
,
-
0.142818
,
0.0513685
,
0.0547876
,
0.0201981
,
-
0.00808453
,
-
0.00520328
,
0.0945081
,
0.264123
,
0.410805
,
-
0.0786602
,
-
0.0613048
,
0.179592
,
-
0.071286
,
0.074206
,
0.0124086
,
-
0.139544
,
0.108016
,
-
0.00973633
,
-
0.0552699
,
0.0252681
,
-
0.0562072
,
-
0.123496
,
-
0.153616
,
-
0.032874
,
-
0.195349
,
0.0192675
,
-
0.108636
,
0.098927
,
-
0.140733
,
0.162602
,
0.0143099
,
-
0.0455534
,
0.0151574
,
-
0.102509
,
-
0.0372696
,
0.252296
,
-
0.144544
,
0.00496085
,
0.0662588
,
-
0.048577
,
-
0.187329
,
0.0855831
,
-
0.0171894
,
-
0.140202
,
0.0828391
,
-
0.1073
,
-
0.150145
,
0.015065
,
-
0.192699
,
-
0.112764
,
-
0.120496
,
0.155754
,
0.148256
,
0.208491
,
0.348432
,
0.0291103
,
0.230275
,
-
0.165194
,
-
0.0372928
,
0.273786
,
-
0.100877
,
-
0.0458544
,
-
0.0401315
,
0.0737483
,
-
0.064505
,
0.136898
,
0.00160891
,
-
0.184812
,
0.147774
,
-
0.021205
,
-
0.125423
,
0.0206439
,
-
0.187097
,
-
0.0051453
,
-
0.0767618
,
-
0.0735348
,
-
0.0826436
,
0.214159
,
0.262295
,
0.0247127
,
0.14472
};
EXPECT
(
migraphx
::
verify_range
(
output_data
,
output_data_gold
));
}
// 3 args, 1 actv func
{
migraphx
::
program
p
;
auto
seq
=
p
.
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
w
=
p
.
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
p
.
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
p
.
add_instruction
(
migraphx
::
op
::
lstm
{
hidden_size
,
{
migraphx
::
op
::
sigmoid
{}},
migraphx
::
op
::
rnn_direction
::
bidirectional
,
clip
,
0
},
seq
,
w
,
r
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
hs_concat
=
p
.
eval
({});
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
0.227861
,
0.328562
,
0.277867
,
0.272945
,
0.204389
,
0.296123
,
0.223834
,
0.311113
,
0.424666
,
0.173974
,
0.40628
,
0.286631
,
0.246078
,
0.199709
,
0.303753
,
0.301178
,
0.264634
,
0.304661
,
0.349371
,
0.288934
,
0.405483
,
0.445586
,
0.515814
,
0.473186
,
0.339438
,
0.29655
,
0.331832
,
0.242338
,
0.409384
,
0.236272
,
0.306045
,
0.26269
,
0.261246
,
0.334357
,
0.23622
,
0.245288
,
0.301937
,
0.264893
,
0.254353
,
0.269231
,
0.359258
,
0.400097
,
0.288884
,
0.247329
,
0.276519
,
0.264249
,
0.1769
,
0.23213
,
0.374123
,
0.283167
,
0.377129
,
0.245726
,
0.444712
,
0.203168
,
0.411446
,
0.269965
,
0.172792
,
0.296224
,
0.17319
,
0.352547
,
0.310306
,
0.262902
,
0.276964
,
0.295002
,
0.373802
,
0.366785
,
0.419791
,
0.393216
,
0.262827
,
0.371441
,
0.369022
,
0.298262
,
0.450186
,
0.263538
,
0.402895
,
0.216177
,
0.267257
,
0.342535
,
0.257797
,
0.268563
,
0.193043
,
0.275645
,
0.167678
,
0.350889
,
0.334143
,
0.309444
,
0.174822
,
0.251634
,
0.244564
,
0.214386
,
0.185994
,
0.226699
,
0.28445
,
0.376092
,
0.338326
,
0.259502
};
EXPECT
(
migraphx
::
verify_range
(
output_data
,
output_data_gold
));
}
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment