Unverified Commit 51acbdef authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Fixes in order to refactor the everything to make_op (#675)



* Use generic op for eliminate_pad

* Formatting

* Improve error when loading a missing operator

* Add more enum tests

* Add more tests for constructing an op

* Formatting

* Fix failed tests

* Avoid duplicate branches

* Format file

* Default initialize variable
Co-authored-by: default avatarShucai Xiao <shucai.xiao@amd.com>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 405f30a0
...@@ -21,18 +21,14 @@ void eliminate_pad::apply(module& p) const ...@@ -21,18 +21,14 @@ void eliminate_pad::apply(module& p) const
auto input = ins->inputs().front(); auto input = ins->inputs().front();
if(input->name() != "pad") if(input->name() != "pad")
continue; continue;
if(op_name == "convolution") if(op_name == "convolution" or op_name == "im2col")
update_op(op::convolution{}, input, ins, p); update_op(input, ins, p);
else if(op_name == "im2col")
update_op(op::im2col{}, input, ins, p);
else if(op_name == "pooling") else if(op_name == "pooling")
update_pooling(input, ins, p); update_pooling(input, ins, p);
} }
} }
template <class T> void eliminate_pad::update_op(const instruction_ref& input,
void eliminate_pad::update_op(T,
const instruction_ref& input,
const instruction_ref& ins, const instruction_ref& ins,
module& p) const module& p) const
{ {
...@@ -45,8 +41,8 @@ void eliminate_pad::update_op(T, ...@@ -45,8 +41,8 @@ void eliminate_pad::update_op(T,
std::vector<size_t> new_pads(kdims_it, kdims_it + kdims); std::vector<size_t> new_pads(kdims_it, kdims_it + kdims);
T op = any_cast<T>(ins->get_operator()); auto op = ins->get_operator();
op.padding = new_pads; op.from_value({"padding", new_pads});
std::vector<instruction_ref> new_inputs{ins->inputs()}; std::vector<instruction_ref> new_inputs{ins->inputs()};
new_inputs.front() = input->inputs().front(); new_inputs.front() = input->inputs().front();
......
...@@ -20,9 +20,9 @@ using module = program; ...@@ -20,9 +20,9 @@ using module = program;
struct eliminate_pad struct eliminate_pad
{ {
std::string name() const { return "eliminate_pad"; } std::string name() const { return "eliminate_pad"; }
void apply(module& p) const; void apply(module& p) const;
template <class T> void update_op(const instruction_ref& input, const instruction_ref& ins, module& p) const;
void update_op(T, const instruction_ref& input, const instruction_ref& ins, module& p) const;
void update_pooling(const instruction_ref& input, const instruction_ref& ins, module& p) const; void update_pooling(const instruction_ref& input, const instruction_ref& ins, module& p) const;
}; };
......
...@@ -24,7 +24,7 @@ struct rnn_var_sl_shift_output ...@@ -24,7 +24,7 @@ struct rnn_var_sl_shift_output
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.output_name, "hidden_states"), f(self.direction, "direction")); return pack(f(self.output_name, "output_name"), f(self.direction, "direction"));
} }
std::string name() const { return "rnn_var_sl_shift_output"; } std::string name() const { return "rnn_var_sl_shift_output"; }
......
...@@ -24,7 +24,7 @@ void from_value(const value& v, T& x); ...@@ -24,7 +24,7 @@ void from_value(const value& v, T& x);
template <class T> template <class T>
T from_value(const value& v) T from_value(const value& v)
{ {
T x; T x{};
from_value(v, x); from_value(v, x);
return x; return x;
} }
...@@ -143,6 +143,7 @@ void from_value_impl(rank<3>, const value& v, T& x) ...@@ -143,6 +143,7 @@ void from_value_impl(rank<3>, const value& v, T& x)
{ {
reflect_each(x, [&](auto& y, const std::string& name) { reflect_each(x, [&](auto& y, const std::string& name) {
using type = std::decay_t<decltype(y)>; using type = std::decay_t<decltype(y)>;
if(v.contains(name))
y = from_value<type>(v.at(name).without_key()); y = from_value<type>(v.at(name).without_key());
}); });
} }
......
...@@ -7,6 +7,8 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -7,6 +7,8 @@ inline namespace MIGRAPHX_INLINE_NS {
operation make_op(const std::string& name) { return load_op(name); } operation make_op(const std::string& name) { return load_op(name); }
operation make_op(const std::string& name, const value& v) operation make_op(const std::string& name, const value& v)
{ {
if(not(v.is_object() or (v.empty() and v.is_array())))
MIGRAPHX_THROW("Value is not an object");
auto op = load_op(name); auto op = load_op(name);
// Merge values // Merge values
value w = op.to_value(); value w = op.to_value();
......
...@@ -10,7 +10,13 @@ std::unordered_map<std::string, operation>& op_map() ...@@ -10,7 +10,13 @@ std::unordered_map<std::string, operation>& op_map()
return m; return m;
} }
void register_op(const operation& op) { op_map()[op.name()] = op; } void register_op(const operation& op) { op_map()[op.name()] = op; }
operation load_op(const std::string& name) { return op_map().at(name); } operation load_op(const std::string& name)
{
auto it = op_map().find(name);
if(it == op_map().end())
MIGRAPHX_THROW("Operator not found: " + name);
return it->second;
}
std::vector<std::string> get_operators() std::vector<std::string> get_operators()
{ {
......
...@@ -98,7 +98,7 @@ void set_vector(std::shared_ptr<value_base_impl>& x, ...@@ -98,7 +98,7 @@ void set_vector(std::shared_ptr<value_base_impl>& x,
value::value(const std::initializer_list<value>& i) : x(nullptr) value::value(const std::initializer_list<value>& i) : x(nullptr)
{ {
if(i.size() == 2 and i.begin()->is_string()) if(i.size() == 2 and i.begin()->is_string() and i.begin()->get_key().empty())
{ {
key = i.begin()->get_string(); key = i.begin()->get_string();
auto r = (i.begin() + 1)->x; auto r = (i.begin() + 1)->x;
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <migraphx/operation.hpp> #include <migraphx/operation.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/convolution.hpp> #include <migraphx/op/convolution.hpp>
#include <migraphx/op/rnn_variable_seq_lens.hpp>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include "test.hpp" #include "test.hpp"
...@@ -39,6 +40,15 @@ TEST_CASE(make_op_from_value2) ...@@ -39,6 +40,15 @@ TEST_CASE(make_op_from_value2)
EXPECT(x == y); EXPECT(x == y);
} }
TEST_CASE(make_rnn_op_from_value)
{
migraphx::op::rnn_direction dirct = migraphx::op::rnn_direction::reverse;
migraphx::operation x = migraphx::make_op(
"rnn_var_sl_shift_output", {{"output_name", "hidden_states"}, {"direction", dirct}});
migraphx::operation y = migraphx::op::rnn_var_sl_shift_output{"hidden_states", dirct};
EXPECT(x == y);
}
TEST_CASE(make_op_invalid_key) TEST_CASE(make_op_invalid_key)
{ {
EXPECT(test::throws([] { migraphx::make_op("convolution", {{"paddings", {1, 1}}}); })); EXPECT(test::throws([] { migraphx::make_op("convolution", {{"paddings", {1, 1}}}); }));
......
...@@ -72,7 +72,7 @@ TEST_CASE(value_construct_bool) ...@@ -72,7 +72,7 @@ TEST_CASE(value_construct_bool)
EXPECT(v.get_key().empty()); EXPECT(v.get_key().empty());
} }
TEST_CASE(value_construct_enum) TEST_CASE(value_construct_enum1)
{ {
migraphx::value v = enum_type::a; migraphx::value v = enum_type::a;
EXPECT(v.is_int64()); EXPECT(v.is_int64());
...@@ -81,6 +81,24 @@ TEST_CASE(value_construct_enum) ...@@ -81,6 +81,24 @@ TEST_CASE(value_construct_enum)
EXPECT(v.get_key().empty()); EXPECT(v.get_key().empty());
} }
TEST_CASE(value_construct_enum2)
{
migraphx::value v = enum_type::b;
EXPECT(v.is_int64());
EXPECT(v.get_int64() == static_cast<std::uint64_t>(enum_type::b));
EXPECT(bool{v.to<enum_type>() == enum_type::b});
EXPECT(v.get_key().empty());
}
TEST_CASE(value_construct_enum3)
{
migraphx::value v = enum_type::c;
EXPECT(v.is_int64());
EXPECT(v.get_int64() == static_cast<std::uint64_t>(enum_type::c));
EXPECT(bool{v.to<enum_type>() == enum_type::c});
EXPECT(v.get_key().empty());
}
TEST_CASE(value_construct_empty_object) TEST_CASE(value_construct_empty_object)
{ {
migraphx::value v = migraphx::value::object{}; migraphx::value v = migraphx::value::object{};
...@@ -467,6 +485,36 @@ TEST_CASE(value_emplace_object) ...@@ -467,6 +485,36 @@ TEST_CASE(value_emplace_object)
EXPECT(v["three"].get_key() == "three"); EXPECT(v["three"].get_key() == "three");
} }
TEST_CASE(value_construct_object_string_value)
{
migraphx::value v = {{"one", "onev"}, {"two", "twov"}};
EXPECT(v.is_object());
EXPECT(v.size() == 2);
EXPECT(not v.empty());
EXPECT(v.data() != nullptr);
EXPECT(v.at("one").is_string());
EXPECT(v.at("one").get_key() == "one");
EXPECT(v.at("one").get_string() == "onev");
EXPECT(v.at("two").is_string());
EXPECT(v.at("two").get_key() == "two");
EXPECT(v.at("two").get_string() == "twov");
}
TEST_CASE(value_construct_object_string_mixed_value)
{
migraphx::value v = {{"one", "onev"}, {"two", 2}};
EXPECT(v.is_object());
EXPECT(v.size() == 2);
EXPECT(not v.empty());
EXPECT(v.data() != nullptr);
EXPECT(v.at("one").is_string());
EXPECT(v.at("one").get_key() == "one");
EXPECT(v.at("one").get_string() == "onev");
EXPECT(v.at("two").is_int64());
EXPECT(v.at("two").get_key() == "two");
EXPECT(v.at("two").get_int64() == 2);
}
TEST_CASE(value_compare) TEST_CASE(value_compare)
{ {
EXPECT(migraphx::value(1) == migraphx::value(1)); EXPECT(migraphx::value(1) == migraphx::value(1));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment