Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
80016cff
Commit
80016cff
authored
Feb 05, 2019
by
Shucai Xiao
Browse files
clang format
parent
846afb76
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
122 additions
and
115 deletions
+122
-115
src/rewrite_gru.cpp
src/rewrite_gru.cpp
+26
-31
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+96
-84
No files found.
src/rewrite_gru.cpp
View file @
80016cff
...
@@ -62,26 +62,20 @@ void rewrite_gru::apply(program& prog) const
...
@@ -62,26 +62,20 @@ void rewrite_gru::apply(program& prog) const
ih_reverse
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
ih_reverse
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
}
}
auto
ret_forward
=
gru_cell
(
true
,
auto
ret_forward
=
gru_cell
(
true
,
prog
,
prog
,
ins
,
ins
,
{
args
[
0
],
{
args
[
0
],
w_forward
,
r_forward
,
bias_forward
,
ih_forward
},
w_forward
,
r_forward
,
bias_forward
,
ih_forward
},
gru_op
.
linear_before_reset
,
gru_op
.
linear_before_reset
,
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
1
));
actv_funcs
.
at
(
1
));
auto
ret_reverse
=
gru_cell
(
false
,
auto
ret_reverse
=
gru_cell
(
false
,
prog
,
prog
,
ins
,
ins
,
{
args
[
0
],
{
args
[
0
],
w_reverse
,
r_reverse
,
bias_reverse
,
ih_reverse
},
w_reverse
,
r_reverse
,
bias_reverse
,
ih_reverse
},
gru_op
.
linear_before_reset
,
gru_op
.
linear_before_reset
,
actv_funcs
.
at
(
2
),
actv_funcs
.
at
(
2
),
actv_funcs
.
at
(
3
));
actv_funcs
.
at
(
3
));
...
@@ -159,10 +153,11 @@ void rewrite_gru::apply(program& prog) const
...
@@ -159,10 +153,11 @@ void rewrite_gru::apply(program& prog) const
// replace the corresponding gru_last_output instruction
// replace the corresponding gru_last_output instruction
// with the last_output, if gru_last_output exists
// with the last_output, if gru_last_output exists
auto
last_output_it
=
std
::
find_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
[](
auto
i
)
{
auto
last_output_it
=
std
::
find_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
[](
auto
i
)
{
return
i
->
name
()
==
"gru_last_output"
;
return
i
->
name
()
==
"gru_last_output"
;
});
});
if
(
last_output_it
!=
ins
->
outputs
().
end
())
if
(
last_output_it
!=
ins
->
outputs
().
end
())
{
{
prog
.
replace_instruction
(
*
last_output_it
,
last_output
);
prog
.
replace_instruction
(
*
last_output_it
,
last_output
);
}
}
...
...
test/onnx/onnx_test.cpp
View file @
80016cff
...
@@ -645,8 +645,10 @@ TEST_CASE(gru_test)
...
@@ -645,8 +645,10 @@ TEST_CASE(gru_test)
auto
seq
=
auto
seq
=
p
.
add_parameter
(
"seq"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
sl
,
bs
,
is
}});
p
.
add_parameter
(
"seq"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
sl
,
bs
,
is
}});
auto
w
=
p
.
add_parameter
(
"w"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
is
}});
auto
w
=
auto
r
=
p
.
add_parameter
(
"r"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
hs
}});
p
.
add_parameter
(
"w"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
is
}});
auto
r
=
p
.
add_parameter
(
"r"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
hs
}});
auto
bias
=
auto
bias
=
p
.
add_parameter
(
"bias"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
6
*
hs
}});
p
.
add_parameter
(
"bias"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
6
*
hs
}});
auto
seq_len
=
auto
seq_len
=
...
@@ -677,8 +679,10 @@ TEST_CASE(gru_test)
...
@@ -677,8 +679,10 @@ TEST_CASE(gru_test)
auto
seq
=
auto
seq
=
p
.
add_parameter
(
"seq"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
sl
,
bs
,
is
}});
p
.
add_parameter
(
"seq"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
sl
,
bs
,
is
}});
auto
w
=
p
.
add_parameter
(
"w"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
is
}});
auto
w
=
auto
r
=
p
.
add_parameter
(
"r"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
hs
}});
p
.
add_parameter
(
"w"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
is
}});
auto
r
=
p
.
add_parameter
(
"r"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
hs
}});
auto
bias
=
auto
bias
=
p
.
add_parameter
(
"bias"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
6
*
hs
}});
p
.
add_parameter
(
"bias"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
6
*
hs
}});
auto
seq_len
=
auto
seq_len
=
...
@@ -709,18 +713,21 @@ TEST_CASE(gru_test)
...
@@ -709,18 +713,21 @@ TEST_CASE(gru_test)
auto
seq
=
auto
seq
=
p
.
add_parameter
(
"seq"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
sl
,
bs
,
is
}});
p
.
add_parameter
(
"seq"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
sl
,
bs
,
is
}});
auto
w
=
p
.
add_parameter
(
"w"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
is
}});
auto
w
=
auto
r
=
p
.
add_parameter
(
"r"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
hs
}});
p
.
add_parameter
(
"w"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
is
}});
auto
r
=
p
.
add_parameter
(
"r"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
hs
}});
auto
bias
=
auto
bias
=
p
.
add_parameter
(
"bias"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
6
*
hs
}});
p
.
add_parameter
(
"bias"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
6
*
hs
}});
auto
seq_len
=
auto
seq_len
=
p
.
add_parameter
(
"seq_len"
,
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
bs
}});
p
.
add_parameter
(
"seq_len"
,
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
bs
}});
auto
ih
=
p
.
add_parameter
(
"h0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
bs
,
hs
}});
auto
ih
=
p
.
add_parameter
(
"h0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
bs
,
hs
}});
auto
out_hs
=
auto
out_hs
=
p
.
add_instruction
(
migraphx
::
op
::
gru
{
hs
,
p
.
add_instruction
(
migraphx
::
op
::
gru
{
hs
,
{
migraphx
::
op
::
tanh
{},
{
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
relu
{},
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
relu
{},
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
gru
::
bidirectional
,
migraphx
::
op
::
gru
::
bidirectional
,
clip
},
clip
},
seq
,
seq
,
...
@@ -742,8 +749,10 @@ TEST_CASE(gru_test)
...
@@ -742,8 +749,10 @@ TEST_CASE(gru_test)
auto
seq
=
auto
seq
=
p
.
add_parameter
(
"seq"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
sl
,
bs
,
is
}});
p
.
add_parameter
(
"seq"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
sl
,
bs
,
is
}});
auto
w
=
p
.
add_parameter
(
"w"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
is
}});
auto
w
=
auto
r
=
p
.
add_parameter
(
"r"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
hs
}});
p
.
add_parameter
(
"w"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
is
}});
auto
r
=
p
.
add_parameter
(
"r"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
hs
}});
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
auto
out_hs
=
auto
out_hs
=
p
.
add_instruction
(
migraphx
::
op
::
gru
{
hs
,
p
.
add_instruction
(
migraphx
::
op
::
gru
{
hs
,
...
@@ -769,8 +778,10 @@ TEST_CASE(gru_test)
...
@@ -769,8 +778,10 @@ TEST_CASE(gru_test)
auto
seq
=
auto
seq
=
p
.
add_parameter
(
"seq"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
sl
,
bs
,
is
}});
p
.
add_parameter
(
"seq"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
sl
,
bs
,
is
}});
auto
w
=
p
.
add_parameter
(
"w"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
is
}});
auto
w
=
auto
r
=
p
.
add_parameter
(
"r"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
hs
}});
p
.
add_parameter
(
"w"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
is
}});
auto
r
=
p
.
add_parameter
(
"r"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
hs
}});
auto
bias
=
auto
bias
=
p
.
add_parameter
(
"bias"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
6
*
hs
}});
p
.
add_parameter
(
"bias"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
6
*
hs
}});
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
...
@@ -799,8 +810,10 @@ TEST_CASE(gru_test)
...
@@ -799,8 +810,10 @@ TEST_CASE(gru_test)
auto
seq
=
auto
seq
=
p
.
add_parameter
(
"seq"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
sl
,
bs
,
is
}});
p
.
add_parameter
(
"seq"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
sl
,
bs
,
is
}});
auto
w
=
p
.
add_parameter
(
"w"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
is
}});
auto
w
=
auto
r
=
p
.
add_parameter
(
"r"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
hs
}});
p
.
add_parameter
(
"w"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
is
}});
auto
r
=
p
.
add_parameter
(
"r"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
hs
}});
auto
bias
=
auto
bias
=
p
.
add_parameter
(
"bias"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
6
*
hs
}});
p
.
add_parameter
(
"bias"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
6
*
hs
}});
auto
seq_len
=
auto
seq_len
=
...
@@ -831,8 +844,10 @@ TEST_CASE(gru_test)
...
@@ -831,8 +844,10 @@ TEST_CASE(gru_test)
auto
seq
=
auto
seq
=
p
.
add_parameter
(
"seq"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
sl
,
bs
,
is
}});
p
.
add_parameter
(
"seq"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
sl
,
bs
,
is
}});
auto
w
=
p
.
add_parameter
(
"w"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
is
}});
auto
w
=
auto
r
=
p
.
add_parameter
(
"r"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
hs
}});
p
.
add_parameter
(
"w"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
is
}});
auto
r
=
p
.
add_parameter
(
"r"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
hs
}});
auto
bias
=
auto
bias
=
p
.
add_parameter
(
"bias"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
6
*
hs
}});
p
.
add_parameter
(
"bias"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
6
*
hs
}});
auto
seq_len
=
auto
seq_len
=
...
@@ -840,10 +855,7 @@ TEST_CASE(gru_test)
...
@@ -840,10 +855,7 @@ TEST_CASE(gru_test)
auto
ih
=
p
.
add_parameter
(
"h0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
bs
,
hs
}});
auto
ih
=
p
.
add_parameter
(
"h0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
bs
,
hs
}});
auto
out_hs
=
auto
out_hs
=
p
.
add_instruction
(
migraphx
::
op
::
gru
{
hs
,
p
.
add_instruction
(
migraphx
::
op
::
gru
{
hs
,
{},
migraphx
::
op
::
gru
::
bidirectional
,
clip
},
{},
migraphx
::
op
::
gru
::
bidirectional
,
clip
},
seq
,
seq
,
w
,
w
,
r
,
r
,
...
@@ -863,19 +875,18 @@ TEST_CASE(gru_test)
...
@@ -863,19 +875,18 @@ TEST_CASE(gru_test)
auto
seq
=
auto
seq
=
p
.
add_parameter
(
"seq"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
sl
,
bs
,
is
}});
p
.
add_parameter
(
"seq"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
sl
,
bs
,
is
}});
auto
w
=
p
.
add_parameter
(
"w"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
is
}});
auto
w
=
auto
r
=
p
.
add_parameter
(
"r"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
hs
}});
p
.
add_parameter
(
"w"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
is
}});
auto
r
=
p
.
add_parameter
(
"r"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
hs
}});
auto
bias
=
auto
bias
=
p
.
add_parameter
(
"bias"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
6
*
hs
}});
p
.
add_parameter
(
"bias"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
6
*
hs
}});
auto
seq_len
=
auto
seq_len
=
p
.
add_parameter
(
"seq_len"
,
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
bs
}});
p
.
add_parameter
(
"seq_len"
,
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
bs
}});
auto
ih
=
p
.
add_parameter
(
"h0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
bs
,
hs
}});
auto
ih
=
p
.
add_parameter
(
"h0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
bs
,
hs
}});
auto
out_hs
=
auto
out_hs
=
p
.
add_instruction
(
p
.
add_instruction
(
migraphx
::
op
::
gru
{
hs
,
migraphx
::
op
::
gru
{
hs
,
{
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
gru
::
bidirectional
,
clip
},
{
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
gru
::
bidirectional
,
clip
},
seq
,
seq
,
w
,
w
,
r
,
r
,
...
@@ -895,8 +906,10 @@ TEST_CASE(gru_test)
...
@@ -895,8 +906,10 @@ TEST_CASE(gru_test)
auto
seq
=
auto
seq
=
p
.
add_parameter
(
"seq"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
sl
,
bs
,
is
}});
p
.
add_parameter
(
"seq"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
sl
,
bs
,
is
}});
auto
w
=
p
.
add_parameter
(
"w"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
is
}});
auto
w
=
auto
r
=
p
.
add_parameter
(
"r"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
hs
}});
p
.
add_parameter
(
"w"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
is
}});
auto
r
=
p
.
add_parameter
(
"r"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
hs
}});
auto
bias
=
auto
bias
=
p
.
add_parameter
(
"bias"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
6
*
hs
}});
p
.
add_parameter
(
"bias"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
6
*
hs
}});
auto
seq_len
=
auto
seq_len
=
...
@@ -927,16 +940,18 @@ TEST_CASE(gru_test)
...
@@ -927,16 +940,18 @@ TEST_CASE(gru_test)
auto
seq
=
auto
seq
=
p
.
add_parameter
(
"seq"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
sl
,
bs
,
is
}});
p
.
add_parameter
(
"seq"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
sl
,
bs
,
is
}});
auto
w
=
p
.
add_parameter
(
"w"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
is
}});
auto
w
=
auto
r
=
p
.
add_parameter
(
"r"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
hs
}});
p
.
add_parameter
(
"w"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
is
}});
auto
r
=
p
.
add_parameter
(
"r"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
hs
}});
auto
bias
=
auto
bias
=
p
.
add_parameter
(
"bias"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
6
*
hs
}});
p
.
add_parameter
(
"bias"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
6
*
hs
}});
auto
seq_len
=
auto
seq_len
=
p
.
add_parameter
(
"seq_len"
,
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
bs
}});
p
.
add_parameter
(
"seq_len"
,
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
bs
}});
auto
ih
=
p
.
add_parameter
(
"h0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
bs
,
hs
}});
auto
ih
=
p
.
add_parameter
(
"h0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
bs
,
hs
}});
auto
out_hs
=
auto
out_hs
=
p
.
add_instruction
(
p
.
add_instruction
(
migraphx
::
op
::
gru
{
hs
,
migraphx
::
op
::
gru
{
hs
,
{
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{}},
{
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
gru
::
bidirectional
,
migraphx
::
op
::
gru
::
bidirectional
,
clip
},
clip
},
...
@@ -959,19 +974,17 @@ TEST_CASE(gru_test)
...
@@ -959,19 +974,17 @@ TEST_CASE(gru_test)
auto
seq
=
auto
seq
=
p
.
add_parameter
(
"seq"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
sl
,
bs
,
is
}});
p
.
add_parameter
(
"seq"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
sl
,
bs
,
is
}});
auto
w
=
p
.
add_parameter
(
"w"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
is
}});
auto
w
=
auto
r
=
p
.
add_parameter
(
"r"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
hs
}});
p
.
add_parameter
(
"w"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
is
}});
auto
r
=
p
.
add_parameter
(
"r"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
hs
}});
auto
bias
=
auto
bias
=
p
.
add_parameter
(
"bias"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
6
*
hs
}});
p
.
add_parameter
(
"bias"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
6
*
hs
}});
auto
seq_len
=
auto
seq_len
=
p
.
add_parameter
(
"seq_len"
,
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
bs
}});
p
.
add_parameter
(
"seq_len"
,
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
bs
}});
auto
ih
=
p
.
add_parameter
(
"h0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
bs
,
hs
}});
auto
ih
=
p
.
add_parameter
(
"h0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
bs
,
hs
}});
auto
out_hs
=
auto
out_hs
=
p
.
add_instruction
(
migraphx
::
op
::
gru
{
hs
,
{},
migraphx
::
op
::
gru
::
forward
,
clip
},
p
.
add_instruction
(
migraphx
::
op
::
gru
{
hs
,
{},
migraphx
::
op
::
gru
::
forward
,
clip
},
seq
,
seq
,
w
,
w
,
r
,
r
,
...
@@ -991,19 +1004,18 @@ TEST_CASE(gru_test)
...
@@ -991,19 +1004,18 @@ TEST_CASE(gru_test)
auto
seq
=
auto
seq
=
p
.
add_parameter
(
"seq"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
sl
,
bs
,
is
}});
p
.
add_parameter
(
"seq"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
sl
,
bs
,
is
}});
auto
w
=
p
.
add_parameter
(
"w"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
is
}});
auto
w
=
auto
r
=
p
.
add_parameter
(
"r"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
hs
}});
p
.
add_parameter
(
"w"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
is
}});
auto
r
=
p
.
add_parameter
(
"r"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
,
hs
}});
auto
bias
=
auto
bias
=
p
.
add_parameter
(
"bias"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
6
*
hs
}});
p
.
add_parameter
(
"bias"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
6
*
hs
}});
auto
seq_len
=
auto
seq_len
=
p
.
add_parameter
(
"seq_len"
,
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
bs
}});
p
.
add_parameter
(
"seq_len"
,
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
bs
}});
auto
ih
=
p
.
add_parameter
(
"h0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
bs
,
hs
}});
auto
ih
=
p
.
add_parameter
(
"h0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
bs
,
hs
}});
auto
out_hs
=
auto
out_hs
=
p
.
add_instruction
(
p
.
add_instruction
(
migraphx
::
op
::
gru
{
hs
,
migraphx
::
op
::
gru
{
hs
,
{
migraphx
::
op
::
relu
{}},
migraphx
::
op
::
gru
::
reverse
,
clip
},
{
migraphx
::
op
::
relu
{}},
migraphx
::
op
::
gru
::
reverse
,
clip
},
seq
,
seq
,
w
,
w
,
r
,
r
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment