"test/onnx/onnx_rnn_test.cpp" did not exist on "0d053d71ea09569d9fe87b30601e0faa88d10507"
Commit 00d5d880 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into mi100_opts

parents 00d90ca8 f60c3815
#ifndef MIGRAPHX_GUARD_RTGLIB_INLINE_MODULE_HPP
#define MIGRAPHX_GUARD_RTGLIB_INLINE_MODULE_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
struct inline_module
{
std::string name() const { return "inline_module"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_INSERT_PAD_HPP
#define MIGRAPHX_GUARD_RTGLIB_INSERT_PAD_HPP
#include <string>
#include <vector>
#include <array>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
/**
* insert pads if attribute of padding is asymmetrical
*/
struct insert_pad
{
std::string name() const { return "insert_pad"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -164,6 +164,18 @@ struct hash<migraphx::instruction_ref>
}
};
template <>
struct equal_to<migraphx::instruction_ref>
{
using argument_type = migraphx::instruction_ref;
using result_type = bool;
result_type operator()(const migraphx::instruction_ref& x,
const migraphx::instruction_ref& y) const noexcept
{
return &*x == &*y;
}
};
} // namespace std
#endif
#ifndef MIGRAPHX_GUARD_MIGRAPHX_ITERATOR_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_ITERATOR_HPP
#include <migraphx/config.hpp>
#include <migraphx/rank.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class Iterator, class EndIterator>
auto is_end(rank<2>, Iterator it, EndIterator) -> decltype(!it._M_dereferenceable())
{
return !it._M_dereferenceable();
}
template <class Iterator, class EndIterator>
auto is_end(rank<1>, Iterator it, EndIterator last)
{
return it == last;
}
template <class Iterator, class EndIterator>
bool is_end(Iterator it, EndIterator last)
{
return is_end(rank<2>{}, it, last);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_ITERATOR_HPP
......@@ -5,7 +5,7 @@
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/module.hpp>
#include <migraphx/program.hpp>
#include <migraphx/optional.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/type_name.hpp>
#include <migraphx/config.hpp>
......@@ -19,24 +19,51 @@ namespace match {
struct matcher_context
{
matcher_context(instruction_ref i) : last(i) {}
matcher_context(module& m) : mod(&m) {}
std::unordered_map<std::string, instruction_ref> instructions;
instruction_ref not_found() const { return last; }
template <class M>
bool matched(M m, instruction_ref ins)
{
return m.match(*this, ins) != this->not_found();
return has_value(m.match(*this, ins));
}
template <class M>
auto lazy_match(M m, instruction_ref ins)
bool matched(M m, optional<instruction_ref> ins)
{
if(ins)
return has_value(m.match(*this, *ins));
return false;
}
template <class M, class I>
auto lazy_match(M m, I ins)
{
return [=] { return this->matched(m, ins); };
}
bool has_instruction(instruction_ref ins) const
{
if(mod == nullptr)
return true;
return mod->has_instruction(ins);
}
bool has_instruction(optional<instruction_ref> ins) const
{
if(ins)
return this->has_instruction(*ins);
return false;
}
bool is_last(instruction_ref ins) const
{
assert(mod->begin() != mod->end());
assert(this->has_instruction(ins));
return ins == std::prev(mod->end());
}
private:
instruction_ref last;
module* mod = nullptr;
};
/// Convert a predicate function into a matcher
......@@ -45,12 +72,11 @@ struct predicate_matcher
{
P p;
instruction_ref match(const matcher_context& ctx, instruction_ref ins) const
optional<instruction_ref> match(const matcher_context&, instruction_ref ins) const
{
assert(ins != ctx.not_found());
if(p(ins))
return ins;
return ctx.not_found();
return optional<instruction_ref>{ins};
return nullopt;
}
};
......@@ -60,11 +86,7 @@ struct function_matcher
{
F f;
instruction_ref match(matcher_context& ctx, instruction_ref ins) const
{
assert(ins != ctx.not_found());
return f(ctx, ins);
}
auto match(matcher_context& ctx, instruction_ref ins) const { return f(ctx, ins); }
};
/// Convert a function into a matcher
......@@ -79,12 +101,17 @@ template <class M>
auto bind_match(M m, std::string name)
{
return make_function_matcher(
[ =, name = std::move(name) ](matcher_context & ctx, instruction_ref ins) {
auto result = m.match(ctx, ins);
if(result != ctx.not_found())
ctx.instructions[name] = ins;
return result;
});
[ =, name = std::move(name) ](matcher_context & ctx, instruction_ref ins)
->optional<instruction_ref> {
auto result = m.match(ctx, ins);
if(result)
{
if(not ctx.has_instruction(ins))
return nullopt;
ctx.instructions[name] = ins;
}
return result;
});
}
/// Convert a matcher to a bindable matcher
......@@ -95,10 +122,7 @@ struct bindable_matcher
auto bind(std::string name) const { return bind_match(m, std::move(name)); }
instruction_ref match(matcher_context& ctx, instruction_ref ins) const
{
return m.match(ctx, ins);
}
auto match(matcher_context& ctx, instruction_ref ins) const { return m.match(ctx, ins); }
};
/// Create a bindable matcher
......@@ -126,7 +150,10 @@ using bool_list = std::initializer_list<bool>;
struct id_matcher
{
instruction_ref match(matcher_context&, instruction_ref ins) const { return ins; }
auto match(matcher_context&, instruction_ref ins) const
{
return optional<instruction_ref>{ins};
}
};
/// The basic matcher provides the all_of composability of the matcher
......@@ -140,26 +167,23 @@ struct basic_matcher
{
// Copy m because we cant capture `this` by value
auto mm = m;
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
return make_bf_matcher([=](matcher_context& ctx,
instruction_ref ins) -> optional<instruction_ref> {
auto result = mm.match(ctx, ins);
if(result != ctx.not_found())
if(result)
{
bool matches = fold([&](auto x, auto y) {
return x and y.match(ctx, result) != ctx.not_found();
})(true, ms...);
bool matches =
fold([&](auto x, auto y) { return x and ctx.matched(y, result); })(true, ms...);
if(matches)
return result;
}
return ctx.not_found();
return nullopt;
});
}
auto bind(std::string name) const { return bind_match(m, std::move(name)); }
instruction_ref match(matcher_context& ctx, instruction_ref ins) const
{
return m.match(ctx, ins);
}
auto match(matcher_context& ctx, instruction_ref ins) const { return m.match(ctx, ins); }
};
/// Create a basic matcher from a matcher
......@@ -185,7 +209,7 @@ basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p)
/// Create a typed-erased matcher
using any_matcher_base = basic_matcher<
function_matcher<std::function<instruction_ref(matcher_context&, instruction_ref)>>>;
function_matcher<std::function<optional<instruction_ref>(matcher_context&, instruction_ref)>>>;
struct any_matcher : any_matcher_base
{
template <class M>
......@@ -198,10 +222,10 @@ struct any_matcher : any_matcher_base
#define MIGRAPHX_BASIC_MATCHER(name, ...) \
struct name##_m \
{ \
instruction_ref match(__VA_ARGS__) const; \
optional<instruction_ref> match(__VA_ARGS__) const; \
}; \
const constexpr auto name = migraphx::match::basic_matcher<name##_m>{{}}; \
inline instruction_ref name##_m::match(__VA_ARGS__) const
inline optional<instruction_ref> name##_m::match(__VA_ARGS__) const
/// This macro takes care of the boilerplate for defining a predicate matcher
#define MIGRAPHX_PRED_MATCHER(name, ...) \
......@@ -221,21 +245,29 @@ struct matcher_result
/// Match a single instruction
template <class M>
matcher_result match_instruction(module& p, instruction_ref ins, M&& m)
matcher_result match_instruction(module& mod, instruction_ref ins, M&& m)
{
assert(ins != p.end());
assert(ins != mod.end());
assert(mod.has_instruction(ins));
matcher_context ctx{mod};
matcher_result result;
matcher_context ctx{p.end()};
result.result = m.match(ctx, ins);
result.instructions = ctx.instructions;
if(m.match(ctx, ins))
{
result.result = ins;
result.instructions = ctx.instructions;
}
else
{
result.result = mod.end();
}
return result;
}
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES)
/// Find matches for an instruction in the program
/// Find matches for an instruction in the module
template <class... Ms>
void find_matches(module& p, instruction_ref ins, Ms&&... ms)
void find_matches(module& mod, instruction_ref ins, Ms&&... ms)
{
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
const
......@@ -246,27 +278,27 @@ void find_matches(module& p, instruction_ref ins, Ms&&... ms)
[&](auto&& m) {
if(match)
return;
auto r = match_instruction(p, ins, m.matcher());
if(r.result == p.end())
auto r = match_instruction(mod, ins, m.matcher());
if(r.result == mod.end())
return;
if(trace)
{
std::cout << "Matched by " << get_type_name(m) << std::endl;
p.debug_print(ins);
mod.debug_print(ins);
}
m.apply(p, r);
m.apply(mod, r);
match = true;
},
ms...);
}
/// Find matches in a program
/// Find matches in a module
template <class... Ms>
void find_matches(module& p, Ms&&... ms)
void find_matches(module& mod, Ms&&... ms)
{
for(auto ins : iterator_for(p))
for(auto ins : iterator_for(mod))
{
find_matches(p, ins, ms...);
find_matches(mod, ins, ms...);
}
}
......@@ -339,12 +371,13 @@ struct match_fold_f
template <class... Ts>
auto operator()(Ts... ms) const
{
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
bool matches = match_fold_f::fold_matchers(ctx, ins, ms...);
if(matches == Matches)
return ins;
return ctx.not_found();
});
return make_bf_matcher(
[=](matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
bool matches = match_fold_f::fold_matchers(ctx, ins, ms...);
if(matches == Matches)
return {ins};
return nullopt;
});
}
template <class Selector>
......@@ -353,17 +386,18 @@ struct match_fold_f
return [=](auto... ms) {
// Workaround ICE on gcc by packing matchers into an object
auto mpack = pack(ms...);
return make_bf_matcher([=](matcher_context& ctx, instruction_ref start) {
Op op;
bool matches = Start;
select(start, [&](auto ins) {
auto fm = [&] { return match_fold_f::fold_matchers_pack(ctx, ins, mpack); };
matches = op(always(matches), fm)();
return make_bf_matcher(
[=](matcher_context& ctx, instruction_ref start) -> optional<instruction_ref> {
Op op;
bool matches = Start;
select(start, [&](auto ins) {
auto fm = [&] { return match_fold_f::fold_matchers_pack(ctx, ins, mpack); };
matches = op(always(matches), fm)();
});
if(matches == Matches)
return {start};
return nullopt;
});
if(matches == Matches)
return start;
return ctx.not_found();
});
};
}
};
......@@ -420,64 +454,29 @@ MIGRAPHX_PRED_MATCHER(same_input_shapes, instruction_ref ins)
ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x->get_shape() == s; });
}
MIGRAPHX_BASIC_MATCHER(output, const matcher_context& ctx, instruction_ref ins)
MIGRAPHX_BASIC_MATCHER(output, const matcher_context&, instruction_ref ins)
{
if(ins->outputs().size() == 1)
return ins->outputs().front();
return ctx.not_found();
return {ins->outputs().front()};
return nullopt;
}
MIGRAPHX_BASIC_MATCHER(used_once, const matcher_context& ctx, instruction_ref ins)
{
if(ins->outputs().size() == 1)
return ins;
if(ins->outputs().empty() and std::next(ins) == ctx.not_found())
return ins;
return ctx.not_found();
}
inline auto used_once_recursive(std::size_t depth)
{
return make_basic_fun_matcher([=](const matcher_context& ctx, instruction_ref start) {
// Used once
if(start->outputs().size() == 1)
return start;
// Unused
if(start->outputs().empty())
{
if(std::next(start) == ctx.not_found())
return start;
else
return ctx.not_found();
}
// Check for dead instructions
auto is_dead = fix<bool>([&](auto self, auto ins, auto n) {
if(n == 0)
return false;
if(ins->get_shape().elements() == 0)
return false;
if(ins->outputs().empty() and std::next(ins) != ctx.not_found())
return true;
return std::all_of(ins->outputs().begin(), ins->outputs().end(), [&](auto i) {
return self(i, n - 1);
});
});
auto dead = std::count_if(start->outputs().begin(), start->outputs().end(), [&](auto i) {
return is_dead(i, depth);
});
if(dead + 1 == start->outputs().size())
return start;
return ctx.not_found();
});
return {ins};
if(ins->outputs().empty() and ctx.is_last(ins))
return {ins};
return nullopt;
}
MIGRAPHX_PRED_MATCHER(is_constant, instruction_ref ins) { return ins->can_eval(); }
MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref ins)
{
if(ins->outputs().empty() and ins != std::prev(ctx.not_found()))
return ins;
return ctx.not_found();
if(ins->outputs().empty() and not ctx.is_last(ins))
return {ins};
return nullopt;
}
template <class... Ms>
......@@ -485,14 +484,15 @@ auto skip(Ms... ms)
{
auto m = any_of(ms...);
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) {
return fix<instruction_ref>([&](auto self, auto ins) {
if(ins->inputs().size() == 1 and ctx.matched(m, ins))
{
auto next = ins->inputs().front();
return self(next);
}
return ins;
})(start);
return fix<optional<instruction_ref>>(
[&](auto self, auto ins) -> optional<instruction_ref> {
if(ins->inputs().size() == 1 and ctx.matched(m, ins))
{
auto next = ins->inputs().front();
return self(next);
}
return ins;
})(start);
});
}
......@@ -501,20 +501,21 @@ auto skip_output(Ms... ms)
{
auto m = any_of(ms...);
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) {
return fix<instruction_ref>([&](auto self, auto ins) {
if(ins->outputs().size() == 1)
{
auto next = ins->outputs().front();
if(ctx.matched(m, next))
return fix<optional<instruction_ref>>(
[&](auto self, auto ins) -> optional<instruction_ref> {
if(ins->outputs().size() == 1)
{
auto skipped_next = self(next);
if(skipped_next != ctx.not_found())
return skipped_next;
auto next = ins->outputs().front();
if(ctx.matched(m, next))
{
auto skipped_next = self(next);
if(skipped_next)
return skipped_next;
}
return next;
}
return next;
}
return ctx.not_found();
})(start);
return nullopt;
})(start);
});
}
......@@ -550,11 +551,12 @@ inline auto nargs(std::size_t n)
inline auto arg(std::size_t i)
{
return make_basic_fun_matcher([=](const matcher_context& ctx, instruction_ref ins) {
if(i < ins->inputs().size())
return ins->inputs()[i];
return ctx.not_found();
});
return make_basic_fun_matcher(
[=](const matcher_context&, instruction_ref ins) -> optional<instruction_ref> {
if(i < ins->inputs().size())
return ins->inputs()[i];
return nullopt;
});
}
// Workaround for bugs in clang
......@@ -616,52 +618,55 @@ std::size_t tree_leafs_impl(matcher_context& ctx,
template <class M, class... Ms>
auto tree(M main_op, Ms... ms)
{
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) {
// Flatten leaf nodes
std::array<instruction_ref, sizeof...(Ms)> leafs;
std::size_t idx = tree_leafs_impl(ctx, leafs, main_op, ins);
if(idx != leafs.size())
return ctx.not_found();
// Use explicit captures to workaround ICE on gcc
bool found = sequence_c<sizeof...(Ms)>([&ms..., &ctx, &leafs](auto... is) {
return fold(lazy_and{})(ctx.lazy_match(ms, leafs[is])...)();
return make_basic_fun_matcher(
[=](matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
// Flatten leaf nodes
std::array<instruction_ref, sizeof...(Ms)> leafs;
std::size_t idx = tree_leafs_impl(ctx, leafs, main_op, ins);
if(idx != leafs.size())
return nullopt;
// Use explicit captures to workaround ICE on gcc
bool found = sequence_c<sizeof...(Ms)>([&ms..., &ctx, &leafs](auto... is) {
return fold(lazy_and{})(ctx.lazy_match(ms, leafs[is])...)();
});
if(not found)
return nullopt;
return ins;
});
if(not found)
return ctx.not_found();
return ins;
});
}
template <class M, class... Ms>
auto unordered_tree(M main_op, Ms... ms)
{
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) {
// Flatten leaf nodes
std::array<instruction_ref, sizeof...(Ms)> leafs;
std::size_t idx = tree_leafs_impl(ctx, leafs, main_op, ins);
if(idx != leafs.size())
return ctx.not_found();
// Use explicit captures to workaround ICE on gcc
bool found = sequence_c<sizeof...(Ms)>([ms..., &ctx, &leafs](auto... is) {
return by(fold(lazy_and{}), [is..., &ctx, &leafs](auto m) {
return fold(lazy_or{})(ctx.lazy_match(m, leafs[is])...);
})(ms...)();
return make_basic_fun_matcher(
[=](matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
// Flatten leaf nodes
std::array<instruction_ref, sizeof...(Ms)> leafs;
std::size_t idx = tree_leafs_impl(ctx, leafs, main_op, ins);
if(idx != leafs.size())
return nullopt;
// Use explicit captures to workaround ICE on gcc
bool found = sequence_c<sizeof...(Ms)>([ms..., &ctx, &leafs](auto... is) {
return by(fold(lazy_and{}), [is..., &ctx, &leafs](auto m) {
return fold(lazy_or{})(ctx.lazy_match(m, leafs[is])...);
})(ms...)();
});
if(not found)
return nullopt;
return ins;
});
if(not found)
return ctx.not_found();
return ins;
});
}
template <class M>
auto same_shape(M m)
{
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) {
auto i = m.match(ctx, ins);
if(i != ctx.not_found() and i->get_shape() == ins->get_shape())
return ins;
return ctx.not_found();
});
return make_basic_fun_matcher(
[=](matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
auto i = m.match(ctx, ins);
if(i and (*i)->get_shape() == ins->get_shape())
return ins;
return nullopt;
});
}
template <class... Ms>
......
......@@ -9,6 +9,8 @@
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
......@@ -39,25 +41,30 @@ struct convolution
void check_attribute_size() const
{
if(not(padding.size() == stride.size() and padding.size() == dilation.size()))
if(not((padding.size() == stride.size() or (padding.size() / 2) == stride.size()) and
stride.size() == dilation.size()))
{
MIGRAPHX_THROW("CONVOLUTION: inconsistent attribute sizes");
}
}
shape compute_shape(std::vector<shape> inputs) const
value attributes() const { return {{"normalize_padding", "padding"}}; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2).same_type().same_ndims().min_ndims(3);
check_attribute_size();
// dim num of input and attribute should match
if(inputs[0].lens().size() != padding.size() + 2)
auto input_size = inputs[0].lens().size();
auto padding_size = padding.size();
if(not(input_size == padding_size / 2 + 2 or input_size == padding_size + 2))
{
MIGRAPHX_THROW("CONVOLUTION: input and attribute size mismatch!");
}
const shape& input = inputs.at(0);
const shape& weights = inputs.at(1);
size_t kdims = input.lens().size() - 2;
size_t kdims = input_size - 2;
if(kdims != this->kdims())
{
MIGRAPHX_THROW("convolution: input k-dims does not match attribute size");
......@@ -70,10 +77,13 @@ struct convolution
for(size_t i = 0; i < kdims; i++)
{
auto padding_factor = 2 * padding[i];
if(padding_size == 2 * kdims)
padding_factor = padding[i] + padding[i + kdims];
output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[i + 2] - (1 + dilation[i] * (weights.lens()[i + 2] - 1)) +
2 * padding[i]) /
padding_factor) /
stride[i] +
1)));
}
......@@ -84,7 +94,7 @@ struct convolution
size_t kdims() const
{
check_attribute_size();
return padding.size();
return stride.size();
}
};
......
......@@ -39,7 +39,8 @@ struct deconvolution
void check_attribute_size() const
{
if(not(padding.size() == stride.size() and padding.size() == dilation.size()))
if(not((padding.size() == stride.size() or (padding.size() / 2) == stride.size()) and
stride.size() == dilation.size()))
{
MIGRAPHX_THROW("deconvolution: inconsistent attribute sizes");
}
......@@ -72,7 +73,7 @@ struct deconvolution
size_t kdims() const
{
check_attribute_size();
return padding.size();
return stride.size();
}
};
......
#ifndef MIGRAPHX_GUARD_OPERATORS_GET_TUPLE_ELEM_HPP
#define MIGRAPHX_GUARD_OPERATORS_GET_TUPLE_ELEM_HPP
#include "migraphx/errors.hpp"
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct get_tuple_elem
{
std::size_t index = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.index, "index"));
}
std::string name() const { return "get_tuple_elem"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).tuple_type();
const auto& sub_shapes = inputs.at(0).sub_shapes();
if(index >= sub_shapes.size())
{
MIGRAPHX_THROW("GET_TUPLE_ELEM: index " + std::to_string(index) + " is out of range " +
std::to_string(sub_shapes.size()));
}
return sub_shapes.at(index);
}
argument compute(const shape&, std::vector<argument> args) const
{
assert(args.size() == 1);
auto vec_args = args.at(0).get_sub_objects();
assert(index < vec_args.size());
return vec_args.at(index);
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -35,14 +35,14 @@ struct if_op
MIGRAPHX_THROW("IF: output shapes of submodules must be the same.");
}
return out_shapes0.front();
return shape(out_shapes0);
}
argument compute(
const std::vector<argument>& args,
const std::vector<module_ref>& mods,
const std::function<std::vector<argument>(
module_ref& mdl, const std::unordered_map<std::string, argument>& inputs)>& run) const
argument compute(const shape&,
const std::vector<argument>& args,
const std::vector<module_ref>& mods,
const std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)>& run) const
{
auto cond = args.front().at<bool>();
module_ref mod = cond ? mods[0] : mods[1];
......@@ -63,7 +63,7 @@ struct if_op
[](auto&& name, auto&& arg) { return std::make_pair(name, arg); });
auto results = run(mod, params);
return results[0];
return argument{results};
}
};
......
......@@ -31,7 +31,9 @@ struct im2col
std::string name() const { return "im2col"; }
shape compute_shape(std::vector<shape> inputs) const
value attributes() const { return {{"normalize_padding", "padding"}}; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{
auto input = inputs[0];
auto weights = inputs[1];
......@@ -42,17 +44,24 @@ struct im2col
check_shapes{inputs, *this}.has(2);
if(batch_size != 1)
MIGRAPHX_THROW("im2col only support batch_size 1");
auto padding_h = 2 * padding[0];
auto padding_w = 2 * padding[1];
if(padding.size() == 2 * stride.size())
{
padding_h = padding[0] + padding[2];
padding_w = padding[1] + padding[3];
}
auto output_height = std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[2] - (1 + dilation[0] * (kernel_height - 1)) + 2 * padding[0]) /
stride[0] +
(input.lens()[2] - (1 + dilation[0] * (kernel_height - 1)) + padding_h) / stride[0] +
1));
auto output_width = std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[3] - (1 + dilation[1] * (kernel_width - 1)) + 2 * padding[1]) /
stride[1] +
(input.lens()[3] - (1 + dilation[1] * (kernel_width - 1)) + padding_w) / stride[1] +
1));
auto channels_col = kernel_height * kernel_width * input_channels;
auto channels_col = kernel_height * kernel_width * input_channels;
return {input.type(), {output_height * output_width, channels_col}};
}
};
......
......@@ -15,6 +15,7 @@ namespace op {
// 3.1) include_min(default)/exclude_min
// 4) clip_max(default)/not_clip_max
// 4.1) exclude_max(default)/include_max
// 5) normalize padding
enum class normalize_attribute
{
use_len,
......@@ -22,7 +23,8 @@ enum class normalize_attribute
clip_max,
clip_min,
include_max,
include_min
include_min,
normalize_padding
};
} // namespace op
......
......@@ -8,6 +8,7 @@
#include <migraphx/streamutils.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/value.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/int_divide.hpp>
#include <migraphx/config.hpp>
......@@ -40,29 +41,39 @@ struct pooling
void check_attribute_size() const
{
if(not(padding.size() == stride.size() and padding.size() == lengths.size()))
if(not((padding.size() == stride.size() or (padding.size() / 2) == stride.size()) and
stride.size() == lengths.size()))
{
MIGRAPHX_THROW("POOLING: inconsistent attribute sizes");
}
}
shape compute_shape(std::vector<shape> inputs) const
value attributes() const { return {{"normalize_padding", "padding"}}; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
const shape& input = inputs.at(0);
auto input_lens = input.lens();
size_t kdims = input_lens.size() - 2;
if(kdims != this->kdims())
auto input_lens = input.lens();
size_t kdims = input_lens.size() - 2;
auto input_size = inputs[0].lens().size();
auto padding_size = padding.size();
if(not(input_size == padding_size / 2 + 2 or input_size == padding_size + 2))
{
MIGRAPHX_THROW("pooling: input k-dims does not match attribute size");
MIGRAPHX_THROW("POOLING: input and attribute size mismatch!");
}
std::vector<std::size_t> output_lens(input_lens.begin(), input_lens.begin() + 2);
for(size_t i = 0; i < kdims; i++)
{
std::ptrdiff_t dim_size = input_lens[i + 2] + 2 * padding[i] - lengths[i];
std::ptrdiff_t dim_size;
auto padding_factor = 2 * padding[i];
if(padding_size == 2 * kdims)
padding_factor = padding[i] + padding[i + kdims];
dim_size = input_lens[i + 2] + padding_factor - lengths[i];
assert(dim_size >= 0);
std::size_t len = (ceil_mode) ? ceil_divide<std::ptrdiff_t>(dim_size, stride[i])
: floor_divide<std::ptrdiff_t>(dim_size, stride[i]);
......@@ -75,7 +86,7 @@ struct pooling
size_t kdims() const
{
check_attribute_size();
return padding.size();
return stride.size();
}
};
......
......@@ -36,19 +36,23 @@ struct quant_convolution
f(self.group, "group"));
}
value attributes() const { return {{"general_data_type", "convolution"}}; }
value attributes() const
{
return {{"general_data_type", "convolution"}, {"normalize_padding", "padding"}};
}
std::string name() const { return "quant_convolution"; }
void check_attribute_size() const
{
if(not(padding.size() == stride.size() and padding.size() == dilation.size()))
if(not((padding.size() == stride.size() or (padding.size() / 2) == stride.size()) and
stride.size() == dilation.size()))
{
MIGRAPHX_THROW("quant_convolution: inconsistent attribute sizes");
MIGRAPHX_THROW("QUANT_CONVOLUTION: inconsistent attribute sizes");
}
}
shape compute_shape(std::vector<shape> inputs) const
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2).same_type().same_ndims().min_ndims(3);
check_attribute_size();
......@@ -70,13 +74,16 @@ struct quant_convolution
t = shape::int32_type;
std::vector<size_t> output_lens{input.lens()[0], weights.lens()[0]};
auto padding_size = padding.size();
for(size_t i = 0; i < kdims; i++)
{
auto padding_factor = 2 * padding[i];
if(padding_size == 2 * kdims)
padding_factor = padding[i] + padding[i + kdims];
output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[i + 2] - (1 + dilation[i] * (weights.lens()[i + 2] - 1)) +
2 * padding[i]) /
padding_factor) /
stride[i] +
1)));
}
......@@ -87,7 +94,7 @@ struct quant_convolution
size_t kdims() const
{
check_attribute_size();
return padding.size();
return stride.size();
}
};
......
#ifndef MIGRAPHX_GUARD_OPERATORS_REVERSE_HPP
#define MIGRAPHX_GUARD_OPERATORS_REVERSE_HPP
#include <algorithm>
#include <vector>
#include <cmath>
#include <utility>
#include <migraphx/config.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/value.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct reverse
{
std::vector<int64_t> axes;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axes, "axes"));
}
std::string name() const { return "reverse"; }
value attributes() const
{
value normalize;
normalize["axes"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
shape normalize_compute_shape(std::vector<shape> inputs) const
{
return inputs[0].with_lens(inputs[0].lens());
}
argument compute(const shape& s, std::vector<argument> args) const
{
argument result{s};
auto lens = s.lens();
visit_all(result, args.front())([&](auto output, auto input) {
shape_for_each(s, [&](const auto& out_idx) {
auto in_idx = out_idx;
for(const auto& axis : axes)
{
in_idx[axis] = lens[axis] - 1 - out_idx[axis];
}
output[s.index(out_idx)] = input[s.index(in_idx)];
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -7,6 +7,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
......@@ -25,8 +26,15 @@ struct step
return pack(f(self.axes, "axes"), f(self.steps, "steps"));
}
value attributes() const
{
value normalize;
normalize["axes"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "step"; }
shape compute_shape(std::vector<shape> inputs) const
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto input = inputs.at(0);
......
......@@ -178,7 +178,7 @@ shape normalize_compute_shape_op(const T& x,
}
template <class T>
auto compute_op(rank<2>,
auto compute_op(rank<1>,
const T& x,
context& ctx,
const shape& output_shape,
......@@ -188,14 +188,6 @@ auto compute_op(rank<2>,
return x.compute(auto_any_cast(ctx), output_shape, input);
}
template <class T>
auto compute_op(
rank<1>, const T& x, context&, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input))
{
return x.compute(output_shape, input);
}
template <class T>
argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&)
{
......@@ -207,50 +199,106 @@ template <class T>
argument
compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input)
{
return compute_op(rank<2>{}, x, ctx, output_shape, input);
return compute_op(rank<1>{}, x, ctx, output_shape, input);
}
template <class T>
auto compute_op(rank<2>, const T& x, const shape& output_shape, const std::vector<argument>& input)
auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input))
{
return x.compute(output_shape, input);
}
template <class T>
auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(auto_any_cast(std::declval<context&>()), output_shape, input))
argument compute_op(rank<0>, const T& x, const shape&, const std::vector<argument>&)
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable without a context: " + name);
MIGRAPHX_THROW("Not computable: " + name);
}
template <class T>
argument compute_op(rank<0>, const T& x, const shape&, const std::vector<argument>&)
argument compute_op(const T& x, const shape& output_shape, const std::vector<argument>& input)
{
return compute_op(rank<1>{}, x, output_shape, input);
}
template <class T, class F>
auto compute_op(rank<1>,
const T& x,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(output, inputs, module_args, f))
{
return x.compute(output, inputs, module_args, f);
}
template <class T, class F>
argument compute_op(rank<0>,
const T& x,
const shape&,
const std::vector<argument>&,
const std::vector<module_ref>&,
F)
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name);
}
template <class T>
argument compute_op(const T& x, const shape& output_shape, const std::vector<argument>& input)
template <class T, class F>
argument compute_op(const T& x,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f)
{
return compute_op(rank<2>{}, x, output_shape, input);
return compute_op(rank<1>{}, x, output, inputs, module_args, f);
}
template <class T, class F>
auto compute_op(rank<1>,
auto compute_op(rank<3>,
const T& x,
context&,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(inputs, module_args, f))
F f) -> decltype(x.compute(output, inputs, module_args, f))
{
return x.compute(inputs, module_args, f);
return x.compute(output, inputs, module_args, f);
}
template <class T, class F>
argument
compute_op(rank<0>, const T& x, const std::vector<argument>&, const std::vector<module_ref>&, F)
auto compute_op(rank<2>,
const T& x,
context&,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>&,
F) -> decltype(x.compute(output, inputs))
{
return x.compute(output, inputs);
}
template <class T, class F>
auto compute_op(rank<1>,
const T& x,
context& ctx,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>&,
F) -> decltype(x.compute(auto_any_cast(ctx), output, inputs))
{
return x.compute(auto_any_cast(ctx), output, inputs);
}
template <class T, class F>
argument compute_op(rank<0>,
const T& x,
context&,
const shape&,
const std::vector<argument>&,
const std::vector<module_ref>&,
F)
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name);
......@@ -258,11 +306,13 @@ argument
template <class T, class F>
argument compute_op(const T& x,
context& ctx,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f)
{
return compute_op(rank<1>{}, x, inputs, module_args, f);
return compute_op(rank<3>{}, x, ctx, output, inputs, module_args, f);
}
template <class T>
......@@ -409,9 +459,12 @@ bool is_borrowed_op(const T&)
* shape compute_shape(const std::vector<shape>& inputs,const std::vector<module_ref>&
* mod_args) const; argument compute(context& ctx,const shape& output,const std::vector<argument>&
* input) const; argument compute(const shape& output,const std::vector<argument>& input)
* const; argument compute(const std::vector<argument>& input,const std::vector<module_ref>&
* module_args,std::function<std::vector<argument>(module_ref& mdl, const
* std::unordered_map<std::string, argument>& inputs)> run) const; value to_value() const; void
* const; argument compute(const shape& output,const std::vector<argument>& input,const
* std::vector<module_ref>& module_args,std::function<std::vector<argument>(module_ref&, const
* std::unordered_map<std::string, argument>&)> run) const; argument compute(context& ctx,const
* shape& output,const std::vector<argument>& input,const std::vector<module_ref>&
* module_args,std::function<std::vector<argument>(module_ref&, const
* std::unordered_map<std::string, argument>&)> run) const; value to_value() const; void
* from_value(const value& v) ; value attributes() const; friend std::ostream &
* operator<<(std::ostream & os,const operation & op) ; friend bool operator==(const operation &
* x,const operation & y) ;
......@@ -555,14 +608,27 @@ struct operation
return (*this).private_detail_te_get_handle().compute(output, input);
}
argument compute(
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref& mdl, const std::unordered_map<std::string, argument>& inputs)> run) const
argument compute(const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)> run) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().compute(input, module_args, std::move(run));
return (*this).private_detail_te_get_handle().compute(
output, input, module_args, std::move(run));
}
argument compute(context& ctx,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)> run) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().compute(
ctx, output, input, module_args, std::move(run));
}
value to_value() const
......@@ -625,16 +691,23 @@ struct operation
compute(context& ctx, const shape& output, const std::vector<argument>& input) const = 0;
virtual argument compute(const shape& output, const std::vector<argument>& input) const = 0;
virtual argument
compute(const std::vector<argument>& input,
compute(const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)> run) const = 0;
virtual argument
compute(context& ctx,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref& mdl, const std::unordered_map<std::string, argument>& inputs)> run)
const = 0;
virtual value to_value() const = 0;
virtual void from_value(const value& v) = 0;
virtual value attributes() const = 0;
virtual std::ostream& operator_shift_left(std::ostream& os) const = 0;
virtual bool operator==(const operation& y) const = 0;
module_ref&, const std::unordered_map<std::string, argument>&)> run) const = 0;
virtual value to_value() const = 0;
virtual void from_value(const value& v) = 0;
virtual value attributes() const = 0;
virtual std::ostream& operator_shift_left(std::ostream& os) const = 0;
virtual bool operator==(const operation& y) const = 0;
};
template <class T>
......@@ -828,25 +901,58 @@ struct operation
static auto private_detail_te_default_compute(
char,
T&& private_detail_te_self,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(module_ref&,
const std::unordered_map<std::string, argument>&)> run)
-> decltype(private_detail_te_self.compute(output, input, module_args, std::move(run)))
{
return private_detail_te_self.compute(output, input, module_args, std::move(run));
}
template <class T>
static argument private_detail_te_default_compute(
float,
T&& private_detail_te_self,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(module_ref&,
const std::unordered_map<std::string, argument>&)> run)
{
return detail::compute_op(
private_detail_te_self, output, input, module_args, std::move(run));
}
template <class T>
static auto private_detail_te_default_compute(
char,
T&& private_detail_te_self,
context& ctx,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref& mdl, const std::unordered_map<std::string, argument>& inputs)> run)
-> decltype(private_detail_te_self.compute(input, module_args, std::move(run)))
std::function<std::vector<argument>(module_ref&,
const std::unordered_map<std::string, argument>&)> run)
-> decltype(private_detail_te_self.compute(ctx, output, input, module_args, std::move(run)))
{
return private_detail_te_self.compute(input, module_args, std::move(run));
return private_detail_te_self.compute(ctx, output, input, module_args, std::move(run));
}
template <class T>
static argument private_detail_te_default_compute(
float,
T&& private_detail_te_self,
context& ctx,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref& mdl, const std::unordered_map<std::string, argument>& inputs)> run)
std::function<std::vector<argument>(module_ref&,
const std::unordered_map<std::string, argument>&)> run)
{
return detail::compute_op(private_detail_te_self, input, module_args, std::move(run));
return detail::compute_op(
private_detail_te_self, ctx, output, input, module_args, std::move(run));
}
template <class T>
......@@ -994,16 +1100,29 @@ struct operation
char(0), private_detail_te_value, output, input);
}
argument
compute(const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref& mdl, const std::unordered_map<std::string, argument>& inputs)> run)
const override
argument compute(
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)> run) const override
{
return private_detail_te_default_compute(
char(0), private_detail_te_value, output, input, module_args, std::move(run));
}
argument compute(
context& ctx,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)> run) const override
{
return private_detail_te_default_compute(
char(0), private_detail_te_value, input, module_args, std::move(run));
char(0), private_detail_te_value, ctx, output, input, module_args, std::move(run));
}
value to_value() const override
......
......@@ -35,6 +35,7 @@
#include <migraphx/op/flatten.hpp>
#include <migraphx/op/floor.hpp>
#include <migraphx/op/gather.hpp>
#include <migraphx/op/get_tuple_elem.hpp>
#include <migraphx/op/greater.hpp>
#include <migraphx/op/gru.hpp>
#include <migraphx/op/identity.hpp>
......@@ -71,6 +72,7 @@
#include <migraphx/op/reduce_sum.hpp>
#include <migraphx/op/relu.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/reverse.hpp>
#include <migraphx/op/rnn.hpp>
#include <migraphx/op/rnn_last_cell_output.hpp>
#include <migraphx/op/rnn_last_hs_output.hpp>
......
#ifndef MIGRAPHX_GUARD_MIGRAPHX_OPTIONAL_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_OPTIONAL_HPP
#include <migraphx/config.hpp>
#if defined(__has_include) && !defined(CPPCHECK)
#if __has_include(<optional>) && __cplusplus >= 201703L
#define MIGRAPHX_HAS_OPTIONAL 1
#else
#define MIGRAPHX_HAS_OPTIONAL 0
#endif
#if __has_include(<experimental/optional>) && __cplusplus >= 201103L
#define MIGRAPHX_HAS_OPTIONAL_TS 1
#else
#define MIGRAPHX_HAS_OPTIONAL_TS 0
#endif
#else
#define MIGRAPHX_HAS_OPTIONAL 0
#define MIGRAPHX_HAS_OPTIONAL_TS 0
#endif
#if MIGRAPHX_HAS_OPTIONAL
#include <optional>
#elif MIGRAPHX_HAS_OPTIONAL_TS
#include <experimental/optional>
#else
#error "No optional include available"
#endif
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
#if MIGRAPHX_HAS_OPTIONAL
template <class T>
using optional = std::optional<T>;
using nullopt_t = std::nullopt_t;
constexpr auto nullopt = std::nullopt;
#elif MIGRAPHX_HAS_OPTIONAL_TS
template <class T>
using optional = std::experimental::optional<T>;
using nullopt_t = std::experimental::nullopt_t;
constexpr auto nullopt = std::experimental::nullopt;
#endif
template <class T>
bool has_value(const optional<T>& x)
{
return x != nullopt;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_OPTIONAL_HPP
#include <migraphx/inline_module.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/iterator_for.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
static void inline_submodule(module& m, instruction_ref ins, bool cond)
{
const auto& mod_inputs = ins->module_inputs();
const auto* smod = cond ? mod_inputs.at(0) : mod_inputs.at(1);
std::unordered_map<instruction_ref, instruction_ref> map_ins;
std::vector<instruction_ref> mod_outputs;
for(auto sins : iterator_for(*smod))
{
instruction_ref copy_ins{};
if(sins->name() == "@literal")
{
auto l = sins->get_literal();
copy_ins = m.add_literal(l);
}
else if(sins->name() == "@param")
{
auto&& name = any_cast<builtin::param>(sins->get_operator()).parameter;
auto s = sins->get_shape();
copy_ins = m.add_parameter(name, s);
}
else if(sins->name() == "@outline")
{
auto s = sins->get_shape();
copy_ins = m.add_outline(s);
}
else
{
auto mod_args = sins->module_inputs();
auto inputs = sins->inputs();
std::vector<instruction_ref> copy_inputs(inputs.size());
std::transform(inputs.begin(), inputs.end(), copy_inputs.begin(), [&](auto i) {
return contains(map_ins, i) ? map_ins[i] : i;
});
if(sins->name() == "@return")
{
mod_outputs = copy_inputs;
break;
}
copy_ins = m.insert_instruction(ins, sins->get_operator(), copy_inputs, mod_args);
}
map_ins[sins] = copy_ins;
mod_outputs = {copy_ins};
}
auto ins_outputs = ins->outputs();
assert(mod_outputs.size() >= ins_outputs.size());
for(const auto& out : ins_outputs)
{
auto val = out->get_operator().to_value();
assert(val.contains("index"));
auto index = val.at("index").to<std::size_t>();
m.replace_instruction(out, mod_outputs.at(index));
}
}
void inline_module::apply(module& m) const
{
for(auto ins : iterator_for(m))
{
if(ins->name() != "if")
continue;
auto arg_cond = ins->inputs().front()->eval();
if(not arg_cond.empty())
{
bool cond = arg_cond.at<bool>();
inline_submodule(m, ins, cond);
}
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/insert_pad.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/im2col.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
static void update_op(const instruction_ref& input, const instruction_ref& ins, module& m)
{
auto op = ins->get_operator();
auto val = op.to_value();
auto op_padding = val.at("padding").to_vector<size_t>();
auto kdims = input->get_shape().lens().size() - 2;
if(std::equal(op_padding.begin(),
op_padding.begin() + kdims,
op_padding.begin() + kdims,
op_padding.end()))
return;
std::vector<int64_t> padding(input->get_shape().lens().size() * 2, 0);
std::vector<size_t> pads_l(op_padding.begin(), op_padding.begin() + kdims);
std::vector<size_t> pads_r(op_padding.begin() + kdims, op_padding.end());
op_padding = std::vector<size_t>(kdims * 2, 0);
op.from_value({{"padding", op_padding}});
std::copy(pads_l.begin(), pads_l.end(), padding.begin() + 2);
std::copy(pads_r.begin(), pads_r.end(), padding.begin() + kdims + 2 + 2);
auto pad_op = m.insert_instruction(ins, op::pad{padding}, input);
auto new_inputs = ins->inputs();
new_inputs.front() = pad_op;
m.replace_instruction(ins, op, new_inputs);
}
static void update_pooling(const instruction_ref& input, const instruction_ref& ins, module& m)
{
auto op = any_cast<op::pooling>(ins->get_operator());
if(op.mode == "average")
{
return;
}
auto kdims = input->get_shape().lens().size() - 2;
if(std::equal(op.padding.begin(),
op.padding.begin() + kdims,
op.padding.begin() + kdims,
op.padding.end()))
return;
std::vector<int64_t> padding(input->get_shape().lens().size() * 2, 0);
std::vector<size_t> pads_l(op.padding.begin(), op.padding.begin() + kdims);
std::vector<size_t> pads_r(op.padding.begin() + kdims, op.padding.end());
op.padding = std::vector<size_t>(kdims * 2, 0);
std::copy(pads_l.begin(), pads_l.end(), padding.begin() + 2);
std::copy(pads_r.begin(), pads_r.end(), padding.begin() + kdims + 2 + 2);
// maxpool uses lowest value for padding
float pad_val = std::numeric_limits<float>::lowest();
auto pad_op = m.insert_instruction(ins, op::pad{padding, pad_val}, input);
auto new_inputs = ins->inputs();
new_inputs.front() = pad_op;
m.replace_instruction(ins, op, new_inputs);
}
void insert_pad::apply(module& m) const
{
for(auto ins : iterator_for(m))
{
const std::string& op_name = ins->name();
if(op_name != "convolution" and op_name != "im2col" and op_name != "pooling")
continue;
auto input = ins->inputs().front();
if(op_name == "convolution" or op_name == "im2col")
update_op(input, ins, m);
else if(op_name == "pooling")
update_pooling(input, ins, m);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment