Commit 00d5d880 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into mi100_opts

parents 00d90ca8 f60c3815
#ifndef MIGRAPHX_GUARD_RTGLIB_INLINE_MODULE_HPP
#define MIGRAPHX_GUARD_RTGLIB_INLINE_MODULE_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
struct inline_module
{
std::string name() const { return "inline_module"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_INSERT_PAD_HPP
#define MIGRAPHX_GUARD_RTGLIB_INSERT_PAD_HPP
#include <string>
#include <vector>
#include <array>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
/**
* insert pads if attribute of padding is asymmetrical
*/
struct insert_pad
{
std::string name() const { return "insert_pad"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -164,6 +164,18 @@ struct hash<migraphx::instruction_ref> ...@@ -164,6 +164,18 @@ struct hash<migraphx::instruction_ref>
} }
}; };
template <>
struct equal_to<migraphx::instruction_ref>
{
using argument_type = migraphx::instruction_ref;
using result_type = bool;
result_type operator()(const migraphx::instruction_ref& x,
const migraphx::instruction_ref& y) const noexcept
{
return &*x == &*y;
}
};
} // namespace std } // namespace std
#endif #endif
#ifndef MIGRAPHX_GUARD_MIGRAPHX_ITERATOR_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_ITERATOR_HPP
#include <migraphx/config.hpp>
#include <migraphx/rank.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class Iterator, class EndIterator>
auto is_end(rank<2>, Iterator it, EndIterator) -> decltype(!it._M_dereferenceable())
{
return !it._M_dereferenceable();
}
template <class Iterator, class EndIterator>
auto is_end(rank<1>, Iterator it, EndIterator last)
{
return it == last;
}
template <class Iterator, class EndIterator>
bool is_end(Iterator it, EndIterator last)
{
return is_end(rank<2>{}, it, last);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_ITERATOR_HPP
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/module.hpp> #include <migraphx/module.hpp>
#include <migraphx/program.hpp> #include <migraphx/optional.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/type_name.hpp> #include <migraphx/type_name.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
...@@ -19,24 +19,51 @@ namespace match { ...@@ -19,24 +19,51 @@ namespace match {
struct matcher_context struct matcher_context
{ {
matcher_context(instruction_ref i) : last(i) {} matcher_context(module& m) : mod(&m) {}
std::unordered_map<std::string, instruction_ref> instructions; std::unordered_map<std::string, instruction_ref> instructions;
instruction_ref not_found() const { return last; }
template <class M> template <class M>
bool matched(M m, instruction_ref ins) bool matched(M m, instruction_ref ins)
{ {
return m.match(*this, ins) != this->not_found(); return has_value(m.match(*this, ins));
} }
template <class M> template <class M>
auto lazy_match(M m, instruction_ref ins) bool matched(M m, optional<instruction_ref> ins)
{
if(ins)
return has_value(m.match(*this, *ins));
return false;
}
template <class M, class I>
auto lazy_match(M m, I ins)
{ {
return [=] { return this->matched(m, ins); }; return [=] { return this->matched(m, ins); };
} }
bool has_instruction(instruction_ref ins) const
{
if(mod == nullptr)
return true;
return mod->has_instruction(ins);
}
bool has_instruction(optional<instruction_ref> ins) const
{
if(ins)
return this->has_instruction(*ins);
return false;
}
bool is_last(instruction_ref ins) const
{
assert(mod->begin() != mod->end());
assert(this->has_instruction(ins));
return ins == std::prev(mod->end());
}
private: private:
instruction_ref last; module* mod = nullptr;
}; };
/// Convert a predicate function into a matcher /// Convert a predicate function into a matcher
...@@ -45,12 +72,11 @@ struct predicate_matcher ...@@ -45,12 +72,11 @@ struct predicate_matcher
{ {
P p; P p;
instruction_ref match(const matcher_context& ctx, instruction_ref ins) const optional<instruction_ref> match(const matcher_context&, instruction_ref ins) const
{ {
assert(ins != ctx.not_found());
if(p(ins)) if(p(ins))
return ins; return optional<instruction_ref>{ins};
return ctx.not_found(); return nullopt;
} }
}; };
...@@ -60,11 +86,7 @@ struct function_matcher ...@@ -60,11 +86,7 @@ struct function_matcher
{ {
F f; F f;
instruction_ref match(matcher_context& ctx, instruction_ref ins) const auto match(matcher_context& ctx, instruction_ref ins) const { return f(ctx, ins); }
{
assert(ins != ctx.not_found());
return f(ctx, ins);
}
}; };
/// Convert a function into a matcher /// Convert a function into a matcher
...@@ -79,12 +101,17 @@ template <class M> ...@@ -79,12 +101,17 @@ template <class M>
auto bind_match(M m, std::string name) auto bind_match(M m, std::string name)
{ {
return make_function_matcher( return make_function_matcher(
[ =, name = std::move(name) ](matcher_context & ctx, instruction_ref ins) { [ =, name = std::move(name) ](matcher_context & ctx, instruction_ref ins)
auto result = m.match(ctx, ins); ->optional<instruction_ref> {
if(result != ctx.not_found()) auto result = m.match(ctx, ins);
ctx.instructions[name] = ins; if(result)
return result; {
}); if(not ctx.has_instruction(ins))
return nullopt;
ctx.instructions[name] = ins;
}
return result;
});
} }
/// Convert a matcher to a bindable matcher /// Convert a matcher to a bindable matcher
...@@ -95,10 +122,7 @@ struct bindable_matcher ...@@ -95,10 +122,7 @@ struct bindable_matcher
auto bind(std::string name) const { return bind_match(m, std::move(name)); } auto bind(std::string name) const { return bind_match(m, std::move(name)); }
instruction_ref match(matcher_context& ctx, instruction_ref ins) const auto match(matcher_context& ctx, instruction_ref ins) const { return m.match(ctx, ins); }
{
return m.match(ctx, ins);
}
}; };
/// Create a bindable matcher /// Create a bindable matcher
...@@ -126,7 +150,10 @@ using bool_list = std::initializer_list<bool>; ...@@ -126,7 +150,10 @@ using bool_list = std::initializer_list<bool>;
struct id_matcher struct id_matcher
{ {
instruction_ref match(matcher_context&, instruction_ref ins) const { return ins; } auto match(matcher_context&, instruction_ref ins) const
{
return optional<instruction_ref>{ins};
}
}; };
/// The basic matcher provides the all_of composability of the matcher /// The basic matcher provides the all_of composability of the matcher
...@@ -140,26 +167,23 @@ struct basic_matcher ...@@ -140,26 +167,23 @@ struct basic_matcher
{ {
// Copy m because we cant capture `this` by value // Copy m because we cant capture `this` by value
auto mm = m; auto mm = m;
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) { return make_bf_matcher([=](matcher_context& ctx,
instruction_ref ins) -> optional<instruction_ref> {
auto result = mm.match(ctx, ins); auto result = mm.match(ctx, ins);
if(result != ctx.not_found()) if(result)
{ {
bool matches = fold([&](auto x, auto y) { bool matches =
return x and y.match(ctx, result) != ctx.not_found(); fold([&](auto x, auto y) { return x and ctx.matched(y, result); })(true, ms...);
})(true, ms...);
if(matches) if(matches)
return result; return result;
} }
return ctx.not_found(); return nullopt;
}); });
} }
auto bind(std::string name) const { return bind_match(m, std::move(name)); } auto bind(std::string name) const { return bind_match(m, std::move(name)); }
instruction_ref match(matcher_context& ctx, instruction_ref ins) const auto match(matcher_context& ctx, instruction_ref ins) const { return m.match(ctx, ins); }
{
return m.match(ctx, ins);
}
}; };
/// Create a basic matcher from a matcher /// Create a basic matcher from a matcher
...@@ -185,7 +209,7 @@ basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p) ...@@ -185,7 +209,7 @@ basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p)
/// Create a typed-erased matcher /// Create a typed-erased matcher
using any_matcher_base = basic_matcher< using any_matcher_base = basic_matcher<
function_matcher<std::function<instruction_ref(matcher_context&, instruction_ref)>>>; function_matcher<std::function<optional<instruction_ref>(matcher_context&, instruction_ref)>>>;
struct any_matcher : any_matcher_base struct any_matcher : any_matcher_base
{ {
template <class M> template <class M>
...@@ -198,10 +222,10 @@ struct any_matcher : any_matcher_base ...@@ -198,10 +222,10 @@ struct any_matcher : any_matcher_base
#define MIGRAPHX_BASIC_MATCHER(name, ...) \ #define MIGRAPHX_BASIC_MATCHER(name, ...) \
struct name##_m \ struct name##_m \
{ \ { \
instruction_ref match(__VA_ARGS__) const; \ optional<instruction_ref> match(__VA_ARGS__) const; \
}; \ }; \
const constexpr auto name = migraphx::match::basic_matcher<name##_m>{{}}; \ const constexpr auto name = migraphx::match::basic_matcher<name##_m>{{}}; \
inline instruction_ref name##_m::match(__VA_ARGS__) const inline optional<instruction_ref> name##_m::match(__VA_ARGS__) const
/// This macro takes care of the boilerplate for defining a predicate matcher /// This macro takes care of the boilerplate for defining a predicate matcher
#define MIGRAPHX_PRED_MATCHER(name, ...) \ #define MIGRAPHX_PRED_MATCHER(name, ...) \
...@@ -221,21 +245,29 @@ struct matcher_result ...@@ -221,21 +245,29 @@ struct matcher_result
/// Match a single instruction /// Match a single instruction
template <class M> template <class M>
matcher_result match_instruction(module& p, instruction_ref ins, M&& m) matcher_result match_instruction(module& mod, instruction_ref ins, M&& m)
{ {
assert(ins != p.end()); assert(ins != mod.end());
assert(mod.has_instruction(ins));
matcher_context ctx{mod};
matcher_result result; matcher_result result;
matcher_context ctx{p.end()}; if(m.match(ctx, ins))
result.result = m.match(ctx, ins); {
result.instructions = ctx.instructions; result.result = ins;
result.instructions = ctx.instructions;
}
else
{
result.result = mod.end();
}
return result; return result;
} }
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES)
/// Find matches for an instruction in the program /// Find matches for an instruction in the module
template <class... Ms> template <class... Ms>
void find_matches(module& p, instruction_ref ins, Ms&&... ms) void find_matches(module& mod, instruction_ref ins, Ms&&... ms)
{ {
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5 #if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
const const
...@@ -246,27 +278,27 @@ void find_matches(module& p, instruction_ref ins, Ms&&... ms) ...@@ -246,27 +278,27 @@ void find_matches(module& p, instruction_ref ins, Ms&&... ms)
[&](auto&& m) { [&](auto&& m) {
if(match) if(match)
return; return;
auto r = match_instruction(p, ins, m.matcher()); auto r = match_instruction(mod, ins, m.matcher());
if(r.result == p.end()) if(r.result == mod.end())
return; return;
if(trace) if(trace)
{ {
std::cout << "Matched by " << get_type_name(m) << std::endl; std::cout << "Matched by " << get_type_name(m) << std::endl;
p.debug_print(ins); mod.debug_print(ins);
} }
m.apply(p, r); m.apply(mod, r);
match = true; match = true;
}, },
ms...); ms...);
} }
/// Find matches in a program /// Find matches in a module
template <class... Ms> template <class... Ms>
void find_matches(module& p, Ms&&... ms) void find_matches(module& mod, Ms&&... ms)
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(mod))
{ {
find_matches(p, ins, ms...); find_matches(mod, ins, ms...);
} }
} }
...@@ -339,12 +371,13 @@ struct match_fold_f ...@@ -339,12 +371,13 @@ struct match_fold_f
template <class... Ts> template <class... Ts>
auto operator()(Ts... ms) const auto operator()(Ts... ms) const
{ {
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) { return make_bf_matcher(
bool matches = match_fold_f::fold_matchers(ctx, ins, ms...); [=](matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
if(matches == Matches) bool matches = match_fold_f::fold_matchers(ctx, ins, ms...);
return ins; if(matches == Matches)
return ctx.not_found(); return {ins};
}); return nullopt;
});
} }
template <class Selector> template <class Selector>
...@@ -353,17 +386,18 @@ struct match_fold_f ...@@ -353,17 +386,18 @@ struct match_fold_f
return [=](auto... ms) { return [=](auto... ms) {
// Workaround ICE on gcc by packing matchers into an object // Workaround ICE on gcc by packing matchers into an object
auto mpack = pack(ms...); auto mpack = pack(ms...);
return make_bf_matcher([=](matcher_context& ctx, instruction_ref start) { return make_bf_matcher(
Op op; [=](matcher_context& ctx, instruction_ref start) -> optional<instruction_ref> {
bool matches = Start; Op op;
select(start, [&](auto ins) { bool matches = Start;
auto fm = [&] { return match_fold_f::fold_matchers_pack(ctx, ins, mpack); }; select(start, [&](auto ins) {
matches = op(always(matches), fm)(); auto fm = [&] { return match_fold_f::fold_matchers_pack(ctx, ins, mpack); };
matches = op(always(matches), fm)();
});
if(matches == Matches)
return {start};
return nullopt;
}); });
if(matches == Matches)
return start;
return ctx.not_found();
});
}; };
} }
}; };
...@@ -420,64 +454,29 @@ MIGRAPHX_PRED_MATCHER(same_input_shapes, instruction_ref ins) ...@@ -420,64 +454,29 @@ MIGRAPHX_PRED_MATCHER(same_input_shapes, instruction_ref ins)
ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x->get_shape() == s; }); ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x->get_shape() == s; });
} }
MIGRAPHX_BASIC_MATCHER(output, const matcher_context& ctx, instruction_ref ins) MIGRAPHX_BASIC_MATCHER(output, const matcher_context&, instruction_ref ins)
{ {
if(ins->outputs().size() == 1) if(ins->outputs().size() == 1)
return ins->outputs().front(); return {ins->outputs().front()};
return ctx.not_found(); return nullopt;
} }
MIGRAPHX_BASIC_MATCHER(used_once, const matcher_context& ctx, instruction_ref ins) MIGRAPHX_BASIC_MATCHER(used_once, const matcher_context& ctx, instruction_ref ins)
{ {
if(ins->outputs().size() == 1) if(ins->outputs().size() == 1)
return ins; return {ins};
if(ins->outputs().empty() and std::next(ins) == ctx.not_found()) if(ins->outputs().empty() and ctx.is_last(ins))
return ins; return {ins};
return ctx.not_found(); return nullopt;
}
inline auto used_once_recursive(std::size_t depth)
{
return make_basic_fun_matcher([=](const matcher_context& ctx, instruction_ref start) {
// Used once
if(start->outputs().size() == 1)
return start;
// Unused
if(start->outputs().empty())
{
if(std::next(start) == ctx.not_found())
return start;
else
return ctx.not_found();
}
// Check for dead instructions
auto is_dead = fix<bool>([&](auto self, auto ins, auto n) {
if(n == 0)
return false;
if(ins->get_shape().elements() == 0)
return false;
if(ins->outputs().empty() and std::next(ins) != ctx.not_found())
return true;
return std::all_of(ins->outputs().begin(), ins->outputs().end(), [&](auto i) {
return self(i, n - 1);
});
});
auto dead = std::count_if(start->outputs().begin(), start->outputs().end(), [&](auto i) {
return is_dead(i, depth);
});
if(dead + 1 == start->outputs().size())
return start;
return ctx.not_found();
});
} }
MIGRAPHX_PRED_MATCHER(is_constant, instruction_ref ins) { return ins->can_eval(); } MIGRAPHX_PRED_MATCHER(is_constant, instruction_ref ins) { return ins->can_eval(); }
MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref ins) MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref ins)
{ {
if(ins->outputs().empty() and ins != std::prev(ctx.not_found())) if(ins->outputs().empty() and not ctx.is_last(ins))
return ins; return {ins};
return ctx.not_found(); return nullopt;
} }
template <class... Ms> template <class... Ms>
...@@ -485,14 +484,15 @@ auto skip(Ms... ms) ...@@ -485,14 +484,15 @@ auto skip(Ms... ms)
{ {
auto m = any_of(ms...); auto m = any_of(ms...);
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) { return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) {
return fix<instruction_ref>([&](auto self, auto ins) { return fix<optional<instruction_ref>>(
if(ins->inputs().size() == 1 and ctx.matched(m, ins)) [&](auto self, auto ins) -> optional<instruction_ref> {
{ if(ins->inputs().size() == 1 and ctx.matched(m, ins))
auto next = ins->inputs().front(); {
return self(next); auto next = ins->inputs().front();
} return self(next);
return ins; }
})(start); return ins;
})(start);
}); });
} }
...@@ -501,20 +501,21 @@ auto skip_output(Ms... ms) ...@@ -501,20 +501,21 @@ auto skip_output(Ms... ms)
{ {
auto m = any_of(ms...); auto m = any_of(ms...);
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) { return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) {
return fix<instruction_ref>([&](auto self, auto ins) { return fix<optional<instruction_ref>>(
if(ins->outputs().size() == 1) [&](auto self, auto ins) -> optional<instruction_ref> {
{ if(ins->outputs().size() == 1)
auto next = ins->outputs().front();
if(ctx.matched(m, next))
{ {
auto skipped_next = self(next); auto next = ins->outputs().front();
if(skipped_next != ctx.not_found()) if(ctx.matched(m, next))
return skipped_next; {
auto skipped_next = self(next);
if(skipped_next)
return skipped_next;
}
return next;
} }
return next; return nullopt;
} })(start);
return ctx.not_found();
})(start);
}); });
} }
...@@ -550,11 +551,12 @@ inline auto nargs(std::size_t n) ...@@ -550,11 +551,12 @@ inline auto nargs(std::size_t n)
inline auto arg(std::size_t i) inline auto arg(std::size_t i)
{ {
return make_basic_fun_matcher([=](const matcher_context& ctx, instruction_ref ins) { return make_basic_fun_matcher(
if(i < ins->inputs().size()) [=](const matcher_context&, instruction_ref ins) -> optional<instruction_ref> {
return ins->inputs()[i]; if(i < ins->inputs().size())
return ctx.not_found(); return ins->inputs()[i];
}); return nullopt;
});
} }
// Workaround for bugs in clang // Workaround for bugs in clang
...@@ -616,52 +618,55 @@ std::size_t tree_leafs_impl(matcher_context& ctx, ...@@ -616,52 +618,55 @@ std::size_t tree_leafs_impl(matcher_context& ctx,
template <class M, class... Ms> template <class M, class... Ms>
auto tree(M main_op, Ms... ms) auto tree(M main_op, Ms... ms)
{ {
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) { return make_basic_fun_matcher(
// Flatten leaf nodes [=](matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
std::array<instruction_ref, sizeof...(Ms)> leafs; // Flatten leaf nodes
std::size_t idx = tree_leafs_impl(ctx, leafs, main_op, ins); std::array<instruction_ref, sizeof...(Ms)> leafs;
if(idx != leafs.size()) std::size_t idx = tree_leafs_impl(ctx, leafs, main_op, ins);
return ctx.not_found(); if(idx != leafs.size())
// Use explicit captures to workaround ICE on gcc return nullopt;
bool found = sequence_c<sizeof...(Ms)>([&ms..., &ctx, &leafs](auto... is) { // Use explicit captures to workaround ICE on gcc
return fold(lazy_and{})(ctx.lazy_match(ms, leafs[is])...)(); bool found = sequence_c<sizeof...(Ms)>([&ms..., &ctx, &leafs](auto... is) {
return fold(lazy_and{})(ctx.lazy_match(ms, leafs[is])...)();
});
if(not found)
return nullopt;
return ins;
}); });
if(not found)
return ctx.not_found();
return ins;
});
} }
template <class M, class... Ms> template <class M, class... Ms>
auto unordered_tree(M main_op, Ms... ms) auto unordered_tree(M main_op, Ms... ms)
{ {
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) { return make_basic_fun_matcher(
// Flatten leaf nodes [=](matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
std::array<instruction_ref, sizeof...(Ms)> leafs; // Flatten leaf nodes
std::size_t idx = tree_leafs_impl(ctx, leafs, main_op, ins); std::array<instruction_ref, sizeof...(Ms)> leafs;
if(idx != leafs.size()) std::size_t idx = tree_leafs_impl(ctx, leafs, main_op, ins);
return ctx.not_found(); if(idx != leafs.size())
// Use explicit captures to workaround ICE on gcc return nullopt;
bool found = sequence_c<sizeof...(Ms)>([ms..., &ctx, &leafs](auto... is) { // Use explicit captures to workaround ICE on gcc
return by(fold(lazy_and{}), [is..., &ctx, &leafs](auto m) { bool found = sequence_c<sizeof...(Ms)>([ms..., &ctx, &leafs](auto... is) {
return fold(lazy_or{})(ctx.lazy_match(m, leafs[is])...); return by(fold(lazy_and{}), [is..., &ctx, &leafs](auto m) {
})(ms...)(); return fold(lazy_or{})(ctx.lazy_match(m, leafs[is])...);
})(ms...)();
});
if(not found)
return nullopt;
return ins;
}); });
if(not found)
return ctx.not_found();
return ins;
});
} }
template <class M> template <class M>
auto same_shape(M m) auto same_shape(M m)
{ {
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) { return make_basic_fun_matcher(
auto i = m.match(ctx, ins); [=](matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
if(i != ctx.not_found() and i->get_shape() == ins->get_shape()) auto i = m.match(ctx, ins);
return ins; if(i and (*i)->get_shape() == ins->get_shape())
return ctx.not_found(); return ins;
}); return nullopt;
});
} }
template <class... Ms> template <class... Ms>
......
...@@ -9,6 +9,8 @@ ...@@ -9,6 +9,8 @@
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -39,25 +41,30 @@ struct convolution ...@@ -39,25 +41,30 @@ struct convolution
void check_attribute_size() const void check_attribute_size() const
{ {
if(not(padding.size() == stride.size() and padding.size() == dilation.size())) if(not((padding.size() == stride.size() or (padding.size() / 2) == stride.size()) and
stride.size() == dilation.size()))
{ {
MIGRAPHX_THROW("CONVOLUTION: inconsistent attribute sizes"); MIGRAPHX_THROW("CONVOLUTION: inconsistent attribute sizes");
} }
} }
shape compute_shape(std::vector<shape> inputs) const value attributes() const { return {{"normalize_padding", "padding"}}; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(2).same_type().same_ndims().min_ndims(3); check_shapes{inputs, *this}.has(2).same_type().same_ndims().min_ndims(3);
check_attribute_size(); check_attribute_size();
// dim num of input and attribute should match // dim num of input and attribute should match
if(inputs[0].lens().size() != padding.size() + 2) auto input_size = inputs[0].lens().size();
auto padding_size = padding.size();
if(not(input_size == padding_size / 2 + 2 or input_size == padding_size + 2))
{ {
MIGRAPHX_THROW("CONVOLUTION: input and attribute size mismatch!"); MIGRAPHX_THROW("CONVOLUTION: input and attribute size mismatch!");
} }
const shape& input = inputs.at(0); const shape& input = inputs.at(0);
const shape& weights = inputs.at(1); const shape& weights = inputs.at(1);
size_t kdims = input.lens().size() - 2; size_t kdims = input_size - 2;
if(kdims != this->kdims()) if(kdims != this->kdims())
{ {
MIGRAPHX_THROW("convolution: input k-dims does not match attribute size"); MIGRAPHX_THROW("convolution: input k-dims does not match attribute size");
...@@ -70,10 +77,13 @@ struct convolution ...@@ -70,10 +77,13 @@ struct convolution
for(size_t i = 0; i < kdims; i++) for(size_t i = 0; i < kdims; i++)
{ {
auto padding_factor = 2 * padding[i];
if(padding_size == 2 * kdims)
padding_factor = padding[i] + padding[i + kdims];
output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>( output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>(
1, 1,
(input.lens()[i + 2] - (1 + dilation[i] * (weights.lens()[i + 2] - 1)) + (input.lens()[i + 2] - (1 + dilation[i] * (weights.lens()[i + 2] - 1)) +
2 * padding[i]) / padding_factor) /
stride[i] + stride[i] +
1))); 1)));
} }
...@@ -84,7 +94,7 @@ struct convolution ...@@ -84,7 +94,7 @@ struct convolution
size_t kdims() const size_t kdims() const
{ {
check_attribute_size(); check_attribute_size();
return padding.size(); return stride.size();
} }
}; };
......
...@@ -39,7 +39,8 @@ struct deconvolution ...@@ -39,7 +39,8 @@ struct deconvolution
void check_attribute_size() const void check_attribute_size() const
{ {
if(not(padding.size() == stride.size() and padding.size() == dilation.size())) if(not((padding.size() == stride.size() or (padding.size() / 2) == stride.size()) and
stride.size() == dilation.size()))
{ {
MIGRAPHX_THROW("deconvolution: inconsistent attribute sizes"); MIGRAPHX_THROW("deconvolution: inconsistent attribute sizes");
} }
...@@ -72,7 +73,7 @@ struct deconvolution ...@@ -72,7 +73,7 @@ struct deconvolution
size_t kdims() const size_t kdims() const
{ {
check_attribute_size(); check_attribute_size();
return padding.size(); return stride.size();
} }
}; };
......
#ifndef MIGRAPHX_GUARD_OPERATORS_GET_TUPLE_ELEM_HPP
#define MIGRAPHX_GUARD_OPERATORS_GET_TUPLE_ELEM_HPP
#include "migraphx/errors.hpp"
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct get_tuple_elem
{
std::size_t index = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.index, "index"));
}
std::string name() const { return "get_tuple_elem"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).tuple_type();
const auto& sub_shapes = inputs.at(0).sub_shapes();
if(index >= sub_shapes.size())
{
MIGRAPHX_THROW("GET_TUPLE_ELEM: index " + std::to_string(index) + " is out of range " +
std::to_string(sub_shapes.size()));
}
return sub_shapes.at(index);
}
argument compute(const shape&, std::vector<argument> args) const
{
assert(args.size() == 1);
auto vec_args = args.at(0).get_sub_objects();
assert(index < vec_args.size());
return vec_args.at(index);
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -35,14 +35,14 @@ struct if_op ...@@ -35,14 +35,14 @@ struct if_op
MIGRAPHX_THROW("IF: output shapes of submodules must be the same."); MIGRAPHX_THROW("IF: output shapes of submodules must be the same.");
} }
return out_shapes0.front(); return shape(out_shapes0);
} }
argument compute( argument compute(const shape&,
const std::vector<argument>& args, const std::vector<argument>& args,
const std::vector<module_ref>& mods, const std::vector<module_ref>& mods,
const std::function<std::vector<argument>( const std::function<std::vector<argument>(
module_ref& mdl, const std::unordered_map<std::string, argument>& inputs)>& run) const module_ref&, const std::unordered_map<std::string, argument>&)>& run) const
{ {
auto cond = args.front().at<bool>(); auto cond = args.front().at<bool>();
module_ref mod = cond ? mods[0] : mods[1]; module_ref mod = cond ? mods[0] : mods[1];
...@@ -63,7 +63,7 @@ struct if_op ...@@ -63,7 +63,7 @@ struct if_op
[](auto&& name, auto&& arg) { return std::make_pair(name, arg); }); [](auto&& name, auto&& arg) { return std::make_pair(name, arg); });
auto results = run(mod, params); auto results = run(mod, params);
return results[0]; return argument{results};
} }
}; };
......
...@@ -31,7 +31,9 @@ struct im2col ...@@ -31,7 +31,9 @@ struct im2col
std::string name() const { return "im2col"; } std::string name() const { return "im2col"; }
shape compute_shape(std::vector<shape> inputs) const value attributes() const { return {{"normalize_padding", "padding"}}; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
auto input = inputs[0]; auto input = inputs[0];
auto weights = inputs[1]; auto weights = inputs[1];
...@@ -42,17 +44,24 @@ struct im2col ...@@ -42,17 +44,24 @@ struct im2col
check_shapes{inputs, *this}.has(2); check_shapes{inputs, *this}.has(2);
if(batch_size != 1) if(batch_size != 1)
MIGRAPHX_THROW("im2col only support batch_size 1"); MIGRAPHX_THROW("im2col only support batch_size 1");
auto padding_h = 2 * padding[0];
auto padding_w = 2 * padding[1];
if(padding.size() == 2 * stride.size())
{
padding_h = padding[0] + padding[2];
padding_w = padding[1] + padding[3];
}
auto output_height = std::size_t(std::max<std::ptrdiff_t>( auto output_height = std::size_t(std::max<std::ptrdiff_t>(
1, 1,
(input.lens()[2] - (1 + dilation[0] * (kernel_height - 1)) + 2 * padding[0]) / (input.lens()[2] - (1 + dilation[0] * (kernel_height - 1)) + padding_h) / stride[0] +
stride[0] +
1)); 1));
auto output_width = std::size_t(std::max<std::ptrdiff_t>( auto output_width = std::size_t(std::max<std::ptrdiff_t>(
1, 1,
(input.lens()[3] - (1 + dilation[1] * (kernel_width - 1)) + 2 * padding[1]) / (input.lens()[3] - (1 + dilation[1] * (kernel_width - 1)) + padding_w) / stride[1] +
stride[1] +
1)); 1));
auto channels_col = kernel_height * kernel_width * input_channels;
auto channels_col = kernel_height * kernel_width * input_channels;
return {input.type(), {output_height * output_width, channels_col}}; return {input.type(), {output_height * output_width, channels_col}};
} }
}; };
......
...@@ -15,6 +15,7 @@ namespace op { ...@@ -15,6 +15,7 @@ namespace op {
// 3.1) include_min(default)/exclude_min // 3.1) include_min(default)/exclude_min
// 4) clip_max(default)/not_clip_max // 4) clip_max(default)/not_clip_max
// 4.1) exclude_max(default)/include_max // 4.1) exclude_max(default)/include_max
// 5) normalize padding
enum class normalize_attribute enum class normalize_attribute
{ {
use_len, use_len,
...@@ -22,7 +23,8 @@ enum class normalize_attribute ...@@ -22,7 +23,8 @@ enum class normalize_attribute
clip_max, clip_max,
clip_min, clip_min,
include_max, include_max,
include_min include_min,
normalize_padding
}; };
} // namespace op } // namespace op
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/value.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/int_divide.hpp> #include <migraphx/int_divide.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
...@@ -40,29 +41,39 @@ struct pooling ...@@ -40,29 +41,39 @@ struct pooling
void check_attribute_size() const void check_attribute_size() const
{ {
if(not(padding.size() == stride.size() and padding.size() == lengths.size())) if(not((padding.size() == stride.size() or (padding.size() / 2) == stride.size()) and
stride.size() == lengths.size()))
{ {
MIGRAPHX_THROW("POOLING: inconsistent attribute sizes"); MIGRAPHX_THROW("POOLING: inconsistent attribute sizes");
} }
} }
shape compute_shape(std::vector<shape> inputs) const value attributes() const { return {{"normalize_padding", "padding"}}; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this}.has(1);
const shape& input = inputs.at(0); const shape& input = inputs.at(0);
auto input_lens = input.lens();
size_t kdims = input_lens.size() - 2; auto input_lens = input.lens();
if(kdims != this->kdims()) size_t kdims = input_lens.size() - 2;
auto input_size = inputs[0].lens().size();
auto padding_size = padding.size();
if(not(input_size == padding_size / 2 + 2 or input_size == padding_size + 2))
{ {
MIGRAPHX_THROW("pooling: input k-dims does not match attribute size"); MIGRAPHX_THROW("POOLING: input and attribute size mismatch!");
} }
std::vector<std::size_t> output_lens(input_lens.begin(), input_lens.begin() + 2); std::vector<std::size_t> output_lens(input_lens.begin(), input_lens.begin() + 2);
for(size_t i = 0; i < kdims; i++) for(size_t i = 0; i < kdims; i++)
{ {
std::ptrdiff_t dim_size = input_lens[i + 2] + 2 * padding[i] - lengths[i]; std::ptrdiff_t dim_size;
auto padding_factor = 2 * padding[i];
if(padding_size == 2 * kdims)
padding_factor = padding[i] + padding[i + kdims];
dim_size = input_lens[i + 2] + padding_factor - lengths[i];
assert(dim_size >= 0); assert(dim_size >= 0);
std::size_t len = (ceil_mode) ? ceil_divide<std::ptrdiff_t>(dim_size, stride[i]) std::size_t len = (ceil_mode) ? ceil_divide<std::ptrdiff_t>(dim_size, stride[i])
: floor_divide<std::ptrdiff_t>(dim_size, stride[i]); : floor_divide<std::ptrdiff_t>(dim_size, stride[i]);
...@@ -75,7 +86,7 @@ struct pooling ...@@ -75,7 +86,7 @@ struct pooling
size_t kdims() const size_t kdims() const
{ {
check_attribute_size(); check_attribute_size();
return padding.size(); return stride.size();
} }
}; };
......
...@@ -36,19 +36,23 @@ struct quant_convolution ...@@ -36,19 +36,23 @@ struct quant_convolution
f(self.group, "group")); f(self.group, "group"));
} }
value attributes() const { return {{"general_data_type", "convolution"}}; } value attributes() const
{
return {{"general_data_type", "convolution"}, {"normalize_padding", "padding"}};
}
std::string name() const { return "quant_convolution"; } std::string name() const { return "quant_convolution"; }
void check_attribute_size() const void check_attribute_size() const
{ {
if(not(padding.size() == stride.size() and padding.size() == dilation.size())) if(not((padding.size() == stride.size() or (padding.size() / 2) == stride.size()) and
stride.size() == dilation.size()))
{ {
MIGRAPHX_THROW("quant_convolution: inconsistent attribute sizes"); MIGRAPHX_THROW("QUANT_CONVOLUTION: inconsistent attribute sizes");
} }
} }
shape compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(2).same_type().same_ndims().min_ndims(3); check_shapes{inputs, *this}.has(2).same_type().same_ndims().min_ndims(3);
check_attribute_size(); check_attribute_size();
...@@ -70,13 +74,16 @@ struct quant_convolution ...@@ -70,13 +74,16 @@ struct quant_convolution
t = shape::int32_type; t = shape::int32_type;
std::vector<size_t> output_lens{input.lens()[0], weights.lens()[0]}; std::vector<size_t> output_lens{input.lens()[0], weights.lens()[0]};
auto padding_size = padding.size();
for(size_t i = 0; i < kdims; i++) for(size_t i = 0; i < kdims; i++)
{ {
auto padding_factor = 2 * padding[i];
if(padding_size == 2 * kdims)
padding_factor = padding[i] + padding[i + kdims];
output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>( output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>(
1, 1,
(input.lens()[i + 2] - (1 + dilation[i] * (weights.lens()[i + 2] - 1)) + (input.lens()[i + 2] - (1 + dilation[i] * (weights.lens()[i + 2] - 1)) +
2 * padding[i]) / padding_factor) /
stride[i] + stride[i] +
1))); 1)));
} }
...@@ -87,7 +94,7 @@ struct quant_convolution ...@@ -87,7 +94,7 @@ struct quant_convolution
size_t kdims() const size_t kdims() const
{ {
check_attribute_size(); check_attribute_size();
return padding.size(); return stride.size();
} }
}; };
......
#ifndef MIGRAPHX_GUARD_OPERATORS_REVERSE_HPP
#define MIGRAPHX_GUARD_OPERATORS_REVERSE_HPP
#include <algorithm>
#include <vector>
#include <cmath>
#include <utility>
#include <migraphx/config.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/value.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct reverse
{
std::vector<int64_t> axes;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axes, "axes"));
}
std::string name() const { return "reverse"; }
value attributes() const
{
value normalize;
normalize["axes"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
shape normalize_compute_shape(std::vector<shape> inputs) const
{
return inputs[0].with_lens(inputs[0].lens());
}
argument compute(const shape& s, std::vector<argument> args) const
{
argument result{s};
auto lens = s.lens();
visit_all(result, args.front())([&](auto output, auto input) {
shape_for_each(s, [&](const auto& out_idx) {
auto in_idx = out_idx;
for(const auto& axis : axes)
{
in_idx[axis] = lens[axis] - 1 - out_idx[axis];
}
output[s.index(out_idx)] = input[s.index(in_idx)];
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -25,8 +26,15 @@ struct step ...@@ -25,8 +26,15 @@ struct step
return pack(f(self.axes, "axes"), f(self.steps, "steps")); return pack(f(self.axes, "axes"), f(self.steps, "steps"));
} }
value attributes() const
{
value normalize;
normalize["axes"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "step"; } std::string name() const { return "step"; }
shape compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this}.has(1);
auto input = inputs.at(0); auto input = inputs.at(0);
......
...@@ -178,7 +178,7 @@ shape normalize_compute_shape_op(const T& x, ...@@ -178,7 +178,7 @@ shape normalize_compute_shape_op(const T& x,
} }
template <class T> template <class T>
auto compute_op(rank<2>, auto compute_op(rank<1>,
const T& x, const T& x,
context& ctx, context& ctx,
const shape& output_shape, const shape& output_shape,
...@@ -188,14 +188,6 @@ auto compute_op(rank<2>, ...@@ -188,14 +188,6 @@ auto compute_op(rank<2>,
return x.compute(auto_any_cast(ctx), output_shape, input); return x.compute(auto_any_cast(ctx), output_shape, input);
} }
template <class T>
auto compute_op(
rank<1>, const T& x, context&, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input))
{
return x.compute(output_shape, input);
}
template <class T> template <class T>
argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&) argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&)
{ {
...@@ -207,50 +199,106 @@ template <class T> ...@@ -207,50 +199,106 @@ template <class T>
argument argument
compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input) compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input)
{ {
return compute_op(rank<2>{}, x, ctx, output_shape, input); return compute_op(rank<1>{}, x, ctx, output_shape, input);
} }
template <class T> template <class T>
auto compute_op(rank<2>, const T& x, const shape& output_shape, const std::vector<argument>& input) auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input)) -> decltype(x.compute(output_shape, input))
{ {
return x.compute(output_shape, input); return x.compute(output_shape, input);
} }
template <class T> template <class T>
auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input) argument compute_op(rank<0>, const T& x, const shape&, const std::vector<argument>&)
-> decltype(x.compute(auto_any_cast(std::declval<context&>()), output_shape, input))
{ {
std::string name = x.name(); std::string name = x.name();
MIGRAPHX_THROW("Not computable without a context: " + name); MIGRAPHX_THROW("Not computable: " + name);
} }
template <class T> template <class T>
argument compute_op(rank<0>, const T& x, const shape&, const std::vector<argument>&) argument compute_op(const T& x, const shape& output_shape, const std::vector<argument>& input)
{
return compute_op(rank<1>{}, x, output_shape, input);
}
template <class T, class F>
auto compute_op(rank<1>,
const T& x,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(output, inputs, module_args, f))
{
return x.compute(output, inputs, module_args, f);
}
template <class T, class F>
argument compute_op(rank<0>,
const T& x,
const shape&,
const std::vector<argument>&,
const std::vector<module_ref>&,
F)
{ {
std::string name = x.name(); std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name); MIGRAPHX_THROW("Not computable: " + name);
} }
template <class T> template <class T, class F>
argument compute_op(const T& x, const shape& output_shape, const std::vector<argument>& input) argument compute_op(const T& x,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f)
{ {
return compute_op(rank<2>{}, x, output_shape, input); return compute_op(rank<1>{}, x, output, inputs, module_args, f);
} }
template <class T, class F> template <class T, class F>
auto compute_op(rank<1>, auto compute_op(rank<3>,
const T& x, const T& x,
context&,
const shape& output,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(inputs, module_args, f)) F f) -> decltype(x.compute(output, inputs, module_args, f))
{ {
return x.compute(inputs, module_args, f); return x.compute(output, inputs, module_args, f);
} }
template <class T, class F> template <class T, class F>
argument auto compute_op(rank<2>,
compute_op(rank<0>, const T& x, const std::vector<argument>&, const std::vector<module_ref>&, F) const T& x,
context&,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>&,
F) -> decltype(x.compute(output, inputs))
{
return x.compute(output, inputs);
}
template <class T, class F>
auto compute_op(rank<1>,
const T& x,
context& ctx,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>&,
F) -> decltype(x.compute(auto_any_cast(ctx), output, inputs))
{
return x.compute(auto_any_cast(ctx), output, inputs);
}
template <class T, class F>
argument compute_op(rank<0>,
const T& x,
context&,
const shape&,
const std::vector<argument>&,
const std::vector<module_ref>&,
F)
{ {
std::string name = x.name(); std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name); MIGRAPHX_THROW("Not computable: " + name);
...@@ -258,11 +306,13 @@ argument ...@@ -258,11 +306,13 @@ argument
template <class T, class F> template <class T, class F>
argument compute_op(const T& x, argument compute_op(const T& x,
context& ctx,
const shape& output,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
F f) F f)
{ {
return compute_op(rank<1>{}, x, inputs, module_args, f); return compute_op(rank<3>{}, x, ctx, output, inputs, module_args, f);
} }
template <class T> template <class T>
...@@ -409,9 +459,12 @@ bool is_borrowed_op(const T&) ...@@ -409,9 +459,12 @@ bool is_borrowed_op(const T&)
* shape compute_shape(const std::vector<shape>& inputs,const std::vector<module_ref>& * shape compute_shape(const std::vector<shape>& inputs,const std::vector<module_ref>&
* mod_args) const; argument compute(context& ctx,const shape& output,const std::vector<argument>& * mod_args) const; argument compute(context& ctx,const shape& output,const std::vector<argument>&
* input) const; argument compute(const shape& output,const std::vector<argument>& input) * input) const; argument compute(const shape& output,const std::vector<argument>& input)
* const; argument compute(const std::vector<argument>& input,const std::vector<module_ref>& * const; argument compute(const shape& output,const std::vector<argument>& input,const
* module_args,std::function<std::vector<argument>(module_ref& mdl, const * std::vector<module_ref>& module_args,std::function<std::vector<argument>(module_ref&, const
* std::unordered_map<std::string, argument>& inputs)> run) const; value to_value() const; void * std::unordered_map<std::string, argument>&)> run) const; argument compute(context& ctx,const
* shape& output,const std::vector<argument>& input,const std::vector<module_ref>&
* module_args,std::function<std::vector<argument>(module_ref&, const
* std::unordered_map<std::string, argument>&)> run) const; value to_value() const; void
* from_value(const value& v) ; value attributes() const; friend std::ostream & * from_value(const value& v) ; value attributes() const; friend std::ostream &
* operator<<(std::ostream & os,const operation & op) ; friend bool operator==(const operation & * operator<<(std::ostream & os,const operation & op) ; friend bool operator==(const operation &
* x,const operation & y) ; * x,const operation & y) ;
...@@ -555,14 +608,27 @@ struct operation ...@@ -555,14 +608,27 @@ struct operation
return (*this).private_detail_te_get_handle().compute(output, input); return (*this).private_detail_te_get_handle().compute(output, input);
} }
argument compute( argument compute(const shape& output,
const std::vector<argument>& input, const std::vector<argument>& input,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
std::function<std::vector<argument>( std::function<std::vector<argument>(
module_ref& mdl, const std::unordered_map<std::string, argument>& inputs)> run) const module_ref&, const std::unordered_map<std::string, argument>&)> run) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().compute(input, module_args, std::move(run)); return (*this).private_detail_te_get_handle().compute(
output, input, module_args, std::move(run));
}
argument compute(context& ctx,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)> run) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().compute(
ctx, output, input, module_args, std::move(run));
} }
value to_value() const value to_value() const
...@@ -625,16 +691,23 @@ struct operation ...@@ -625,16 +691,23 @@ struct operation
compute(context& ctx, const shape& output, const std::vector<argument>& input) const = 0; compute(context& ctx, const shape& output, const std::vector<argument>& input) const = 0;
virtual argument compute(const shape& output, const std::vector<argument>& input) const = 0; virtual argument compute(const shape& output, const std::vector<argument>& input) const = 0;
virtual argument virtual argument
compute(const std::vector<argument>& input, compute(const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)> run) const = 0;
virtual argument
compute(context& ctx,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
std::function<std::vector<argument>( std::function<std::vector<argument>(
module_ref& mdl, const std::unordered_map<std::string, argument>& inputs)> run) module_ref&, const std::unordered_map<std::string, argument>&)> run) const = 0;
const = 0; virtual value to_value() const = 0;
virtual value to_value() const = 0; virtual void from_value(const value& v) = 0;
virtual void from_value(const value& v) = 0; virtual value attributes() const = 0;
virtual value attributes() const = 0; virtual std::ostream& operator_shift_left(std::ostream& os) const = 0;
virtual std::ostream& operator_shift_left(std::ostream& os) const = 0; virtual bool operator==(const operation& y) const = 0;
virtual bool operator==(const operation& y) const = 0;
}; };
template <class T> template <class T>
...@@ -828,25 +901,58 @@ struct operation ...@@ -828,25 +901,58 @@ struct operation
static auto private_detail_te_default_compute( static auto private_detail_te_default_compute(
char, char,
T&& private_detail_te_self, T&& private_detail_te_self,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(module_ref&,
const std::unordered_map<std::string, argument>&)> run)
-> decltype(private_detail_te_self.compute(output, input, module_args, std::move(run)))
{
return private_detail_te_self.compute(output, input, module_args, std::move(run));
}
template <class T>
static argument private_detail_te_default_compute(
float,
T&& private_detail_te_self,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(module_ref&,
const std::unordered_map<std::string, argument>&)> run)
{
return detail::compute_op(
private_detail_te_self, output, input, module_args, std::move(run));
}
template <class T>
static auto private_detail_te_default_compute(
char,
T&& private_detail_te_self,
context& ctx,
const shape& output,
const std::vector<argument>& input, const std::vector<argument>& input,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
std::function<std::vector<argument>( std::function<std::vector<argument>(module_ref&,
module_ref& mdl, const std::unordered_map<std::string, argument>& inputs)> run) const std::unordered_map<std::string, argument>&)> run)
-> decltype(private_detail_te_self.compute(input, module_args, std::move(run))) -> decltype(private_detail_te_self.compute(ctx, output, input, module_args, std::move(run)))
{ {
return private_detail_te_self.compute(input, module_args, std::move(run)); return private_detail_te_self.compute(ctx, output, input, module_args, std::move(run));
} }
template <class T> template <class T>
static argument private_detail_te_default_compute( static argument private_detail_te_default_compute(
float, float,
T&& private_detail_te_self, T&& private_detail_te_self,
context& ctx,
const shape& output,
const std::vector<argument>& input, const std::vector<argument>& input,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
std::function<std::vector<argument>( std::function<std::vector<argument>(module_ref&,
module_ref& mdl, const std::unordered_map<std::string, argument>& inputs)> run) const std::unordered_map<std::string, argument>&)> run)
{ {
return detail::compute_op(private_detail_te_self, input, module_args, std::move(run)); return detail::compute_op(
private_detail_te_self, ctx, output, input, module_args, std::move(run));
} }
template <class T> template <class T>
...@@ -994,16 +1100,29 @@ struct operation ...@@ -994,16 +1100,29 @@ struct operation
char(0), private_detail_te_value, output, input); char(0), private_detail_te_value, output, input);
} }
argument argument compute(
compute(const std::vector<argument>& input, const shape& output,
const std::vector<module_ref>& module_args, const std::vector<argument>& input,
std::function<std::vector<argument>( const std::vector<module_ref>& module_args,
module_ref& mdl, const std::unordered_map<std::string, argument>& inputs)> run) std::function<std::vector<argument>(
const override module_ref&, const std::unordered_map<std::string, argument>&)> run) const override
{
return private_detail_te_default_compute(
char(0), private_detail_te_value, output, input, module_args, std::move(run));
}
argument compute(
context& ctx,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)> run) const override
{ {
return private_detail_te_default_compute( return private_detail_te_default_compute(
char(0), private_detail_te_value, input, module_args, std::move(run)); char(0), private_detail_te_value, ctx, output, input, module_args, std::move(run));
} }
value to_value() const override value to_value() const override
......
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
#include <migraphx/op/flatten.hpp> #include <migraphx/op/flatten.hpp>
#include <migraphx/op/floor.hpp> #include <migraphx/op/floor.hpp>
#include <migraphx/op/gather.hpp> #include <migraphx/op/gather.hpp>
#include <migraphx/op/get_tuple_elem.hpp>
#include <migraphx/op/greater.hpp> #include <migraphx/op/greater.hpp>
#include <migraphx/op/gru.hpp> #include <migraphx/op/gru.hpp>
#include <migraphx/op/identity.hpp> #include <migraphx/op/identity.hpp>
...@@ -71,6 +72,7 @@ ...@@ -71,6 +72,7 @@
#include <migraphx/op/reduce_sum.hpp> #include <migraphx/op/reduce_sum.hpp>
#include <migraphx/op/relu.hpp> #include <migraphx/op/relu.hpp>
#include <migraphx/op/reshape.hpp> #include <migraphx/op/reshape.hpp>
#include <migraphx/op/reverse.hpp>
#include <migraphx/op/rnn.hpp> #include <migraphx/op/rnn.hpp>
#include <migraphx/op/rnn_last_cell_output.hpp> #include <migraphx/op/rnn_last_cell_output.hpp>
#include <migraphx/op/rnn_last_hs_output.hpp> #include <migraphx/op/rnn_last_hs_output.hpp>
......
#ifndef MIGRAPHX_GUARD_MIGRAPHX_OPTIONAL_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_OPTIONAL_HPP
#include <migraphx/config.hpp>
#if defined(__has_include) && !defined(CPPCHECK)
#if __has_include(<optional>) && __cplusplus >= 201703L
#define MIGRAPHX_HAS_OPTIONAL 1
#else
#define MIGRAPHX_HAS_OPTIONAL 0
#endif
#if __has_include(<experimental/optional>) && __cplusplus >= 201103L
#define MIGRAPHX_HAS_OPTIONAL_TS 1
#else
#define MIGRAPHX_HAS_OPTIONAL_TS 0
#endif
#else
#define MIGRAPHX_HAS_OPTIONAL 0
#define MIGRAPHX_HAS_OPTIONAL_TS 0
#endif
#if MIGRAPHX_HAS_OPTIONAL
#include <optional>
#elif MIGRAPHX_HAS_OPTIONAL_TS
#include <experimental/optional>
#else
#error "No optional include available"
#endif
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
#if MIGRAPHX_HAS_OPTIONAL
template <class T>
using optional = std::optional<T>;
using nullopt_t = std::nullopt_t;
constexpr auto nullopt = std::nullopt;
#elif MIGRAPHX_HAS_OPTIONAL_TS
template <class T>
using optional = std::experimental::optional<T>;
using nullopt_t = std::experimental::nullopt_t;
constexpr auto nullopt = std::experimental::nullopt;
#endif
template <class T>
bool has_value(const optional<T>& x)
{
return x != nullopt;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_OPTIONAL_HPP
#include <migraphx/inline_module.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/iterator_for.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
static void inline_submodule(module& m, instruction_ref ins, bool cond)
{
const auto& mod_inputs = ins->module_inputs();
const auto* smod = cond ? mod_inputs.at(0) : mod_inputs.at(1);
std::unordered_map<instruction_ref, instruction_ref> map_ins;
std::vector<instruction_ref> mod_outputs;
for(auto sins : iterator_for(*smod))
{
instruction_ref copy_ins{};
if(sins->name() == "@literal")
{
auto l = sins->get_literal();
copy_ins = m.add_literal(l);
}
else if(sins->name() == "@param")
{
auto&& name = any_cast<builtin::param>(sins->get_operator()).parameter;
auto s = sins->get_shape();
copy_ins = m.add_parameter(name, s);
}
else if(sins->name() == "@outline")
{
auto s = sins->get_shape();
copy_ins = m.add_outline(s);
}
else
{
auto mod_args = sins->module_inputs();
auto inputs = sins->inputs();
std::vector<instruction_ref> copy_inputs(inputs.size());
std::transform(inputs.begin(), inputs.end(), copy_inputs.begin(), [&](auto i) {
return contains(map_ins, i) ? map_ins[i] : i;
});
if(sins->name() == "@return")
{
mod_outputs = copy_inputs;
break;
}
copy_ins = m.insert_instruction(ins, sins->get_operator(), copy_inputs, mod_args);
}
map_ins[sins] = copy_ins;
mod_outputs = {copy_ins};
}
auto ins_outputs = ins->outputs();
assert(mod_outputs.size() >= ins_outputs.size());
for(const auto& out : ins_outputs)
{
auto val = out->get_operator().to_value();
assert(val.contains("index"));
auto index = val.at("index").to<std::size_t>();
m.replace_instruction(out, mod_outputs.at(index));
}
}
void inline_module::apply(module& m) const
{
for(auto ins : iterator_for(m))
{
if(ins->name() != "if")
continue;
auto arg_cond = ins->inputs().front()->eval();
if(not arg_cond.empty())
{
bool cond = arg_cond.at<bool>();
inline_submodule(m, ins, cond);
}
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/insert_pad.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/im2col.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
static void update_op(const instruction_ref& input, const instruction_ref& ins, module& m)
{
auto op = ins->get_operator();
auto val = op.to_value();
auto op_padding = val.at("padding").to_vector<size_t>();
auto kdims = input->get_shape().lens().size() - 2;
if(std::equal(op_padding.begin(),
op_padding.begin() + kdims,
op_padding.begin() + kdims,
op_padding.end()))
return;
std::vector<int64_t> padding(input->get_shape().lens().size() * 2, 0);
std::vector<size_t> pads_l(op_padding.begin(), op_padding.begin() + kdims);
std::vector<size_t> pads_r(op_padding.begin() + kdims, op_padding.end());
op_padding = std::vector<size_t>(kdims * 2, 0);
op.from_value({{"padding", op_padding}});
std::copy(pads_l.begin(), pads_l.end(), padding.begin() + 2);
std::copy(pads_r.begin(), pads_r.end(), padding.begin() + kdims + 2 + 2);
auto pad_op = m.insert_instruction(ins, op::pad{padding}, input);
auto new_inputs = ins->inputs();
new_inputs.front() = pad_op;
m.replace_instruction(ins, op, new_inputs);
}
static void update_pooling(const instruction_ref& input, const instruction_ref& ins, module& m)
{
auto op = any_cast<op::pooling>(ins->get_operator());
if(op.mode == "average")
{
return;
}
auto kdims = input->get_shape().lens().size() - 2;
if(std::equal(op.padding.begin(),
op.padding.begin() + kdims,
op.padding.begin() + kdims,
op.padding.end()))
return;
std::vector<int64_t> padding(input->get_shape().lens().size() * 2, 0);
std::vector<size_t> pads_l(op.padding.begin(), op.padding.begin() + kdims);
std::vector<size_t> pads_r(op.padding.begin() + kdims, op.padding.end());
op.padding = std::vector<size_t>(kdims * 2, 0);
std::copy(pads_l.begin(), pads_l.end(), padding.begin() + 2);
std::copy(pads_r.begin(), pads_r.end(), padding.begin() + kdims + 2 + 2);
// maxpool uses lowest value for padding
float pad_val = std::numeric_limits<float>::lowest();
auto pad_op = m.insert_instruction(ins, op::pad{padding, pad_val}, input);
auto new_inputs = ins->inputs();
new_inputs.front() = pad_op;
m.replace_instruction(ins, op, new_inputs);
}
void insert_pad::apply(module& m) const
{
for(auto ins : iterator_for(m))
{
const std::string& op_name = ins->name();
if(op_name != "convolution" and op_name != "im2col" and op_name != "pooling")
continue;
auto input = ins->inputs().front();
if(op_name == "convolution" or op_name == "im2col")
update_op(input, ins, m);
else if(op_name == "pooling")
update_pooling(input, ins, m);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment