Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
fed23ec7
Commit
fed23ec7
authored
Nov 21, 2023
by
Artur Wojcik
Browse files
Merge branch 'develop' into uif2-initial
parents
9d933920
225873aa
Changes
55
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
2865 additions
and
406 deletions
+2865
-406
test/onnx/verify_onnx.cpp
test/onnx/verify_onnx.cpp
+75
-0
test/py/onnx_backend_test.py
test/py/onnx_backend_test.py
+0
-1
test/quantization.cpp
test/quantization.cpp
+70
-70
test/ref/rnn_ops.cpp
test/ref/rnn_ops.cpp
+1727
-234
test/simplify_algebra_test.cpp
test/simplify_algebra_test.cpp
+15
-19
test/simplify_qdq_test.cpp
test/simplify_qdq_test.cpp
+363
-70
test/simplify_reshapes_test.cpp
test/simplify_reshapes_test.cpp
+5
-10
test/verify/test_lstm_bidirct_3args_layout.cpp
test/verify/test_lstm_bidirct_3args_layout.cpp
+77
-0
test/verify/test_lstm_bidirct_last_layout.cpp
test/verify/test_lstm_bidirct_last_layout.cpp
+95
-0
test/verify/test_lstm_forward_hs_layout.cpp
test/verify/test_lstm_forward_hs_layout.cpp
+95
-0
test/verify/test_lstm_forward_last_layout.cpp
test/verify/test_lstm_forward_last_layout.cpp
+97
-0
test/verify/test_lstm_reverse_3args_cell_output_layout.cpp
test/verify/test_lstm_reverse_3args_cell_output_layout.cpp
+78
-0
test/verify/test_lstm_reverse_3args_layout.cpp
test/verify/test_lstm_reverse_3args_layout.cpp
+78
-0
test/verify/test_lstm_three_outputs_layout.cpp
test/verify/test_lstm_three_outputs_layout.cpp
+85
-0
tools/api/migraphx.h
tools/api/migraphx.h
+5
-2
No files found.
test/onnx/verify_onnx.cpp
View file @
fed23ec7
...
...
@@ -1895,6 +1895,81 @@ TEST_CASE(qlinearmatmul_3D_test)
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vector
,
gold
));
}
TEST_CASE
(
qlinearmul_test
)
{
// github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearMul
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
"qlinearmul_test.onnx"
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
migraphx
::
shape
a
{
migraphx
::
shape
::
uint8_type
,
{
64
}};
std
::
vector
<
uint8_t
>
data_a
=
{
0
,
2
,
4
,
6
,
8
,
10
,
12
,
14
,
16
,
18
,
20
,
22
,
24
,
26
,
28
,
30
,
32
,
34
,
36
,
38
,
40
,
42
,
44
,
46
,
48
,
50
,
52
,
54
,
56
,
58
,
60
,
62
,
64
,
66
,
68
,
70
,
72
,
74
,
76
,
78
,
80
,
82
,
84
,
86
,
88
,
90
,
92
,
94
,
96
,
98
,
100
,
102
,
104
,
106
,
108
,
110
,
112
,
114
,
116
,
118
,
120
,
122
,
124
,
126
};
migraphx
::
shape
b
{
migraphx
::
shape
::
uint8_type
,
{
64
}};
std
::
vector
<
uint8_t
>
data_b
=
{
128
,
126
,
124
,
122
,
120
,
118
,
116
,
114
,
112
,
110
,
108
,
106
,
104
,
102
,
100
,
98
,
96
,
94
,
92
,
90
,
88
,
86
,
84
,
82
,
80
,
78
,
76
,
74
,
72
,
70
,
68
,
66
,
64
,
62
,
60
,
58
,
56
,
54
,
52
,
50
,
48
,
46
,
44
,
42
,
40
,
38
,
36
,
34
,
32
,
30
,
28
,
26
,
24
,
22
,
20
,
18
,
16
,
14
,
12
,
10
,
8
,
6
,
4
,
2
};
migraphx
::
parameter_map
pp
;
pp
[
"A"
]
=
migraphx
::
argument
(
a
,
data_a
.
data
());
pp
[
"B"
]
=
migraphx
::
argument
(
b
,
data_b
.
data
());
auto
result
=
p
.
eval
(
pp
).
back
();
std
::
vector
<
uint8_t
>
result_vector
;
result
.
visit
([
&
](
auto
output
)
{
result_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
uint8_t
>
gold
=
{
100
,
111
,
122
,
132
,
142
,
151
,
160
,
169
,
177
,
185
,
192
,
199
,
206
,
212
,
218
,
223
,
228
,
233
,
237
,
241
,
244
,
247
,
250
,
252
,
254
,
255
,
255
,
255
,
255
,
255
,
255
,
255
,
254
,
252
,
250
,
247
,
244
,
241
,
237
,
233
,
228
,
223
,
218
,
212
,
206
,
199
,
192
,
185
,
177
,
169
,
160
,
151
,
142
,
132
,
122
,
111
,
100
,
89
,
77
,
65
,
52
,
39
,
26
,
12
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vector
,
gold
));
}
TEST_CASE
(
qlinearmul_bcast_test
)
{
// github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearMul
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
"qlinearmul_bcast_test.onnx"
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
migraphx
::
shape
a
{
migraphx
::
shape
::
int8_type
,
{
64
}};
std
::
vector
<
int8_t
>
data_a
=
{
-
64
,
-
62
,
-
60
,
-
58
,
-
56
,
-
54
,
-
52
,
-
50
,
-
48
,
-
46
,
-
44
,
-
42
,
-
40
,
-
38
,
-
36
,
-
34
,
-
32
,
-
30
,
-
28
,
-
26
,
-
24
,
-
22
,
-
20
,
-
18
,
-
16
,
-
14
,
-
12
,
-
10
,
-
8
,
-
6
,
-
4
,
-
2
,
0
,
2
,
4
,
6
,
8
,
10
,
12
,
14
,
16
,
18
,
20
,
22
,
24
,
26
,
28
,
30
,
32
,
34
,
36
,
38
,
40
,
42
,
44
,
46
,
48
,
50
,
52
,
54
,
56
,
58
,
60
,
62
};
migraphx
::
shape
b
{
migraphx
::
shape
::
int8_type
,
{
1
,
1
,
64
}};
std
::
vector
<
int8_t
>
data_b
=
{
96
,
94
,
92
,
90
,
88
,
86
,
84
,
82
,
80
,
78
,
76
,
74
,
72
,
70
,
68
,
66
,
64
,
62
,
60
,
58
,
56
,
54
,
52
,
50
,
48
,
46
,
44
,
42
,
40
,
38
,
36
,
34
,
32
,
30
,
28
,
26
,
24
,
22
,
20
,
18
,
16
,
14
,
12
,
10
,
8
,
6
,
4
,
2
,
0
,
-
2
,
-
4
,
-
6
,
-
8
,
-
10
,
-
12
,
-
14
,
-
16
,
-
18
,
-
20
,
-
22
,
-
24
,
-
26
,
-
28
,
-
30
};
migraphx
::
parameter_map
pp
;
pp
[
"A"
]
=
migraphx
::
argument
(
a
,
data_a
.
data
());
pp
[
"B"
]
=
migraphx
::
argument
(
b
,
data_b
.
data
());
auto
result
=
p
.
eval
(
pp
).
back
();
std
::
vector
<
int8_t
>
result_vector
;
result
.
visit
([
&
](
auto
output
)
{
result_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
int8_t
>
gold
=
{
-
128
,
-
128
,
-
128
,
-
128
,
-
128
,
-
128
,
-
128
,
-
128
,
-
128
,
-
126
,
-
118
,
-
109
,
-
101
,
-
93
,
-
86
,
-
78
,
-
70
,
-
63
,
-
56
,
-
49
,
-
42
,
-
35
,
-
28
,
-
21
,
-
15
,
-
9
,
-
2
,
4
,
10
,
15
,
21
,
27
,
32
,
37
,
42
,
47
,
52
,
57
,
62
,
66
,
70
,
75
,
79
,
83
,
86
,
90
,
94
,
97
,
100
,
103
,
106
,
109
,
112
,
115
,
117
,
119
,
122
,
124
,
126
,
127
,
127
,
127
,
127
,
127
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vector
,
gold
));
}
TEST_CASE
(
resize_downsample_f_test
)
{
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
"resize_downsample_f_test.onnx"
);
...
...
test/py/onnx_backend_test.py
View file @
fed23ec7
...
...
@@ -574,7 +574,6 @@ def disabled_tests_onnx_1_9_0(backend_test):
# fails
# from OnnxBackendNodeModelTest
backend_test
.
exclude
(
r
'test_gru_batchwise_cpu'
)
backend_test
.
exclude
(
r
'test_lstm_batchwise_cpu'
)
backend_test
.
exclude
(
r
'test_simple_rnn_batchwise_cpu'
)
# from OnnxBackendPyTorchConvertedModelTest
backend_test
.
exclude
(
r
'test_MaxPool1d_stride_padding_dilation_cpu'
)
...
...
test/quantization.cpp
View file @
fed23ec7
...
...
@@ -638,11 +638,10 @@ TEST_CASE(dot_float)
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sb
.
lens
()}}),
zp
);
auto
quant_b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
pb
,
scale_b
,
zp_b
);
auto
quant
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
quant_a
,
quant_b
);
std
::
vector
<
float
>
vec
(
sc
.
elements
(),
100.0
f
);
auto
dc
=
mm
->
add_literal
(
100.0
f
);
auto
mdc
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sc
.
lens
()}}),
dc
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
quant
,
mdc
);
auto
scale_mb
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
quant
->
get_shape
().
lens
()}}),
scale
);
auto
out_scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
scale_mb
,
scale_mb
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
quant
,
out_scale
);
mm
->
add_return
({
r
});
return
p
;
...
...
@@ -717,24 +716,28 @@ TEST_CASE(dot_double_2args)
auto
pa
=
mm
->
add_parameter
(
"a"
,
sa
);
auto
pb
=
mm
->
add_parameter
(
"b"
,
sb
);
auto
scale_a
=
mm
->
add_literal
(
10.0
);
auto
scale_a
_lit
=
mm
->
add_literal
(
10.0
);
auto
zp
=
mm
->
add_literal
(
static_cast
<
int8_t
>
(
0
));
scale_a
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sa
.
lens
()}}),
scale_a
);
auto
scale_a
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sa
.
lens
()}}),
scale_a
_lit
);
auto
zp_a
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sa
.
lens
()}}),
zp
);
auto
qa
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
pa
,
scale_a
,
zp_a
);
auto
scale_b
=
mm
->
add_literal
(
5.0
);
scale_b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sb
.
lens
()}}),
scale_b
);
auto
scale_b
_lit
=
mm
->
add_literal
(
5.0
);
auto
scale_b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sb
.
lens
()}}),
scale_b
_lit
);
auto
zp_b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sb
.
lens
()}}),
zp
);
auto
qb
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
pb
,
scale_b
,
zp_b
);
auto
qdot
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
qa
,
qb
);
auto
scale
=
mm
->
add_literal
(
50.0
);
scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
qdot
->
get_shape
().
lens
()}}),
scale
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
qdot
,
scale
);
auto
scale_a_mb
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
qdot
->
get_shape
().
lens
()}}),
scale_a_lit
);
auto
scale_b_mb
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
qdot
->
get_shape
().
lens
()}}),
scale_b_lit
);
auto
out_scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
scale_a_mb
,
scale_b_mb
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
qdot
,
out_scale
);
mm
->
add_return
({
r
});
return
p
;
};
...
...
@@ -799,18 +802,15 @@ TEST_CASE(dot_half_1arg)
auto
x
=
mm
->
add_parameter
(
"x"
,
sa
);
auto
zp
=
mm
->
add_literal
(
static_cast
<
int8_t
>
(
0
));
auto
scale
=
mm
->
add_literal
(
migraphx
::
literal
({
sa
.
type
()},
{
10.0
}));
scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sa
.
lens
()}}),
scale
);
auto
scale
_lit
=
mm
->
add_literal
(
migraphx
::
literal
({
sa
.
type
()},
{
10.0
}));
auto
scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sa
.
lens
()}}),
scale
_lit
);
zp
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sa
.
lens
()}}),
zp
);
auto
qx
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
x
,
scale
,
zp
);
auto
qdot
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
qx
,
qx
);
auto
dq_scale
=
mm
->
add_literal
(
migraphx
::
literal
({
sa
.
type
()},
{
100.0
}));
dq_scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
qdot
->
get_shape
().
lens
()}}),
dq_scale
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
qdot
,
dq_scale
);
auto
out_scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
scale
,
scale
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
qdot
,
out_scale
);
mm
->
add_return
({
r
});
return
p
;
};
...
...
@@ -852,9 +852,9 @@ TEST_CASE(conv_float)
auto
pw
=
mm
->
add_parameter
(
"w"
,
sw
);
auto
zp
=
mm
->
add_literal
(
static_cast
<
int8_t
>
(
0
));
auto
scale
=
mm
->
add_literal
(
10.0
f
);
scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sx
.
lens
()}}),
scale
);
auto
scale
_lit
=
mm
->
add_literal
(
10.0
f
);
auto
scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sx
.
lens
()}}),
scale
_lit
);
zp
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sx
.
lens
()}}),
zp
);
auto
quant_x
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
px
,
scale
,
zp
);
...
...
@@ -862,13 +862,11 @@ TEST_CASE(conv_float)
auto
quant
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
),
quant_x
,
quant_w
);
migraphx
::
shape
sc
{
migraphx
::
shape
::
float_type
,
{
4
,
4
,
1
,
1
}};
std
::
vector
<
float
>
vec
(
sc
.
elements
(),
100.0
f
);
migraphx
::
shape
s_scale
{
migraphx
::
shape
::
float_type
,
sc
.
lens
()};
auto
d_scale
=
mm
->
add_literal
(
100.0
f
);
d_scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
4
,
4
,
1
,
1
}}}),
d_scale
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
quant
,
d_scale
);
auto
scale_mb
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
quant
->
get_shape
().
lens
()}}),
scale_lit
);
auto
out_scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
scale_mb
,
scale_mb
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
quant
,
out_scale
);
mm
->
add_return
({
r
});
return
p
;
...
...
@@ -931,19 +929,20 @@ TEST_CASE(conv_half)
auto
pw
=
mm
->
add_parameter
(
"w"
,
sw
);
auto
zp
=
mm
->
add_literal
(
static_cast
<
int8_t
>
(
0
));
auto
scale
=
mm
->
add_literal
(
migraphx
::
literal
({
sx
.
type
()},
{
10.0
}));
scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sx
.
lens
()}}),
scale
);
auto
scale
_lit
=
mm
->
add_literal
(
migraphx
::
literal
({
sx
.
type
()},
{
10.0
}));
auto
scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sx
.
lens
()}}),
scale
_lit
);
zp
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sx
.
lens
()}}),
zp
);
auto
quant_x
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
px
,
scale
,
zp
);
auto
quant_w
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
pw
,
scale
,
zp
);
auto
quant
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
),
quant_x
,
quant_w
);
auto
d_scale
=
mm
->
add_literal
(
migraphx
::
literal
({
sx
.
type
()},
{
100.0
}));
d_scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
4
,
4
,
1
,
1
}}}),
d_scale
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
quant
,
d_scale
);
auto
scale_mb
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
quant
->
get_shape
().
lens
()}}),
scale_lit
);
auto
out_scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
scale_mb
,
scale_mb
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
quant
,
out_scale
);
mm
->
add_return
({
r
});
return
p
;
...
...
@@ -1187,9 +1186,9 @@ TEST_CASE(int8_subgraph)
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sy
.
lens
()}}),
zp1
);
auto
qb
=
then_mod
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
b
,
sb
,
zpb
);
auto
qdot
=
then_mod
->
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
qa
,
qb
);
auto
s
o
=
then_mod
->
add_
literal
(
100.0
f
);
so
=
then_mod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sout
.
lens
()}}),
so
);
auto
s
1_mb
=
then_mod
->
add_
instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
qdot
->
get_shape
().
lens
()}}),
s1
);
auto
so
=
then_mod
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
s1_mb
,
s1_mb
);
auto
r
=
then_mod
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
qdot
,
so
);
then_mod
->
add_return
({
r
});
...
...
@@ -1199,23 +1198,24 @@ TEST_CASE(int8_subgraph)
auto
w
=
mm
->
add_parameter
(
"w"
,
sw
);
// else submod
auto
*
else_mod
=
p
.
create_module
(
"If_6_else"
);
auto
sax
=
else_mod
->
add_literal
(
2.0
f
);
auto
sax
_lit
=
else_mod
->
add_literal
(
2.0
f
);
auto
zp
=
else_mod
->
add_literal
(
static_cast
<
int8_t
>
(
0
));
sax
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sd
.
lens
()}}),
sax
);
auto
sax
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sd
.
lens
()}}),
sax
_lit
);
auto
zpx
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sd
.
lens
()}}),
zp
);
auto
qx
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
x
,
sax
,
zpx
);
auto
ssw
=
else_mod
->
add_literal
(
1.66667
f
);
ssw
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sw
.
lens
()}}),
ssw
);
auto
ssw
_lit
=
else_mod
->
add_literal
(
1.66667
f
);
auto
ssw
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sw
.
lens
()}}),
ssw
_lit
);
auto
zpw
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sw
.
lens
()}}),
zp
);
auto
qw
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
w
,
ssw
,
zpw
);
auto
qconv
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
),
qx
,
qw
);
auto
so1
=
else_mod
->
add_literal
(
3.33333
f
);
so1
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sout
.
lens
()}}),
so1
);
auto
ssw_mb
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
qconv
->
get_shape
().
lens
()}}),
ssw_lit
);
auto
so1
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
sax
,
ssw_mb
);
auto
r1
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
qconv
,
so1
);
else_mod
->
add_return
({
r1
});
...
...
test/ref/rnn_ops.cpp
View file @
fed23ec7
This diff is collapsed.
Click to expand it.
test/simplify_algebra_test.cpp
View file @
fed23ec7
...
...
@@ -1897,12 +1897,17 @@ TEST_CASE(simplify_split_add_relu_reshape)
auto
concatb
=
m2
.
add_instruction
(
b
,
concat
);
auto
sum
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
input
,
concatb
);
auto
relu
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"relu"
),
sum
);
auto
rsp
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
3
,
8
}}}),
relu
);
auto
slc1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
4
}}}),
rsp
);
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
relu
);
auto
rsp1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
3
,
4
}}}),
slc1
);
auto
slc2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
4
}},
{
"ends"
,
{
8
}}}),
rsp
);
auto
add
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
slc1
,
slc2
);
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
1
}},
{
"ends"
,
{
2
}}}),
relu
);
auto
rsp2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
3
,
4
}}}),
slc2
);
auto
add
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
rsp1
,
rsp2
);
m2
.
add_instruction
(
pass_op
{},
add
);
}
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
...
...
@@ -2323,9 +2328,7 @@ TEST_CASE(simplify_dot_horiz_reshape)
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
4
}}}),
dot
);
auto
y
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
4
}},
{
"ends"
,
{
8
}}}),
dot
);
auto
x_cont
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
x
);
auto
x_rsp
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
3
,
4
,
2
,
2
}}}),
x_cont
);
auto
x_rsp
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
3
,
4
,
2
,
2
}}}),
x
);
auto
y_rsp
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
2
}},
{
"steps"
,
{
2
}}}),
y
);
auto
sum
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
{
x_rsp
,
y_rsp
});
...
...
@@ -2690,10 +2693,6 @@ void reorder_reshape_slice()
}
auto
input
=
m2
.
add_parameter
(
"input"
,
s
);
auto
rsp_input
=
input
;
if
(
TransposeInput
)
{
rsp_input
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
{
input
});
}
std
::
vector
<
int64_t
>
lens
=
{
static_cast
<
int64_t
>
(
BS
),
128
,
30
,
64
};
auto
r
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
rsp_input
);
...
...
@@ -2976,9 +2975,8 @@ TEST_CASE(reorder_reshape_slice_multi_rsp)
auto
input
=
m2
.
add_parameter
(
"input"
,
s
);
auto
t1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
2
,
0
,
3
,
1
,
4
}}}),
input
);
auto
c_t1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
t1
);
auto
rsp1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
384
,
128
,
80
}}}),
c_
t1
);
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
384
,
128
,
80
}}}),
t1
);
auto
slc0
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
256
}},
{
"ends"
,
{
384
}}}),
rsp1
);
auto
slc1
=
m2
.
add_instruction
(
...
...
@@ -2993,9 +2991,8 @@ TEST_CASE(reorder_reshape_slice_multi_rsp)
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
slc2
,
c_t_slc1
);
auto
c_t1_1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
t1
);
auto
rsp2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
12
,
32
,
128
,
80
}}}),
c_t1_
1
);
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
12
,
32
,
128
,
80
}}}),
t
1
);
auto
slc2_1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
4
}},
{
"ends"
,
{
8
}}}),
rsp2
);
...
...
@@ -3372,9 +3369,8 @@ TEST_CASE(dot_fusion_reshape)
auto
s1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
320
}},
{
"ends"
,
{
640
}}}),
d
);
auto
cont0
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
s0
);
auto
r0
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
2
,
4096
,
8
,
40
}}}),
cont
0
);
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
2
,
4096
,
8
,
40
}}}),
s
0
);
m2
.
add_return
({
r0
,
s1
});
};
...
...
test/simplify_qdq_test.cpp
View file @
fed23ec7
This diff is collapsed.
Click to expand it.
test/simplify_reshapes_test.cpp
View file @
fed23ec7
...
...
@@ -888,9 +888,8 @@ TEST_CASE(optimize_resize)
std
::
vector
<
int64_t
>
mb_dims
=
{
1
,
2
,
2
,
2
,
2
,
3
};
auto
mbx
=
m
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
mb_dims
}}),
rspx
);
auto
std_mb
=
m
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
mbx
);
std
::
vector
<
int64_t
>
orig_dims
=
{
1
,
2
,
4
,
6
};
auto
rmb
=
m
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
orig_dims
}}),
std_
mb
);
auto
rmb
=
m
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
orig_dims
}}),
mb
x
);
auto
r
=
m
.
add_instruction
(
migraphx
::
make_op
(
"softmax"
,
{{
"axis"
,
1
}}),
rmb
);
m
.
add_return
({
r
});
...
...
@@ -1301,9 +1300,8 @@ TEST_CASE(transpose_contiguous_reshape_unary)
auto
transpose_ins
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
3
,
4
,
1
,
5
,
2
}}}),
reshape_ins1
);
auto
relu
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"relu"
),
transpose_ins
);
auto
cont_ins
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
relu
);
auto
reshape_ins2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
2
,
2
,
10
,
10
}}}),
cont_ins
);
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
2
,
2
,
10
,
10
}}}),
relu
);
m2
.
add_instruction
(
pass_op
{},
reshape_ins2
);
}
EXPECT
(
m1
==
m2
);
...
...
@@ -1328,8 +1326,7 @@ TEST_CASE(transpose_contiguous_squeeze_unary)
auto
transpose_ins
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
3
,
1
}}}),
x
);
auto
rsqrt
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"rsqrt"
),
transpose_ins
);
auto
cont_ins
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
rsqrt
);
auto
sq_ins
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"squeeze"
,
{{
"axes"
,
{
1
}}}),
cont_ins
);
auto
sq_ins
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"squeeze"
,
{{
"axes"
,
{
1
}}}),
rsqrt
);
m2
.
add_instruction
(
pass_op
{},
sq_ins
);
}
EXPECT
(
m1
==
m2
);
...
...
@@ -1355,9 +1352,7 @@ TEST_CASE(transpose_contiguous_unsqueeze_unary)
auto
transpose_ins
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
3
,
1
}}}),
x
);
auto
round
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"nearbyint"
),
transpose_ins
);
auto
cont_ins
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
round
);
auto
unsq_ins
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
2
}}}),
cont_ins
);
auto
unsq_ins
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
2
}}}),
round
);
m2
.
add_instruction
(
pass_op
{},
unsq_ins
);
}
EXPECT
(
m1
==
m2
);
...
...
test/verify/test_lstm_bidirct_3args_layout.cpp
0 → 100644
View file @
fed23ec7
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
struct
test_lstm_bidirct_3args_layout
:
verify_program
<
test_lstm_bidirct_3args_layout
>
{
migraphx
::
program
create_program
()
const
{
std
::
size_t
batch_size
=
2
;
std
::
size_t
seq_len
=
3
;
std
::
size_t
hidden_size
=
5
;
std
::
size_t
input_size
=
8
;
std
::
size_t
num_dirct
=
2
;
float
clip
=
0.0
f
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
seq_len
,
input_size
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
input_size
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
hidden_size
}};
auto
seq
=
mm
->
add_parameter
(
"seq"
,
in_shape
);
auto
w
=
mm
->
add_parameter
(
"w"
,
w_shape
);
auto
r
=
mm
->
add_parameter
(
"r"
,
r_shape
);
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
bidirectional
)},
{
"clip"
,
clip
}}),
seq
,
w
,
r
);
std
::
vector
<
int64_t
>
perm_hid
{
2
,
0
,
1
,
3
};
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hid
}}),
hs
);
return
p
;
}
std
::
string
section
()
const
{
return
"rnn"
;
}
};
test/verify/test_lstm_bidirct_last_layout.cpp
0 → 100644
View file @
fed23ec7
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/op/common.hpp>
struct
test_lstm_bidirct_last_layout
:
verify_program
<
test_lstm_bidirct_last_layout
>
{
migraphx
::
program
create_program
()
const
{
std
::
size_t
batch_size
=
2
;
std
::
size_t
seq_len
=
3
;
std
::
size_t
hidden_size
=
5
;
std
::
size_t
input_size
=
8
;
std
::
size_t
num_dirct
=
2
;
float
clip
=
0.0
f
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
seq_len
,
input_size
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
input_size
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
hidden_size
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
8
*
hidden_size
}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
ic_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
pph_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
3
*
hidden_size
}};
auto
seq
=
mm
->
add_parameter
(
"seq"
,
in_shape
);
auto
w
=
mm
->
add_parameter
(
"w"
,
w_shape
);
auto
r
=
mm
->
add_parameter
(
"r"
,
r_shape
);
auto
bias
=
mm
->
add_parameter
(
"bias"
,
b_shape
);
auto
ih
=
mm
->
add_parameter
(
"ih"
,
ih_shape
);
auto
ic
=
mm
->
add_parameter
(
"ic"
,
ic_shape
);
auto
pph
=
mm
->
add_parameter
(
"pph"
,
pph_shape
);
auto
und
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
bidirectional
)},
{
"clip"
,
clip
}}),
seq
,
w
,
r
,
bias
,
und
,
ih
,
ic
,
pph
);
auto
last_output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_hs_output"
),
output
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
last_output
);
return
p
;
}
std
::
string
section
()
const
{
return
"rnn"
;
}
};
test/verify/test_lstm_forward_hs_layout.cpp
0 → 100644
View file @
fed23ec7
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/op/common.hpp>
struct
test_lstm_forward_hs_layout
:
verify_program
<
test_lstm_forward_hs_layout
>
{
migraphx
::
program
create_program
()
const
{
std
::
size_t
batch_size
=
2
;
std
::
size_t
seq_len
=
3
;
std
::
size_t
hidden_size
=
5
;
std
::
size_t
input_size
=
8
;
std
::
size_t
num_dirct
=
1
;
float
clip
=
0.0
f
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
seq_len
,
input_size
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
input_size
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
hidden_size
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
8
*
hidden_size
}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
ic_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
pph_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
3
*
hidden_size
}};
auto
seq
=
mm
->
add_parameter
(
"seq"
,
in_shape
);
auto
w
=
mm
->
add_parameter
(
"w"
,
w_shape
);
auto
r
=
mm
->
add_parameter
(
"r"
,
r_shape
);
auto
bias
=
mm
->
add_parameter
(
"bias"
,
b_shape
);
auto
ih
=
mm
->
add_parameter
(
"ih"
,
ih_shape
);
auto
ic
=
mm
->
add_parameter
(
"ic"
,
ic_shape
);
auto
pph
=
mm
->
add_parameter
(
"pph"
,
pph_shape
);
auto
und
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
forward
)},
{
"clip"
,
clip
}}),
seq
,
w
,
r
,
bias
,
und
,
ih
,
ic
,
pph
);
std
::
vector
<
int64_t
>
perm_hid
{
2
,
0
,
1
,
3
};
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hid
}}),
hs
);
return
p
;
}
std
::
string
section
()
const
{
return
"rnn"
;
}
};
test/verify/test_lstm_forward_last_layout.cpp
0 → 100644
View file @
fed23ec7
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
struct
test_lstm_forward_last_layout
:
verify_program
<
test_lstm_forward_last_layout
>
{
migraphx
::
program
create_program
()
const
{
std
::
size_t
batch_size
=
2
;
std
::
size_t
seq_len
=
3
;
std
::
size_t
hidden_size
=
5
;
std
::
size_t
input_size
=
8
;
std
::
size_t
num_dirct
=
1
;
float
clip
=
0.0
f
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
seq_len
,
input_size
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
input_size
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
hidden_size
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
8
*
hidden_size
}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
l_shape
{
migraphx
::
shape
::
int32_type
,
{
batch_size
}};
migraphx
::
shape
ic_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
pph_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
3
*
hidden_size
}};
auto
seq
=
mm
->
add_parameter
(
"seq"
,
in_shape
);
auto
w
=
mm
->
add_parameter
(
"w"
,
w_shape
);
auto
r
=
mm
->
add_parameter
(
"r"
,
r_shape
);
auto
bias
=
mm
->
add_parameter
(
"bias"
,
b_shape
);
auto
ih
=
mm
->
add_parameter
(
"ih"
,
ih_shape
);
auto
len
=
mm
->
add_literal
(
migraphx
::
literal
(
l_shape
,
{
1
,
2
}));
auto
ic
=
mm
->
add_parameter
(
"ic"
,
ic_shape
);
auto
pph
=
mm
->
add_parameter
(
"pph"
,
pph_shape
);
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
forward
)},
{
"clip"
,
clip
}}),
seq
,
w
,
r
,
bias
,
len
,
ih
,
ic
,
pph
);
auto
last_output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_hs_output"
),
output
,
len
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
last_output
);
return
p
;
}
std
::
string
section
()
const
{
return
"rnn"
;
}
};
test/verify/test_lstm_reverse_3args_cell_output_layout.cpp
0 → 100644
View file @
fed23ec7
This diff is collapsed.
Click to expand it.
test/verify/test_lstm_reverse_3args_layout.cpp
0 → 100644
View file @
fed23ec7
This diff is collapsed.
Click to expand it.
test/verify/test_lstm_three_outputs_layout.cpp
0 → 100644
View file @
fed23ec7
This diff is collapsed.
Click to expand it.
tools/api/migraphx.h
View file @
fed23ec7
...
...
@@ -44,7 +44,8 @@
m(int32_type, int32_t) \
m(int64_type, int64_t) \
m(uint32_type, uint32_t) \
m(uint64_type, uint64_t)
m(uint64_type, uint64_t) \
m(fp8e4m3fnuz_type, migraphx::fp8::fp8e4m3fnuz)
// clang-format on
#ifdef __cplusplus
...
...
@@ -70,7 +71,9 @@ typedef enum
}
migraphx_shape_datatype_t
;
#undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES
<%
generate_c_header
()
%>
<%
generate_c_header
()
%>
#ifdef __cplusplus
}
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment