Unverified Commit faefeef9 authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Merge branch 'develop' into dyn_shape_update

parents 97a40ac3 bf0a4713
......@@ -19,7 +19,7 @@ struct eliminate_allocation
std::string allocation_op{};
std::size_t alignment = 32;
std::string name() const { return "eliminate_allocation"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -16,7 +16,7 @@ struct module;
struct eliminate_common_subexpression
{
std::string name() const { return "eliminate_common_subexpression"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -18,7 +18,7 @@ struct eliminate_concat
{
concat_optimization concat_opt;
std::string name() const { return "eliminate_concat"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -17,7 +17,7 @@ struct eliminate_contiguous
{
std::string op_name;
std::string name() const { return "eliminate_contiguous"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -18,7 +18,7 @@ struct module;
struct eliminate_identity
{
std::string name() const { return "eliminate_identity"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -9,7 +9,19 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
operation make_op(const std::string& name);
operation make_op(const std::string& name, const value& v);
operation make_op(const std::string& name,
const std::initializer_list<std::pair<std::string, value>>& v);
operation make_op_from_value(const std::string& name, const value& v);
// A template overload is added for migraphx::value so the initializer_list
// cannot be passed in directly. This is to enforce at compile-time that all
// initializer_list are key-value pairs, whereas migraphx::value allows other
// types of initializer_list such as for arrays.
template <class Value>
operation make_op(const std::string& name, const Value& v)
{
return make_op_from_value(name, v);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -156,6 +156,19 @@ struct id_matcher
}
};
// Forward declare class and constructors
template <class M>
struct basic_matcher;
template <class M>
basic_matcher<M> make_basic_matcher(M m);
template <class F>
basic_matcher<function_matcher<F>> make_basic_fun_matcher(F f);
template <class P>
basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p);
/// The basic matcher provides the all_of composability of the matcher
template <class M>
struct basic_matcher
......@@ -167,8 +180,8 @@ struct basic_matcher
{
// Copy m because we cant capture `this` by value
auto mm = m;
return make_bf_matcher([=](matcher_context& ctx,
instruction_ref ins) -> optional<instruction_ref> {
return make_basic_fun_matcher([=](matcher_context& ctx,
instruction_ref ins) -> optional<instruction_ref> {
auto result = mm.match(ctx, ins);
if(result)
{
......@@ -239,7 +252,39 @@ struct any_matcher : any_matcher_base
struct matcher_result
{
std::unordered_map<std::string, instruction_ref> instructions;
struct instruction_container
{
instruction_container() = default;
instruction_container(std::unordered_map<std::string, instruction_ref> x)
: ins_map(std::move(x))
{
}
instruction_ref operator[](const std::string& name) const
{
auto it = ins_map.find(name);
if(it == ins_map.end())
MIGRAPHX_THROW("Accessing name that wasn't bound in matcher: " + name);
return it->second;
}
auto find(const std::string& name) const { return ins_map.find(name); }
auto begin() const { return ins_map.cbegin(); }
auto end() const { return ins_map.cend(); }
bool has_instructions_in(const module& mod) const
{
return std::all_of(ins_map.begin(), ins_map.end(), [&](auto&& p) {
return mod.has_instruction(p.second);
});
}
private:
std::unordered_map<std::string, instruction_ref> ins_map;
};
instruction_container instructions;
instruction_ref result;
};
......@@ -255,6 +300,7 @@ matcher_result match_instruction(module& mod, instruction_ref ins, M&& m)
{
result.result = ins;
result.instructions = ctx.instructions;
assert(result.instructions.has_instructions_in(mod));
}
else
{
......@@ -533,6 +579,18 @@ auto skip_output(Ms... ms)
});
}
inline auto var(std::string s)
{
return make_basic_fun_matcher(
[=, s = std::move(s)](const matcher_context& ctx,
instruction_ref) -> optional<instruction_ref> {
auto it = ctx.instructions.find(s);
if(it == ctx.instructions.end())
return nullopt;
return it->second;
});
}
inline auto name(std::string s)
{
return make_basic_pred_matcher(
......@@ -696,10 +754,16 @@ auto skip_broadcasts(Ms... ms)
return skip(name("broadcast", "multibroadcast", "contiguous"))(ms...);
}
template <class... Ms>
auto skip_broadcasts_converts(Ms... ms)
{
return skip(name("broadcast", "multibroadcast", "contiguous", "convert"))(ms...);
}
template <class T>
inline auto has_value(T x, float tolerance = 1e-6)
{
return skip_broadcasts(make_basic_pred_matcher([=](instruction_ref ins) {
return skip_broadcasts_converts(make_basic_pred_matcher([=](instruction_ref ins) {
if(ins->name() != "@literal")
return false;
auto l = ins->get_literal();
......
......@@ -17,7 +17,7 @@ struct memory_coloring
std::string allocation_op{};
bool verify = false;
std::string name() const { return "memory coloring"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -15,7 +15,7 @@ struct module;
struct propagate_constant
{
std::string name() const { return "propagate_constant"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -16,7 +16,7 @@ struct module;
struct rewrite_batchnorm
{
std::string name() const { return "rewrite_batchnorm"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -15,7 +15,7 @@ struct module;
struct rewrite_pooling
{
std::string name() const { return "rewrite_pooling"; }
void apply(module& prog) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -19,22 +19,22 @@ struct module;
struct rewrite_rnn
{
std::string name() const { return "rewrite_rnn"; }
void apply(module& prog) const;
void apply(module& m) const;
private:
// for vanilla rnn operators
void apply_vanilla_rnn(module& prog, instruction_ref ins) const;
void apply_vanilla_rnn(module& m, instruction_ref ins) const;
std::vector<instruction_ref> vanilla_rnn_cell(bool is_forward,
module& prog,
module& m,
instruction_ref ins,
std::vector<instruction_ref> inputs,
operation& actv_func) const;
std::vector<operation> vanilla_rnn_actv_funcs(instruction_ref ins) const;
// for gru operators
void apply_gru(module& prog, instruction_ref ins) const;
void apply_gru(module& m, instruction_ref ins) const;
std::vector<instruction_ref> gru_cell(bool is_forward,
module& prog,
module& m,
instruction_ref ins,
std::vector<instruction_ref> inputs,
int linear_before_reset,
......@@ -44,9 +44,9 @@ struct rewrite_rnn
std::vector<operation> gru_actv_funcs(instruction_ref ins) const;
// for lstm operators
void apply_lstm(module& prog, instruction_ref ins) const;
void apply_lstm(module& m, instruction_ref ins) const;
std::vector<instruction_ref> lstm_cell(bool is_forward,
module& prog,
module& m,
instruction_ref ins,
std::vector<instruction_ref> inputs,
const operation& actv_func1,
......@@ -55,24 +55,23 @@ struct rewrite_rnn
std::vector<operation> lstm_actv_funcs(instruction_ref ins) const;
bool is_variable_seq_lens(const module& prog, instruction_ref seq_lens) const;
instruction_ref replace_last_hs_output(module& prog,
bool is_variable_seq_lens(const module& m, instruction_ref seq_lens) const;
instruction_ref replace_last_hs_output(module& m,
instruction_ref ins,
instruction_ref seq_lens,
instruction_ref last_hs_output,
op::rnn_direction dirct) const;
void replace_last_cell_output(module& prog,
void replace_last_cell_output(module& m,
instruction_ref ins,
instruction_ref seq_lens,
instruction_ref cell_outputs,
instruction_ref last_cell_output,
op::rnn_direction dirct) const;
std::size_t
get_seq_len(const module& prog, instruction_ref input, instruction_ref seq_lens) const;
std::size_t get_seq_len(const module& m, instruction_ref input, instruction_ref seq_lens) const;
instruction_ref pad_hidden_states(module& prog,
instruction_ref pad_hidden_states(module& m,
instruction_ref seq,
instruction_ref seq_lens,
instruction_ref hs) const;
......
......@@ -19,7 +19,7 @@ struct schedule
schedule_model model{};
bool enable = true;
std::string name() const { return "schedule"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -15,7 +15,7 @@ struct module;
struct simplify_algebra
{
std::string name() const { return "simplify_algebra"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -16,7 +16,7 @@ struct module;
struct simplify_reshapes
{
std::string name() const { return "simplify_reshapes"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -5,20 +5,41 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
operation make_op(const std::string& name) { return load_op(name); }
operation make_op(const std::string& name, const value& v)
template <class F>
operation make_op_generic(const std::string& name, F for_each)
{
if(not(v.is_object() or (v.empty() and v.is_array())))
MIGRAPHX_THROW("Value is not an object");
auto op = load_op(name);
// Merge values
value w = op.to_value();
for(auto&& x : v)
{
w.at(x.get_key()) = x.without_key();
}
for_each([&](const auto& key, const auto& x) {
if(not w.contains(key))
// NOLINTNEXTLINE(performance-inefficient-string-concatenation)
MIGRAPHX_THROW("No key '" + key + "' in " + name);
w.at(key) = x;
});
op.from_value(w);
return op;
}
operation make_op(const std::string& name,
const std::initializer_list<std::pair<std::string, value>>& v)
{
return make_op_generic(name, [&](auto f) {
for(auto&& [key, x] : v)
f(key, x);
});
}
operation make_op_from_value(const std::string& name, const value& v)
{
if(not(v.is_object() or (v.empty() and v.is_array())))
MIGRAPHX_THROW("Value is not an object for make_op: " + name);
return make_op_generic(name, [&](auto f) {
for(auto&& x : v)
f(x.get_key(), x.without_key());
});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -2,6 +2,7 @@
#include <migraphx/onnx/checks.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -9,6 +10,9 @@ namespace onnx {
struct parse_mean : op_parser<parse_mean>
{
const std::set<shape::type_t> float_types = {
shape::float_type, shape::half_type, shape::double_type};
std::vector<op_desc> operators() const { return {{"Mean"}}; }
/// Calculates the element-wise mean of n>=1 input tensors
......@@ -24,17 +28,29 @@ struct parse_mean : op_parser<parse_mean>
auto divisor = info.add_literal(
migraphx::literal{migraphx::shape{args[0]->get_shape().type()}, {num_data}});
// TODO: Only divide when using floating-point
return std::accumulate(args.begin() + 1,
args.end(),
info.add_broadcastable_binary_op("div", args[0], divisor),
[&](auto mean, auto data_i) {
// Pre-divide each tensor element-wise by n to reduce risk of
// overflow during summation
auto div =
info.add_broadcastable_binary_op("div", data_i, divisor);
return info.add_broadcastable_binary_op("add", mean, div);
});
if(contains(float_types, args[0]->get_shape().type()))
{
return std::accumulate(args.begin() + 1,
args.end(),
info.add_broadcastable_binary_op("div", args[0], divisor),
[&](auto mean, auto data_i) {
// Pre-divide each tensor element-wise by n to reduce risk of
// overflow during summation
auto div =
info.add_broadcastable_binary_op("div", data_i, divisor);
return info.add_broadcastable_binary_op("add", mean, div);
});
}
else
{
// Compute sum before division for integral types
auto sum = std::accumulate(
args.begin() + 1, args.end(), args[0], [&](auto accum, auto data_i) {
return info.add_broadcastable_binary_op("add", accum, data_i);
});
return info.add_broadcastable_binary_op("div", sum, divisor);
}
}
};
......
......@@ -4,11 +4,11 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void memory_coloring::apply(module& p) const
void memory_coloring::apply(module& m) const
{
if(!enabled(MIGRAPHX_DISABLE_MEMORY_COLORING{}))
{
memory_coloring_impl opt(&p, allocation_op, verify);
memory_coloring_impl opt(&m, allocation_op, verify);
opt.run();
}
}
......
......@@ -20,9 +20,9 @@ bool skip_propogate(instruction_ref ins)
return false;
}
void propagate_constant::apply(module& p) const
void propagate_constant::apply(module& m) const
{
for(auto i : iterator_for(p))
for(auto i : iterator_for(m))
{
if(i->name() != "@literal")
continue;
......@@ -42,8 +42,8 @@ void propagate_constant::apply(module& p) const
if(not r.empty())
{
assert(r.get_shape() == child->get_shape());
auto l = p.add_literal(r.get_shape(), r.data());
self(p.replace_instruction(child, l));
auto l = m.add_literal(r.get_shape(), r.data());
self(m.replace_instruction(child, l));
}
}
})(i);
......
......@@ -273,6 +273,14 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::arg("op"),
py::arg("args"),
py::arg("mod_args") = std::vector<migraphx::module*>{})
.def(
"add_literal",
[](migraphx::module& mm, py::buffer data) {
py::buffer_info info = data.request();
auto literal_shape = to_shape(info);
return mm.add_literal(literal_shape, reinterpret_cast<char*>(info.ptr));
},
py::arg("data"))
.def(
"add_parameter",
[](migraphx::module& mm, const std::string& name, const migraphx::shape shape) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment