Commit d63dbddc authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into pass_manager

parents 6a141efa c2527321
...@@ -31,6 +31,8 @@ enum class rnn_direction ...@@ -31,6 +31,8 @@ enum class rnn_direction
bidirectional, bidirectional,
}; };
std::ostream& operator<<(std::ostream& os, rnn_direction v);
} // namespace op } // namespace op
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -19,6 +19,13 @@ namespace op { ...@@ -19,6 +19,13 @@ namespace op {
struct gather struct gather
{ {
int axis = 0; int axis = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
}
std::string name() const { return "gather"; } std::string name() const { return "gather"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
......
...@@ -27,6 +27,16 @@ struct gru ...@@ -27,6 +27,16 @@ struct gru
float clip = 0.0f; float clip = 0.0f;
int linear_before_reset = 0; int linear_before_reset = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.hidden_size, "hidden_size"),
f(self.actv_funcs, "actv_func"),
f(self.direction, "direction"),
f(self.clip, "clip"),
f(self.linear_before_reset, "linear_before_reset"));
}
std::string name() const { return "gru"; } std::string name() const { return "gru"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
......
...@@ -19,6 +19,13 @@ namespace op { ...@@ -19,6 +19,13 @@ namespace op {
struct logsoftmax struct logsoftmax
{ {
int axis = 1; int axis = 1;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
}
std::string name() const { return "logsoftmax"; } std::string name() const { return "logsoftmax"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
......
...@@ -25,6 +25,15 @@ struct lstm ...@@ -25,6 +25,15 @@ struct lstm
float clip = 0.0f; float clip = 0.0f;
int input_forget = 0; int input_forget = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.hidden_size, "hidden_size"),
f(self.actv_funcs, "actv_func"),
f(self.direction, "direction"),
f(self.input_forget, "input_forget"));
}
std::string name() const { return "lstm"; } std::string name() const { return "lstm"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
......
...@@ -25,6 +25,15 @@ struct rnn ...@@ -25,6 +25,15 @@ struct rnn
rnn_direction direction = rnn_direction::forward; rnn_direction direction = rnn_direction::forward;
float clip = 0.0f; float clip = 0.0f;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.hidden_size, "hidden_size"),
f(self.actv_funcs, "actv_func"),
f(self.direction, "direction"),
f(self.clip, "clip"));
}
std::string name() const { return "rnn"; } std::string name() const { return "rnn"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/op/common.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -1166,5 +1167,14 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const ...@@ -1166,5 +1167,14 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
} }
} }
namespace op {
std::ostream& operator<<(std::ostream& os, rnn_direction v)
{
std::vector<std::string> rnn_direction_str = {"forward", "reverse", "bidirectional"};
os << rnn_direction_str[static_cast<std::underlying_type<rnn_direction>::type>(v)];
return os;
}
} // namespace op
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -363,6 +363,16 @@ struct tf_parser ...@@ -363,6 +363,16 @@ struct tf_parser
int64_t axis = 0; int64_t axis = 0;
if(contains(attributes, "axis")) if(contains(attributes, "axis"))
axis = attributes.at("axis").i(); axis = attributes.at("axis").i();
size_t input_size = args.front()->get_shape().lens().size();
if(axis > input_size)
{
MIGRAPHX_THROW("TF_PARSER: axis value of " + to_string(axis) +
" must be smaller than input size " + to_string(input_size));
}
// check if input arg needs axis to be converted to NCHW
if(input_size >= 4)
axis = parse_axis(axis);
std::transform( std::transform(
args.begin(), args.begin(),
args.end(), args.end(),
......
This diff is collapsed.
:
0 Placeholder*
dtype0*
shape:
:
1 Placeholder*
dtype0*
shape:
:
2 Placeholder*
dtype0*
shape:
4
pack1Pack012*
T0*
axis*
N"
\ No newline at end of file
...@@ -151,6 +151,28 @@ TEST_CASE(pack_test) ...@@ -151,6 +151,28 @@ TEST_CASE(pack_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(pack_test_nhwc)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto l2 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
std::vector<migraphx::instruction_ref> args{l0, l1, l2};
std::vector<migraphx::instruction_ref> unsqueezed_args;
int64_t nchw_axis = 1;
std::transform(args.begin(),
args.end(),
std::back_inserter(unsqueezed_args),
[&](migraphx::instruction_ref arg) {
return p.add_instruction(migraphx::op::unsqueeze{{nchw_axis}}, arg);
});
p.add_instruction(migraphx::op::concat{static_cast<size_t>(nchw_axis)}, unsqueezed_args);
auto prog = migraphx::parse_tf("pack_test_nhwc.pb", true);
EXPECT(p == prog);
}
TEST_CASE(pooling_test) TEST_CASE(pooling_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment