Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
c2527321
Unverified
Commit
c2527321
authored
Apr 16, 2019
by
mvermeulen
Committed by
GitHub
Apr 16, 2019
Browse files
Merge pull request #240 from ROCmSoftwarePlatform/add_reflect_attribute
Add reflect attribute
parents
0861bb2a
98338dea
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
311 additions
and
174 deletions
+311
-174
src/include/migraphx/op/common.hpp
src/include/migraphx/op/common.hpp
+2
-0
src/include/migraphx/op/gather.hpp
src/include/migraphx/op/gather.hpp
+7
-0
src/include/migraphx/op/gru.hpp
src/include/migraphx/op/gru.hpp
+10
-0
src/include/migraphx/op/logsoftmax.hpp
src/include/migraphx/op/logsoftmax.hpp
+7
-0
src/include/migraphx/op/lstm.hpp
src/include/migraphx/op/lstm.hpp
+9
-0
src/include/migraphx/op/rnn.hpp
src/include/migraphx/op/rnn.hpp
+9
-0
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+10
-0
test/onnx/onnx_gru_reverse.onnx
test/onnx/onnx_gru_reverse.onnx
+0
-0
test/onnx/onnx_rnn_3args.onnx
test/onnx/onnx_rnn_3args.onnx
+0
-0
test/onnx/onnx_rnn_test.cpp
test/onnx/onnx_rnn_test.cpp
+257
-174
No files found.
src/include/migraphx/op/common.hpp
View file @
c2527321
...
@@ -31,6 +31,8 @@ enum class rnn_direction
...
@@ -31,6 +31,8 @@ enum class rnn_direction
bidirectional
,
bidirectional
,
};
};
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
rnn_direction
v
);
}
// namespace op
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/op/gather.hpp
View file @
c2527321
...
@@ -19,6 +19,13 @@ namespace op {
...
@@ -19,6 +19,13 @@ namespace op {
struct
gather
struct
gather
{
{
int
axis
=
0
;
int
axis
=
0
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
axis
,
"axis"
));
}
std
::
string
name
()
const
{
return
"gather"
;
}
std
::
string
name
()
const
{
return
"gather"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
...
...
src/include/migraphx/op/gru.hpp
View file @
c2527321
...
@@ -27,6 +27,16 @@ struct gru
...
@@ -27,6 +27,16 @@ struct gru
float
clip
=
0.0
f
;
float
clip
=
0.0
f
;
int
linear_before_reset
=
0
;
int
linear_before_reset
=
0
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
hidden_size
,
"hidden_size"
),
f
(
self
.
actv_funcs
,
"actv_func"
),
f
(
self
.
direction
,
"direction"
),
f
(
self
.
clip
,
"clip"
),
f
(
self
.
linear_before_reset
,
"linear_before_reset"
));
}
std
::
string
name
()
const
{
return
"gru"
;
}
std
::
string
name
()
const
{
return
"gru"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
...
...
src/include/migraphx/op/logsoftmax.hpp
View file @
c2527321
...
@@ -19,6 +19,13 @@ namespace op {
...
@@ -19,6 +19,13 @@ namespace op {
struct
logsoftmax
struct
logsoftmax
{
{
int
axis
=
1
;
int
axis
=
1
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
axis
,
"axis"
));
}
std
::
string
name
()
const
{
return
"logsoftmax"
;
}
std
::
string
name
()
const
{
return
"logsoftmax"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
...
...
src/include/migraphx/op/lstm.hpp
View file @
c2527321
...
@@ -25,6 +25,15 @@ struct lstm
...
@@ -25,6 +25,15 @@ struct lstm
float
clip
=
0.0
f
;
float
clip
=
0.0
f
;
int
input_forget
=
0
;
int
input_forget
=
0
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
hidden_size
,
"hidden_size"
),
f
(
self
.
actv_funcs
,
"actv_func"
),
f
(
self
.
direction
,
"direction"
),
f
(
self
.
input_forget
,
"input_forget"
));
}
std
::
string
name
()
const
{
return
"lstm"
;
}
std
::
string
name
()
const
{
return
"lstm"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
...
...
src/include/migraphx/op/rnn.hpp
View file @
c2527321
...
@@ -25,6 +25,15 @@ struct rnn
...
@@ -25,6 +25,15 @@ struct rnn
rnn_direction
direction
=
rnn_direction
::
forward
;
rnn_direction
direction
=
rnn_direction
::
forward
;
float
clip
=
0.0
f
;
float
clip
=
0.0
f
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
hidden_size
,
"hidden_size"
),
f
(
self
.
actv_funcs
,
"actv_func"
),
f
(
self
.
direction
,
"direction"
),
f
(
self
.
clip
,
"clip"
));
}
std
::
string
name
()
const
{
return
"rnn"
;
}
std
::
string
name
()
const
{
return
"rnn"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
...
...
src/rewrite_rnn.cpp
View file @
c2527321
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#include <migraphx/operators.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/op/common.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -1166,5 +1167,14 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
...
@@ -1166,5 +1167,14 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
}
}
}
}
namespace
op
{
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
rnn_direction
v
)
{
std
::
vector
<
std
::
string
>
rnn_direction_str
=
{
"forward"
,
"reverse"
,
"bidirectional"
};
os
<<
rnn_direction_str
[
static_cast
<
std
::
underlying_type
<
rnn_direction
>::
type
>
(
v
)];
return
os
;
}
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
test/onnx/onnx_gru_reverse.onnx
View file @
c2527321
No preview for this file type
test/onnx/onnx_rnn_3args.onnx
View file @
c2527321
No preview for this file type
test/onnx/onnx_rnn_test.cpp
View file @
c2527321
...
@@ -154,7 +154,7 @@ TEST_CASE(rnn_test_one_direction)
...
@@ -154,7 +154,7 @@ TEST_CASE(rnn_test_one_direction)
auto
out_hs
=
auto
out_hs
=
p
.
add_instruction
(
migraphx
::
op
::
rnn
{
hs
,
p
.
add_instruction
(
migraphx
::
op
::
rnn
{
hs
,
{
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
sigmoid
{}},
{
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
sigmoid
{}},
migraphx
::
op
::
rnn_direction
::
reverse
,
migraphx
::
op
::
rnn_direction
::
forward
,
clip
},
clip
},
seq
,
seq
,
w
,
w
,
...
@@ -339,7 +339,7 @@ TEST_CASE(gru_test_args)
...
@@ -339,7 +339,7 @@ TEST_CASE(gru_test_args)
auto
out_hs
=
auto
out_hs
=
p
.
add_instruction
(
migraphx
::
op
::
gru
{
hs
,
p
.
add_instruction
(
migraphx
::
op
::
gru
{
hs
,
{
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
sigmoid
{}},
{
migraphx
::
op
::
relu
{},
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
rnn_direction
::
reverse
,
migraphx
::
op
::
rnn_direction
::
reverse
,
clip
},
clip
},
seq
,
seq
,
...
@@ -373,7 +373,10 @@ TEST_CASE(gru_test_args)
...
@@ -373,7 +373,10 @@ TEST_CASE(gru_test_args)
auto
out_hs
=
auto
out_hs
=
p
.
add_instruction
(
migraphx
::
op
::
gru
{
hs
,
p
.
add_instruction
(
migraphx
::
op
::
gru
{
hs
,
{
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
sigmoid
{}},
{
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
relu
{},
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
rnn_direction
::
bidirectional
,
migraphx
::
op
::
rnn_direction
::
bidirectional
,
clip
},
clip
},
seq
,
seq
,
...
@@ -414,8 +417,14 @@ TEST_CASE(gru_test_actv_funcs)
...
@@ -414,8 +417,14 @@ TEST_CASE(gru_test_actv_funcs)
p
.
add_parameter
(
"seq_len"
,
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
bs
}});
p
.
add_parameter
(
"seq_len"
,
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
bs
}});
auto
ih
=
p
.
add_parameter
(
"h0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
bs
,
hs
}});
auto
ih
=
p
.
add_parameter
(
"h0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
bs
,
hs
}});
auto
out_hs
=
p
.
add_instruction
(
auto
out_hs
=
migraphx
::
op
::
gru
{
hs
,
{},
migraphx
::
op
::
rnn_direction
::
bidirectional
,
clip
},
p
.
add_instruction
(
migraphx
::
op
::
gru
{
hs
,
{
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
rnn_direction
::
bidirectional
,
clip
},
seq
,
seq
,
w
,
w
,
r
,
r
,
...
@@ -445,9 +454,14 @@ TEST_CASE(gru_test_actv_funcs)
...
@@ -445,9 +454,14 @@ TEST_CASE(gru_test_actv_funcs)
p
.
add_parameter
(
"seq_len"
,
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
bs
}});
p
.
add_parameter
(
"seq_len"
,
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
bs
}});
auto
ih
=
p
.
add_parameter
(
"h0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
bs
,
hs
}});
auto
ih
=
p
.
add_parameter
(
"h0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
bs
,
hs
}});
auto
out_hs
=
p
.
add_instruction
(
auto
out_hs
=
migraphx
::
op
::
gru
{
p
.
add_instruction
(
migraphx
::
op
::
gru
{
hs
,
hs
,
{
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
rnn_direction
::
bidirectional
,
clip
},
{
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
sigmoid
{}},
migraphx
::
op
::
rnn_direction
::
bidirectional
,
clip
},
seq
,
seq
,
w
,
w
,
r
,
r
,
...
@@ -479,7 +493,10 @@ TEST_CASE(gru_test_actv_funcs)
...
@@ -479,7 +493,10 @@ TEST_CASE(gru_test_actv_funcs)
auto
out_hs
=
auto
out_hs
=
p
.
add_instruction
(
migraphx
::
op
::
gru
{
hs
,
p
.
add_instruction
(
migraphx
::
op
::
gru
{
hs
,
{
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
sigmoid
{}},
{
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
sigmoid
{}},
migraphx
::
op
::
rnn_direction
::
bidirectional
,
migraphx
::
op
::
rnn_direction
::
bidirectional
,
clip
},
clip
},
seq
,
seq
,
...
@@ -511,9 +528,12 @@ TEST_CASE(gru_test_actv_funcs)
...
@@ -511,9 +528,12 @@ TEST_CASE(gru_test_actv_funcs)
p
.
add_parameter
(
"seq_len"
,
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
bs
}});
p
.
add_parameter
(
"seq_len"
,
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
bs
}});
auto
ih
=
p
.
add_parameter
(
"h0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
bs
,
hs
}});
auto
ih
=
p
.
add_parameter
(
"h0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
bs
,
hs
}});
auto
out_hs
=
p
.
add_instruction
(
auto
out_hs
=
migraphx
::
op
::
gru
{
hs
,
p
.
add_instruction
(
migraphx
::
op
::
gru
{
hs
,
{
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{}},
{
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
rnn_direction
::
bidirectional
,
migraphx
::
op
::
rnn_direction
::
bidirectional
,
clip
},
clip
},
seq
,
seq
,
...
@@ -546,7 +566,10 @@ TEST_CASE(gru_test_actv_funcs)
...
@@ -546,7 +566,10 @@ TEST_CASE(gru_test_actv_funcs)
auto
ih
=
p
.
add_parameter
(
"h0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
bs
,
hs
}});
auto
ih
=
p
.
add_parameter
(
"h0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
bs
,
hs
}});
auto
out_hs
=
auto
out_hs
=
p
.
add_instruction
(
migraphx
::
op
::
gru
{
hs
,
{},
migraphx
::
op
::
rnn_direction
::
forward
,
clip
},
p
.
add_instruction
(
migraphx
::
op
::
gru
{
hs
,
{
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
rnn_direction
::
forward
,
clip
},
seq
,
seq
,
w
,
w
,
r
,
r
,
...
@@ -576,9 +599,11 @@ TEST_CASE(gru_test_actv_funcs)
...
@@ -576,9 +599,11 @@ TEST_CASE(gru_test_actv_funcs)
p
.
add_parameter
(
"seq_len"
,
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
bs
}});
p
.
add_parameter
(
"seq_len"
,
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
bs
}});
auto
ih
=
p
.
add_parameter
(
"h0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
bs
,
hs
}});
auto
ih
=
p
.
add_parameter
(
"h0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
bs
,
hs
}});
auto
out_hs
=
p
.
add_instruction
(
auto
out_hs
=
migraphx
::
op
::
gru
{
p
.
add_instruction
(
migraphx
::
op
::
gru
{
hs
,
hs
,
{
migraphx
::
op
::
relu
{}},
migraphx
::
op
::
rnn_direction
::
reverse
,
clip
},
{
migraphx
::
op
::
relu
{},
migraphx
::
op
::
relu
{}},
migraphx
::
op
::
rnn_direction
::
reverse
,
clip
},
seq
,
seq
,
w
,
w
,
r
,
r
,
...
@@ -826,7 +851,12 @@ TEST_CASE(lstm_forward_actv_func)
...
@@ -826,7 +851,12 @@ TEST_CASE(lstm_forward_actv_func)
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
auto
out_hs
=
p
.
add_instruction
(
auto
out_hs
=
p
.
add_instruction
(
migraphx
::
op
::
lstm
{
hs
,
{},
migraphx
::
op
::
rnn_direction
::
forward
,
clip
,
input_forget
},
migraphx
::
op
::
lstm
{
hs
,
{
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
rnn_direction
::
forward
,
clip
,
input_forget
},
seq
,
seq
,
w
,
w
,
r
,
r
,
...
@@ -851,8 +881,10 @@ TEST_CASE(lstm_forward_actv_func)
...
@@ -851,8 +881,10 @@ TEST_CASE(lstm_forward_actv_func)
auto
bias
=
p
.
add_parameter
(
"bias"
,
bias_shape
);
auto
bias
=
p
.
add_parameter
(
"bias"
,
bias_shape
);
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
auto
out_hs
=
p
.
add_instruction
(
migraphx
::
op
::
lstm
{
hs
,
auto
out_hs
=
p
.
add_instruction
(
{
migraphx
::
op
::
sigmoid
{}},
migraphx
::
op
::
lstm
{
hs
,
{
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
sigmoid
{}},
migraphx
::
op
::
rnn_direction
::
forward
,
migraphx
::
op
::
rnn_direction
::
forward
,
clip
,
clip
,
input_forget
},
input_forget
},
...
@@ -881,9 +913,10 @@ TEST_CASE(lstm_forward_actv_func)
...
@@ -881,9 +913,10 @@ TEST_CASE(lstm_forward_actv_func)
auto
seq_len
=
p
.
add_parameter
(
"seq_len"
,
sl_shape
);
auto
seq_len
=
p
.
add_parameter
(
"seq_len"
,
sl_shape
);
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
auto
out_hs
=
auto
out_hs
=
p
.
add_instruction
(
p
.
add_instruction
(
migraphx
::
op
::
lstm
{
hs
,
migraphx
::
op
::
lstm
{
{
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
sigmoid
{}},
hs
,
{
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
sigmoid
{}},
migraphx
::
op
::
rnn_direction
::
forward
,
migraphx
::
op
::
rnn_direction
::
forward
,
clip
,
clip
,
input_forget
},
input_forget
},
...
@@ -993,7 +1026,12 @@ TEST_CASE(lstm_reverse)
...
@@ -993,7 +1026,12 @@ TEST_CASE(lstm_reverse)
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
auto
out_hs
=
p
.
add_instruction
(
auto
out_hs
=
p
.
add_instruction
(
migraphx
::
op
::
lstm
{
hs
,
{},
migraphx
::
op
::
rnn_direction
::
forward
,
clip
,
input_forget
},
migraphx
::
op
::
lstm
{
hs
,
{
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
rnn_direction
::
reverse
,
clip
,
input_forget
},
seq
,
seq
,
w
,
w
,
r
,
r
,
...
@@ -1037,10 +1075,14 @@ TEST_CASE(lstm_bidirectional)
...
@@ -1037,10 +1075,14 @@ TEST_CASE(lstm_bidirectional)
auto
ic
=
p
.
add_parameter
(
"c0"
,
ih_shape
);
auto
ic
=
p
.
add_parameter
(
"c0"
,
ih_shape
);
auto
pph
=
p
.
add_parameter
(
"pph"
,
pph_shape
);
auto
pph
=
p
.
add_parameter
(
"pph"
,
pph_shape
);
auto
out_hs
=
p
.
add_instruction
(
auto
out_hs
=
migraphx
::
op
::
lstm
{
p
.
add_instruction
(
migraphx
::
op
::
lstm
{
hs
,
hs
,
{
migraphx
::
op
::
sigmoid
{},
{
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
rnn_direction
::
bidirectional
,
migraphx
::
op
::
rnn_direction
::
bidirectional
,
clip
,
clip
,
input_forget
},
input_forget
},
...
@@ -1067,10 +1109,14 @@ TEST_CASE(lstm_bidirectional)
...
@@ -1067,10 +1109,14 @@ TEST_CASE(lstm_bidirectional)
auto
r
=
p
.
add_parameter
(
"r"
,
r_shape
);
auto
r
=
p
.
add_parameter
(
"r"
,
r_shape
);
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
auto
out_hs
=
p
.
add_instruction
(
auto
out_hs
=
migraphx
::
op
::
lstm
{
p
.
add_instruction
(
migraphx
::
op
::
lstm
{
hs
,
hs
,
{
migraphx
::
op
::
sigmoid
{},
{
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
rnn_direction
::
bidirectional
,
migraphx
::
op
::
rnn_direction
::
bidirectional
,
clip
,
clip
,
input_forget
},
input_forget
},
...
@@ -1098,10 +1144,14 @@ TEST_CASE(lstm_bidirectional)
...
@@ -1098,10 +1144,14 @@ TEST_CASE(lstm_bidirectional)
auto
bias
=
p
.
add_parameter
(
"bias"
,
bias_shape
);
auto
bias
=
p
.
add_parameter
(
"bias"
,
bias_shape
);
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
auto
out_hs
=
p
.
add_instruction
(
auto
out_hs
=
migraphx
::
op
::
lstm
{
p
.
add_instruction
(
migraphx
::
op
::
lstm
{
hs
,
hs
,
{
migraphx
::
op
::
sigmoid
{},
{
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
rnn_direction
::
bidirectional
,
migraphx
::
op
::
rnn_direction
::
bidirectional
,
clip
,
clip
,
input_forget
},
input_forget
},
...
@@ -1130,10 +1180,14 @@ TEST_CASE(lstm_bidirectional)
...
@@ -1130,10 +1180,14 @@ TEST_CASE(lstm_bidirectional)
auto
seq_len
=
p
.
add_parameter
(
"seq_len"
,
sl_shape
);
auto
seq_len
=
p
.
add_parameter
(
"seq_len"
,
sl_shape
);
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
auto
out_hs
=
p
.
add_instruction
(
auto
out_hs
=
migraphx
::
op
::
lstm
{
p
.
add_instruction
(
migraphx
::
op
::
lstm
{
hs
,
hs
,
{
migraphx
::
op
::
sigmoid
{},
{
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
rnn_direction
::
bidirectional
,
migraphx
::
op
::
rnn_direction
::
bidirectional
,
clip
,
clip
,
input_forget
},
input_forget
},
...
@@ -1163,10 +1217,14 @@ TEST_CASE(lstm_bidirectional)
...
@@ -1163,10 +1217,14 @@ TEST_CASE(lstm_bidirectional)
auto
ih
=
p
.
add_parameter
(
"h0"
,
ih_shape
);
auto
ih
=
p
.
add_parameter
(
"h0"
,
ih_shape
);
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
auto
out_hs
=
p
.
add_instruction
(
auto
out_hs
=
migraphx
::
op
::
lstm
{
p
.
add_instruction
(
migraphx
::
op
::
lstm
{
hs
,
hs
,
{
migraphx
::
op
::
sigmoid
{},
{
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
rnn_direction
::
bidirectional
,
migraphx
::
op
::
rnn_direction
::
bidirectional
,
clip
,
clip
,
input_forget
},
input_forget
},
...
@@ -1197,10 +1255,14 @@ TEST_CASE(lstm_bidirectional)
...
@@ -1197,10 +1255,14 @@ TEST_CASE(lstm_bidirectional)
auto
ic
=
p
.
add_parameter
(
"c0"
,
ih_shape
);
auto
ic
=
p
.
add_parameter
(
"c0"
,
ih_shape
);
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
auto
out_hs
=
p
.
add_instruction
(
auto
out_hs
=
migraphx
::
op
::
lstm
{
p
.
add_instruction
(
migraphx
::
op
::
lstm
{
hs
,
hs
,
{
migraphx
::
op
::
sigmoid
{},
{
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
rnn_direction
::
bidirectional
,
migraphx
::
op
::
rnn_direction
::
bidirectional
,
clip
,
clip
,
input_forget
},
input_forget
},
...
@@ -1244,9 +1306,17 @@ TEST_CASE(lstm_bi_actv_funcs)
...
@@ -1244,9 +1306,17 @@ TEST_CASE(lstm_bi_actv_funcs)
auto
r
=
p
.
add_parameter
(
"r"
,
r_shape
);
auto
r
=
p
.
add_parameter
(
"r"
,
r_shape
);
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
auto
out_hs
=
p
.
add_instruction
(
auto
out_hs
=
migraphx
::
op
::
lstm
{
p
.
add_instruction
(
migraphx
::
op
::
lstm
{
hs
,
hs
,
{},
migraphx
::
op
::
rnn_direction
::
bidirectional
,
clip
,
input_forget
},
{
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
rnn_direction
::
bidirectional
,
clip
,
input_forget
},
seq
,
seq
,
w
,
w
,
r
,
r
,
...
@@ -1273,7 +1343,12 @@ TEST_CASE(lstm_bi_actv_funcs)
...
@@ -1273,7 +1343,12 @@ TEST_CASE(lstm_bi_actv_funcs)
auto
out_hs
=
auto
out_hs
=
p
.
add_instruction
(
migraphx
::
op
::
lstm
{
hs
,
p
.
add_instruction
(
migraphx
::
op
::
lstm
{
hs
,
{
migraphx
::
op
::
sigmoid
{}},
{
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
sigmoid
{}},
migraphx
::
op
::
rnn_direction
::
bidirectional
,
migraphx
::
op
::
rnn_direction
::
bidirectional
,
clip
,
clip
,
input_forget
},
input_forget
},
...
@@ -1304,7 +1379,12 @@ TEST_CASE(lstm_bi_actv_funcs)
...
@@ -1304,7 +1379,12 @@ TEST_CASE(lstm_bi_actv_funcs)
auto
out_hs
=
auto
out_hs
=
p
.
add_instruction
(
migraphx
::
op
::
lstm
{
hs
,
p
.
add_instruction
(
migraphx
::
op
::
lstm
{
hs
,
{
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{}},
{
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
rnn_direction
::
bidirectional
,
migraphx
::
op
::
rnn_direction
::
bidirectional
,
clip
,
clip
,
input_forget
},
input_forget
},
...
@@ -1337,6 +1417,8 @@ TEST_CASE(lstm_bi_actv_funcs)
...
@@ -1337,6 +1417,8 @@ TEST_CASE(lstm_bi_actv_funcs)
auto
out_hs
=
auto
out_hs
=
p
.
add_instruction
(
migraphx
::
op
::
lstm
{
hs
,
p
.
add_instruction
(
migraphx
::
op
::
lstm
{
hs
,
{
migraphx
::
op
::
sigmoid
{},
{
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
tanh
{}},
...
@@ -1376,6 +1458,7 @@ TEST_CASE(lstm_bi_actv_funcs)
...
@@ -1376,6 +1458,7 @@ TEST_CASE(lstm_bi_actv_funcs)
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
sigmoid
{}},
migraphx
::
op
::
sigmoid
{}},
migraphx
::
op
::
rnn_direction
::
bidirectional
,
migraphx
::
op
::
rnn_direction
::
bidirectional
,
clip
,
clip
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment