"src/targets/vscode:/vscode.git/clone" did not exist on "c8aa00bfede8a070dffe6ed83a8ad07abd95a142"
Unverified Commit 51acbdef authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Fixes in order to refactor the everything to make_op (#675)



* Use generic op for eliminate_pad

* Formatting

* Improve error when loading a missing operator

* Add more enum tests

* Add more tests for constructing an op

* Formatting

* Fix failed tests

* Avoid duplicate branches

* Format file

* Default initialize variable
Co-authored-by: default avatarShucai Xiao <shucai.xiao@amd.com>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 405f30a0
......@@ -21,18 +21,14 @@ void eliminate_pad::apply(module& p) const
auto input = ins->inputs().front();
if(input->name() != "pad")
continue;
if(op_name == "convolution")
update_op(op::convolution{}, input, ins, p);
else if(op_name == "im2col")
update_op(op::im2col{}, input, ins, p);
if(op_name == "convolution" or op_name == "im2col")
update_op(input, ins, p);
else if(op_name == "pooling")
update_pooling(input, ins, p);
}
}
template <class T>
void eliminate_pad::update_op(T,
const instruction_ref& input,
void eliminate_pad::update_op(const instruction_ref& input,
const instruction_ref& ins,
module& p) const
{
......@@ -45,8 +41,8 @@ void eliminate_pad::update_op(T,
std::vector<size_t> new_pads(kdims_it, kdims_it + kdims);
T op = any_cast<T>(ins->get_operator());
op.padding = new_pads;
auto op = ins->get_operator();
op.from_value({"padding", new_pads});
std::vector<instruction_ref> new_inputs{ins->inputs()};
new_inputs.front() = input->inputs().front();
......
......@@ -20,9 +20,9 @@ using module = program;
struct eliminate_pad
{
std::string name() const { return "eliminate_pad"; }
void apply(module& p) const;
template <class T>
void update_op(T, const instruction_ref& input, const instruction_ref& ins, module& p) const;
void update_op(const instruction_ref& input, const instruction_ref& ins, module& p) const;
void update_pooling(const instruction_ref& input, const instruction_ref& ins, module& p) const;
};
......
......@@ -24,7 +24,7 @@ struct rnn_var_sl_shift_output
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.output_name, "hidden_states"), f(self.direction, "direction"));
return pack(f(self.output_name, "output_name"), f(self.direction, "direction"));
}
std::string name() const { return "rnn_var_sl_shift_output"; }
......
......@@ -24,7 +24,7 @@ void from_value(const value& v, T& x);
template <class T>
T from_value(const value& v)
{
T x;
T x{};
from_value(v, x);
return x;
}
......@@ -143,7 +143,8 @@ void from_value_impl(rank<3>, const value& v, T& x)
{
reflect_each(x, [&](auto& y, const std::string& name) {
using type = std::decay_t<decltype(y)>;
y = from_value<type>(v.at(name).without_key());
if(v.contains(name))
y = from_value<type>(v.at(name).without_key());
});
}
......
......@@ -7,6 +7,8 @@ inline namespace MIGRAPHX_INLINE_NS {
operation make_op(const std::string& name) { return load_op(name); }
operation make_op(const std::string& name, const value& v)
{
if(not(v.is_object() or (v.empty() and v.is_array())))
MIGRAPHX_THROW("Value is not an object");
auto op = load_op(name);
// Merge values
value w = op.to_value();
......
......@@ -10,7 +10,13 @@ std::unordered_map<std::string, operation>& op_map()
return m;
}
void register_op(const operation& op) { op_map()[op.name()] = op; }
operation load_op(const std::string& name) { return op_map().at(name); }
operation load_op(const std::string& name)
{
auto it = op_map().find(name);
if(it == op_map().end())
MIGRAPHX_THROW("Operator not found: " + name);
return it->second;
}
std::vector<std::string> get_operators()
{
......
......@@ -98,7 +98,7 @@ void set_vector(std::shared_ptr<value_base_impl>& x,
value::value(const std::initializer_list<value>& i) : x(nullptr)
{
if(i.size() == 2 and i.begin()->is_string())
if(i.size() == 2 and i.begin()->is_string() and i.begin()->get_key().empty())
{
key = i.begin()->get_string();
auto r = (i.begin() + 1)->x;
......
......@@ -2,6 +2,7 @@
#include <migraphx/operation.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/rnn_variable_seq_lens.hpp>
#include <sstream>
#include <string>
#include "test.hpp"
......@@ -39,6 +40,15 @@ TEST_CASE(make_op_from_value2)
EXPECT(x == y);
}
TEST_CASE(make_rnn_op_from_value)
{
migraphx::op::rnn_direction dirct = migraphx::op::rnn_direction::reverse;
migraphx::operation x = migraphx::make_op(
"rnn_var_sl_shift_output", {{"output_name", "hidden_states"}, {"direction", dirct}});
migraphx::operation y = migraphx::op::rnn_var_sl_shift_output{"hidden_states", dirct};
EXPECT(x == y);
}
TEST_CASE(make_op_invalid_key)
{
EXPECT(test::throws([] { migraphx::make_op("convolution", {{"paddings", {1, 1}}}); }));
......
......@@ -72,7 +72,7 @@ TEST_CASE(value_construct_bool)
EXPECT(v.get_key().empty());
}
TEST_CASE(value_construct_enum)
TEST_CASE(value_construct_enum1)
{
migraphx::value v = enum_type::a;
EXPECT(v.is_int64());
......@@ -81,6 +81,24 @@ TEST_CASE(value_construct_enum)
EXPECT(v.get_key().empty());
}
TEST_CASE(value_construct_enum2)
{
migraphx::value v = enum_type::b;
EXPECT(v.is_int64());
EXPECT(v.get_int64() == static_cast<std::uint64_t>(enum_type::b));
EXPECT(bool{v.to<enum_type>() == enum_type::b});
EXPECT(v.get_key().empty());
}
TEST_CASE(value_construct_enum3)
{
migraphx::value v = enum_type::c;
EXPECT(v.is_int64());
EXPECT(v.get_int64() == static_cast<std::uint64_t>(enum_type::c));
EXPECT(bool{v.to<enum_type>() == enum_type::c});
EXPECT(v.get_key().empty());
}
TEST_CASE(value_construct_empty_object)
{
migraphx::value v = migraphx::value::object{};
......@@ -467,6 +485,36 @@ TEST_CASE(value_emplace_object)
EXPECT(v["three"].get_key() == "three");
}
TEST_CASE(value_construct_object_string_value)
{
migraphx::value v = {{"one", "onev"}, {"two", "twov"}};
EXPECT(v.is_object());
EXPECT(v.size() == 2);
EXPECT(not v.empty());
EXPECT(v.data() != nullptr);
EXPECT(v.at("one").is_string());
EXPECT(v.at("one").get_key() == "one");
EXPECT(v.at("one").get_string() == "onev");
EXPECT(v.at("two").is_string());
EXPECT(v.at("two").get_key() == "two");
EXPECT(v.at("two").get_string() == "twov");
}
TEST_CASE(value_construct_object_string_mixed_value)
{
migraphx::value v = {{"one", "onev"}, {"two", 2}};
EXPECT(v.is_object());
EXPECT(v.size() == 2);
EXPECT(not v.empty());
EXPECT(v.data() != nullptr);
EXPECT(v.at("one").is_string());
EXPECT(v.at("one").get_key() == "one");
EXPECT(v.at("one").get_string() == "onev");
EXPECT(v.at("two").is_int64());
EXPECT(v.at("two").get_key() == "two");
EXPECT(v.at("two").get_int64() == 2);
}
TEST_CASE(value_compare)
{
EXPECT(migraphx::value(1) == migraphx::value(1));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment