Commit 4a39a0f7 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into add-conv_bn_add-test

parents 5564172e bb827865
#ifndef MIGRAPHX_GUARD_MIGRAPHX_LIFETIME_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_LIFETIME_HPP
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
enum class lifetime
{
local,
global,
borrow
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_LIFETIME_HPP
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/module.hpp> #include <migraphx/module.hpp>
#include <migraphx/program.hpp> #include <migraphx/optional.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/type_name.hpp> #include <migraphx/type_name.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
...@@ -19,24 +19,51 @@ namespace match { ...@@ -19,24 +19,51 @@ namespace match {
struct matcher_context struct matcher_context
{ {
matcher_context(instruction_ref i) : last(i) {} matcher_context(module& m) : mod(&m) {}
std::unordered_map<std::string, instruction_ref> instructions; std::unordered_map<std::string, instruction_ref> instructions;
instruction_ref not_found() const { return last; }
template <class M> template <class M>
bool matched(M m, instruction_ref ins) bool matched(M m, instruction_ref ins)
{ {
return m.match(*this, ins) != this->not_found(); return has_value(m.match(*this, ins));
} }
template <class M> template <class M>
auto lazy_match(M m, instruction_ref ins) bool matched(M m, optional<instruction_ref> ins)
{
if(ins)
return has_value(m.match(*this, *ins));
return false;
}
template <class M, class I>
auto lazy_match(M m, I ins)
{ {
return [=] { return this->matched(m, ins); }; return [=] { return this->matched(m, ins); };
} }
bool has_instruction(instruction_ref ins) const
{
if(mod == nullptr)
return true;
return mod->has_instruction(ins);
}
bool has_instruction(optional<instruction_ref> ins) const
{
if(ins)
return this->has_instruction(*ins);
return false;
}
bool is_last(instruction_ref ins) const
{
assert(mod->begin() != mod->end());
assert(this->has_instruction(ins));
return ins == std::prev(mod->end());
}
private: private:
instruction_ref last; module* mod = nullptr;
}; };
/// Convert a predicate function into a matcher /// Convert a predicate function into a matcher
...@@ -45,12 +72,11 @@ struct predicate_matcher ...@@ -45,12 +72,11 @@ struct predicate_matcher
{ {
P p; P p;
instruction_ref match(const matcher_context& ctx, instruction_ref ins) const optional<instruction_ref> match(const matcher_context&, instruction_ref ins) const
{ {
assert(ins != ctx.not_found());
if(p(ins)) if(p(ins))
return ins; return optional<instruction_ref>{ins};
return ctx.not_found(); return nullopt;
} }
}; };
...@@ -60,11 +86,7 @@ struct function_matcher ...@@ -60,11 +86,7 @@ struct function_matcher
{ {
F f; F f;
instruction_ref match(matcher_context& ctx, instruction_ref ins) const auto match(matcher_context& ctx, instruction_ref ins) const { return f(ctx, ins); }
{
assert(ins != ctx.not_found());
return f(ctx, ins);
}
}; };
/// Convert a function into a matcher /// Convert a function into a matcher
...@@ -79,12 +101,17 @@ template <class M> ...@@ -79,12 +101,17 @@ template <class M>
auto bind_match(M m, std::string name) auto bind_match(M m, std::string name)
{ {
return make_function_matcher( return make_function_matcher(
[ =, name = std::move(name) ](matcher_context & ctx, instruction_ref ins) { [ =, name = std::move(name) ](matcher_context & ctx, instruction_ref ins)
auto result = m.match(ctx, ins); ->optional<instruction_ref> {
if(result != ctx.not_found()) auto result = m.match(ctx, ins);
ctx.instructions[name] = ins; if(result)
return result; {
}); if(not ctx.has_instruction(ins))
return nullopt;
ctx.instructions[name] = ins;
}
return result;
});
} }
/// Convert a matcher to a bindable matcher /// Convert a matcher to a bindable matcher
...@@ -95,10 +122,7 @@ struct bindable_matcher ...@@ -95,10 +122,7 @@ struct bindable_matcher
auto bind(std::string name) const { return bind_match(m, std::move(name)); } auto bind(std::string name) const { return bind_match(m, std::move(name)); }
instruction_ref match(matcher_context& ctx, instruction_ref ins) const auto match(matcher_context& ctx, instruction_ref ins) const { return m.match(ctx, ins); }
{
return m.match(ctx, ins);
}
}; };
/// Create a bindable matcher /// Create a bindable matcher
...@@ -126,7 +150,10 @@ using bool_list = std::initializer_list<bool>; ...@@ -126,7 +150,10 @@ using bool_list = std::initializer_list<bool>;
struct id_matcher struct id_matcher
{ {
instruction_ref match(matcher_context&, instruction_ref ins) const { return ins; } auto match(matcher_context&, instruction_ref ins) const
{
return optional<instruction_ref>{ins};
}
}; };
/// The basic matcher provides the all_of composability of the matcher /// The basic matcher provides the all_of composability of the matcher
...@@ -140,26 +167,23 @@ struct basic_matcher ...@@ -140,26 +167,23 @@ struct basic_matcher
{ {
// Copy m because we cant capture `this` by value // Copy m because we cant capture `this` by value
auto mm = m; auto mm = m;
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) { return make_bf_matcher([=](matcher_context& ctx,
instruction_ref ins) -> optional<instruction_ref> {
auto result = mm.match(ctx, ins); auto result = mm.match(ctx, ins);
if(result != ctx.not_found()) if(result)
{ {
bool matches = fold([&](auto x, auto y) { bool matches =
return x and y.match(ctx, result) != ctx.not_found(); fold([&](auto x, auto y) { return x and ctx.matched(y, result); })(true, ms...);
})(true, ms...);
if(matches) if(matches)
return result; return result;
} }
return ctx.not_found(); return nullopt;
}); });
} }
auto bind(std::string name) const { return bind_match(m, std::move(name)); } auto bind(std::string name) const { return bind_match(m, std::move(name)); }
instruction_ref match(matcher_context& ctx, instruction_ref ins) const auto match(matcher_context& ctx, instruction_ref ins) const { return m.match(ctx, ins); }
{
return m.match(ctx, ins);
}
}; };
/// Create a basic matcher from a matcher /// Create a basic matcher from a matcher
...@@ -185,7 +209,7 @@ basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p) ...@@ -185,7 +209,7 @@ basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p)
/// Create a typed-erased matcher /// Create a typed-erased matcher
using any_matcher_base = basic_matcher< using any_matcher_base = basic_matcher<
function_matcher<std::function<instruction_ref(matcher_context&, instruction_ref)>>>; function_matcher<std::function<optional<instruction_ref>(matcher_context&, instruction_ref)>>>;
struct any_matcher : any_matcher_base struct any_matcher : any_matcher_base
{ {
template <class M> template <class M>
...@@ -198,10 +222,10 @@ struct any_matcher : any_matcher_base ...@@ -198,10 +222,10 @@ struct any_matcher : any_matcher_base
#define MIGRAPHX_BASIC_MATCHER(name, ...) \ #define MIGRAPHX_BASIC_MATCHER(name, ...) \
struct name##_m \ struct name##_m \
{ \ { \
instruction_ref match(__VA_ARGS__) const; \ optional<instruction_ref> match(__VA_ARGS__) const; \
}; \ }; \
const constexpr auto name = migraphx::match::basic_matcher<name##_m>{{}}; \ const constexpr auto name = migraphx::match::basic_matcher<name##_m>{{}}; \
inline instruction_ref name##_m::match(__VA_ARGS__) const inline optional<instruction_ref> name##_m::match(__VA_ARGS__) const
/// This macro takes care of the boilerplate for defining a predicate matcher /// This macro takes care of the boilerplate for defining a predicate matcher
#define MIGRAPHX_PRED_MATCHER(name, ...) \ #define MIGRAPHX_PRED_MATCHER(name, ...) \
...@@ -221,21 +245,43 @@ struct matcher_result ...@@ -221,21 +245,43 @@ struct matcher_result
/// Match a single instruction /// Match a single instruction
template <class M> template <class M>
matcher_result match_instruction(module& p, instruction_ref ins, M&& m) matcher_result match_instruction(module& mod, instruction_ref ins, M&& m)
{ {
assert(ins != p.end()); assert(ins != mod.end());
assert(mod.has_instruction(ins));
matcher_context ctx{mod};
matcher_result result; matcher_result result;
matcher_context ctx{p.end()}; if(m.match(ctx, ins))
result.result = m.match(ctx, ins); {
result.instructions = ctx.instructions; result.result = ins;
result.instructions = ctx.instructions;
}
else
{
result.result = mod.end();
}
return result;
}
/// Find first instance of a matching instruction in a module
template <class M>
match::matcher_result find_match(module& modl, M&& m)
{
match::matcher_result result;
for(auto ins : iterator_for(modl))
{
result = match::match_instruction(modl, ins, m);
if(result.result != modl.end())
return result;
}
return result; return result;
} }
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES)
/// Find matches for an instruction in the program /// Find matches for an instruction in the module
template <class... Ms> template <class... Ms>
void find_matches(module& p, instruction_ref ins, Ms&&... ms) void find_matches(module& mod, instruction_ref ins, Ms&&... ms)
{ {
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5 #if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
const const
...@@ -246,27 +292,27 @@ void find_matches(module& p, instruction_ref ins, Ms&&... ms) ...@@ -246,27 +292,27 @@ void find_matches(module& p, instruction_ref ins, Ms&&... ms)
[&](auto&& m) { [&](auto&& m) {
if(match) if(match)
return; return;
auto r = match_instruction(p, ins, m.matcher()); auto r = match_instruction(mod, ins, m.matcher());
if(r.result == p.end()) if(r.result == mod.end())
return; return;
if(trace) if(trace)
{ {
std::cout << "Matched by " << get_type_name(m) << std::endl; std::cout << "Matched by " << get_type_name(m) << std::endl;
p.debug_print(ins); mod.debug_print(ins);
} }
m.apply(p, r); m.apply(mod, r);
match = true; match = true;
}, },
ms...); ms...);
} }
/// Find matches in a program /// Find matches in a module
template <class... Ms> template <class... Ms>
void find_matches(module& p, Ms&&... ms) void find_matches(module& mod, Ms&&... ms)
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(mod))
{ {
find_matches(p, ins, ms...); find_matches(mod, ins, ms...);
} }
} }
...@@ -339,12 +385,13 @@ struct match_fold_f ...@@ -339,12 +385,13 @@ struct match_fold_f
template <class... Ts> template <class... Ts>
auto operator()(Ts... ms) const auto operator()(Ts... ms) const
{ {
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) { return make_bf_matcher(
bool matches = match_fold_f::fold_matchers(ctx, ins, ms...); [=](matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
if(matches == Matches) bool matches = match_fold_f::fold_matchers(ctx, ins, ms...);
return ins; if(matches == Matches)
return ctx.not_found(); return {ins};
}); return nullopt;
});
} }
template <class Selector> template <class Selector>
...@@ -353,17 +400,18 @@ struct match_fold_f ...@@ -353,17 +400,18 @@ struct match_fold_f
return [=](auto... ms) { return [=](auto... ms) {
// Workaround ICE on gcc by packing matchers into an object // Workaround ICE on gcc by packing matchers into an object
auto mpack = pack(ms...); auto mpack = pack(ms...);
return make_bf_matcher([=](matcher_context& ctx, instruction_ref start) { return make_bf_matcher(
Op op; [=](matcher_context& ctx, instruction_ref start) -> optional<instruction_ref> {
bool matches = Start; Op op;
select(start, [&](auto ins) { bool matches = Start;
auto fm = [&] { return match_fold_f::fold_matchers_pack(ctx, ins, mpack); }; select(start, [&](auto ins) {
matches = op(always(matches), fm)(); auto fm = [&] { return match_fold_f::fold_matchers_pack(ctx, ins, mpack); };
matches = op(always(matches), fm)();
});
if(matches == Matches)
return {start};
return nullopt;
}); });
if(matches == Matches)
return start;
return ctx.not_found();
});
}; };
} }
}; };
...@@ -420,64 +468,29 @@ MIGRAPHX_PRED_MATCHER(same_input_shapes, instruction_ref ins) ...@@ -420,64 +468,29 @@ MIGRAPHX_PRED_MATCHER(same_input_shapes, instruction_ref ins)
ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x->get_shape() == s; }); ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x->get_shape() == s; });
} }
MIGRAPHX_BASIC_MATCHER(output, const matcher_context& ctx, instruction_ref ins) MIGRAPHX_BASIC_MATCHER(output, const matcher_context&, instruction_ref ins)
{ {
if(ins->outputs().size() == 1) if(ins->outputs().size() == 1)
return ins->outputs().front(); return {ins->outputs().front()};
return ctx.not_found(); return nullopt;
} }
MIGRAPHX_BASIC_MATCHER(used_once, const matcher_context& ctx, instruction_ref ins) MIGRAPHX_BASIC_MATCHER(used_once, const matcher_context& ctx, instruction_ref ins)
{ {
if(ins->outputs().size() == 1) if(ins->outputs().size() == 1)
return ins; return {ins};
if(ins->outputs().empty() and std::next(ins) == ctx.not_found()) if(ins->outputs().empty() and ctx.is_last(ins))
return ins; return {ins};
return ctx.not_found(); return nullopt;
}
inline auto used_once_recursive(std::size_t depth)
{
return make_basic_fun_matcher([=](const matcher_context& ctx, instruction_ref start) {
// Used once
if(start->outputs().size() == 1)
return start;
// Unused
if(start->outputs().empty())
{
if(std::next(start) == ctx.not_found())
return start;
else
return ctx.not_found();
}
// Check for dead instructions
auto is_dead = fix<bool>([&](auto self, auto ins, auto n) {
if(n == 0)
return false;
if(ins->get_shape().elements() == 0)
return false;
if(ins->outputs().empty() and std::next(ins) != ctx.not_found())
return true;
return std::all_of(ins->outputs().begin(), ins->outputs().end(), [&](auto i) {
return self(i, n - 1);
});
});
auto dead = std::count_if(start->outputs().begin(), start->outputs().end(), [&](auto i) {
return is_dead(i, depth);
});
if(dead + 1 == start->outputs().size())
return start;
return ctx.not_found();
});
} }
MIGRAPHX_PRED_MATCHER(is_constant, instruction_ref ins) { return ins->can_eval(); } MIGRAPHX_PRED_MATCHER(is_constant, instruction_ref ins) { return ins->can_eval(); }
MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref ins) MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref ins)
{ {
if(ins->outputs().empty() and ins != std::prev(ctx.not_found())) if(ins->outputs().empty() and not ctx.is_last(ins))
return ins; return {ins};
return ctx.not_found(); return nullopt;
} }
template <class... Ms> template <class... Ms>
...@@ -485,14 +498,15 @@ auto skip(Ms... ms) ...@@ -485,14 +498,15 @@ auto skip(Ms... ms)
{ {
auto m = any_of(ms...); auto m = any_of(ms...);
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) { return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) {
return fix<instruction_ref>([&](auto self, auto ins) { return fix<optional<instruction_ref>>(
if(ins->inputs().size() == 1 and ctx.matched(m, ins)) [&](auto self, auto ins) -> optional<instruction_ref> {
{ if(ins->inputs().size() == 1 and ctx.matched(m, ins))
auto next = ins->inputs().front(); {
return self(next); auto next = ins->inputs().front();
} return self(next);
return ins; }
})(start); return ins;
})(start);
}); });
} }
...@@ -501,20 +515,21 @@ auto skip_output(Ms... ms) ...@@ -501,20 +515,21 @@ auto skip_output(Ms... ms)
{ {
auto m = any_of(ms...); auto m = any_of(ms...);
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) { return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) {
return fix<instruction_ref>([&](auto self, auto ins) { return fix<optional<instruction_ref>>(
if(ins->outputs().size() == 1) [&](auto self, auto ins) -> optional<instruction_ref> {
{ if(ins->outputs().size() == 1)
auto next = ins->outputs().front();
if(ctx.matched(m, next))
{ {
auto skipped_next = self(next); auto next = ins->outputs().front();
if(skipped_next != ctx.not_found()) if(ctx.matched(m, next))
return skipped_next; {
auto skipped_next = self(next);
if(skipped_next)
return skipped_next;
}
return next;
} }
return next; return nullopt;
} })(start);
return ctx.not_found();
})(start);
}); });
} }
...@@ -550,11 +565,12 @@ inline auto nargs(std::size_t n) ...@@ -550,11 +565,12 @@ inline auto nargs(std::size_t n)
inline auto arg(std::size_t i) inline auto arg(std::size_t i)
{ {
return make_basic_fun_matcher([=](const matcher_context& ctx, instruction_ref ins) { return make_basic_fun_matcher(
if(i < ins->inputs().size()) [=](const matcher_context&, instruction_ref ins) -> optional<instruction_ref> {
return ins->inputs()[i]; if(i < ins->inputs().size())
return ctx.not_found(); return ins->inputs()[i];
}); return nullopt;
});
} }
// Workaround for bugs in clang // Workaround for bugs in clang
...@@ -616,52 +632,56 @@ std::size_t tree_leafs_impl(matcher_context& ctx, ...@@ -616,52 +632,56 @@ std::size_t tree_leafs_impl(matcher_context& ctx,
template <class M, class... Ms> template <class M, class... Ms>
auto tree(M main_op, Ms... ms) auto tree(M main_op, Ms... ms)
{ {
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) { return make_basic_fun_matcher(
// Flatten leaf nodes [=](matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
std::array<instruction_ref, sizeof...(Ms)> leafs; // Flatten leaf nodes
std::size_t idx = tree_leafs_impl(ctx, leafs, main_op, ins); std::array<instruction_ref, sizeof...(Ms)> leafs;
if(idx != leafs.size()) std::size_t idx = tree_leafs_impl(ctx, leafs, main_op, ins);
return ctx.not_found(); if(idx != leafs.size())
// Use explicit captures to workaround ICE on gcc return nullopt;
bool found = sequence_c<sizeof...(Ms)>([&ms..., &ctx, &leafs](auto... is) { // Use explicit captures to workaround ICE on gcc
return fold(lazy_and{})(ctx.lazy_match(ms, leafs[is])...)(); // Capture by value to workaround compile error on gcc 9
bool found = sequence_c<sizeof...(Ms)>([ms..., &ctx, &leafs](auto... is) {
return fold(lazy_and{})(ctx.lazy_match(ms, leafs[is])...)();
});
if(not found)
return nullopt;
return ins;
}); });
if(not found)
return ctx.not_found();
return ins;
});
} }
template <class M, class... Ms> template <class M, class... Ms>
auto unordered_tree(M main_op, Ms... ms) auto unordered_tree(M main_op, Ms... ms)
{ {
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) { return make_basic_fun_matcher(
// Flatten leaf nodes [=](matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
std::array<instruction_ref, sizeof...(Ms)> leafs; // Flatten leaf nodes
std::size_t idx = tree_leafs_impl(ctx, leafs, main_op, ins); std::array<instruction_ref, sizeof...(Ms)> leafs;
if(idx != leafs.size()) std::size_t idx = tree_leafs_impl(ctx, leafs, main_op, ins);
return ctx.not_found(); if(idx != leafs.size())
// Use explicit captures to workaround ICE on gcc return nullopt;
bool found = sequence_c<sizeof...(Ms)>([ms..., &ctx, &leafs](auto... is) { // Use explicit captures to workaround ICE on gcc
return by(fold(lazy_and{}), [is..., &ctx, &leafs](auto m) { bool found = sequence_c<sizeof...(Ms)>([ms..., &ctx, &leafs](auto... is) {
return fold(lazy_or{})(ctx.lazy_match(m, leafs[is])...); return by(fold(lazy_and{}), [is..., &ctx, &leafs](auto m) {
})(ms...)(); return fold(lazy_or{})(ctx.lazy_match(m, leafs[is])...);
})(ms...)();
});
if(not found)
return nullopt;
return ins;
}); });
if(not found)
return ctx.not_found();
return ins;
});
} }
template <class M> template <class M>
auto same_shape(M m) auto same_shape(M m)
{ {
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) { return make_basic_fun_matcher(
auto i = m.match(ctx, ins); [=](matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
if(i != ctx.not_found() and i->get_shape() == ins->get_shape()) auto i = m.match(ctx, ins);
return ins; if(i and (*i)->get_shape() == ins->get_shape())
return ctx.not_found(); return ins;
}); return nullopt;
});
} }
template <class... Ms> template <class... Ms>
......
...@@ -46,6 +46,9 @@ struct module ...@@ -46,6 +46,9 @@ struct module
std::string name() const; std::string name() const;
bool bypass() const;
void set_bypass(bool b = true);
template <class... Ts, MIGRAPHX_REQUIRES(std::is_same<Ts, instruction_ref>{}...)> template <class... Ts, MIGRAPHX_REQUIRES(std::is_same<Ts, instruction_ref>{}...)>
instruction_ref add_instruction(operation op, Ts... args) instruction_ref add_instruction(operation op, Ts... args)
{ {
......
...@@ -18,6 +18,8 @@ struct onnx_options ...@@ -18,6 +18,8 @@ struct onnx_options
bool skip_unknown_operators = false; bool skip_unknown_operators = false;
/// Print program if an error occurs /// Print program if an error occurs
bool print_program_on_error = false; bool print_program_on_error = false;
/// Max iter num for the loop operator
int64_t max_loop_iterations = 10;
}; };
/// Create a program from an onnx file /// Create a program from an onnx file
...@@ -29,6 +31,8 @@ program parse_onnx_buffer(const std::string& buffer, const onnx_options& options ...@@ -29,6 +31,8 @@ program parse_onnx_buffer(const std::string& buffer, const onnx_options& options
/// Create a program from an onnx buffer /// Create a program from an onnx buffer
program parse_onnx_buffer(const void* data, std::size_t size, const onnx_options& options); program parse_onnx_buffer(const void* data, std::size_t size, const onnx_options& options);
std::vector<std::string> get_onnx_operators();
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -35,7 +36,7 @@ struct as_shape ...@@ -35,7 +36,7 @@ struct as_shape
{ {
return args.front().reshape(output_shape); return args.front().reshape(output_shape);
} }
bool is_borrowed() const { return true; } lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -29,7 +30,7 @@ struct broadcast ...@@ -29,7 +30,7 @@ struct broadcast
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.axis, "axis"), f(self.broadcast_lens, "dims")); return pack(f(self.axis, "axis"), f(self.broadcast_lens, "out_lens"));
} }
std::string name() const { return "broadcast"; } std::string name() const { return "broadcast"; }
...@@ -66,7 +67,7 @@ struct broadcast ...@@ -66,7 +67,7 @@ struct broadcast
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
bool is_borrowed() const { return true; } lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/context.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -29,7 +30,9 @@ struct capture ...@@ -29,7 +30,9 @@ struct capture
shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); } shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); }
argument compute(const shape&, std::vector<argument> args) const // the context argument is added to prevent the op from be eliminated by
// constant propagation
argument compute(context&, const shape&, const std::vector<argument>& args) const
{ {
if(f) if(f)
{ {
...@@ -42,6 +45,8 @@ struct capture ...@@ -42,6 +45,8 @@ struct capture
return args.front(); return args.front();
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
} // namespace op } // namespace op
......
...@@ -9,6 +9,8 @@ ...@@ -9,6 +9,8 @@
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -39,25 +41,30 @@ struct convolution ...@@ -39,25 +41,30 @@ struct convolution
void check_attribute_size() const void check_attribute_size() const
{ {
if(not(padding.size() == stride.size() and padding.size() == dilation.size())) if(not((padding.size() == stride.size() or (padding.size() / 2) == stride.size()) and
stride.size() == dilation.size()))
{ {
MIGRAPHX_THROW("CONVOLUTION: inconsistent attribute sizes"); MIGRAPHX_THROW("CONVOLUTION: inconsistent attribute sizes");
} }
} }
shape compute_shape(std::vector<shape> inputs) const value attributes() const { return {{"normalize_padding", "padding"}}; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(2).same_type().same_ndims().min_ndims(3); check_shapes{inputs, *this}.has(2).same_type().same_ndims().min_ndims(3);
check_attribute_size(); check_attribute_size();
// dim num of input and attribute should match // dim num of input and attribute should match
if(inputs[0].lens().size() != padding.size() + 2) auto input_size = inputs[0].lens().size();
auto padding_size = padding.size();
if(not(input_size == padding_size / 2 + 2 or input_size == padding_size + 2))
{ {
MIGRAPHX_THROW("CONVOLUTION: input and attribute size mismatch!"); MIGRAPHX_THROW("CONVOLUTION: input and attribute size mismatch!");
} }
const shape& input = inputs.at(0); const shape& input = inputs.at(0);
const shape& weights = inputs.at(1); const shape& weights = inputs.at(1);
size_t kdims = input.lens().size() - 2; size_t kdims = input_size - 2;
if(kdims != this->kdims()) if(kdims != this->kdims())
{ {
MIGRAPHX_THROW("convolution: input k-dims does not match attribute size"); MIGRAPHX_THROW("convolution: input k-dims does not match attribute size");
...@@ -70,10 +77,13 @@ struct convolution ...@@ -70,10 +77,13 @@ struct convolution
for(size_t i = 0; i < kdims; i++) for(size_t i = 0; i < kdims; i++)
{ {
auto padding_factor = 2 * padding[i];
if(padding_size == 2 * kdims)
padding_factor = padding[i] + padding[i + kdims];
output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>( output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>(
1, 1,
(input.lens()[i + 2] - (1 + dilation[i] * (weights.lens()[i + 2] - 1)) + (input.lens()[i + 2] - (1 + dilation[i] * (weights.lens()[i + 2] - 1)) +
2 * padding[i]) / padding_factor) /
stride[i] + stride[i] +
1))); 1)));
} }
...@@ -84,7 +94,7 @@ struct convolution ...@@ -84,7 +94,7 @@ struct convolution
size_t kdims() const size_t kdims() const
{ {
check_attribute_size(); check_attribute_size();
return padding.size(); return stride.size();
} }
}; };
......
...@@ -9,6 +9,8 @@ ...@@ -9,6 +9,8 @@
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/par_dfor.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -39,7 +41,8 @@ struct deconvolution ...@@ -39,7 +41,8 @@ struct deconvolution
void check_attribute_size() const void check_attribute_size() const
{ {
if(not(padding.size() == stride.size() and padding.size() == dilation.size())) if(not((padding.size() == stride.size() or (padding.size() / 2) == stride.size()) and
stride.size() == dilation.size()))
{ {
MIGRAPHX_THROW("deconvolution: inconsistent attribute sizes"); MIGRAPHX_THROW("deconvolution: inconsistent attribute sizes");
} }
...@@ -69,10 +72,85 @@ struct deconvolution ...@@ -69,10 +72,85 @@ struct deconvolution
return inputs[0].with_lens(output_lens); return inputs[0].with_lens(output_lens);
} }
argument compute(shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto kdims = this->kdims();
visit_all(result, args[0], args[1])([&](auto output, auto input, auto weights) {
using type = typename decltype(output)::value_type;
std::fill(output.begin(), output.end(), type{0});
auto in_lens = input.get_shape().lens();
auto in_n = in_lens[0];
auto in_c = in_lens[1];
auto wei = weights.get_shape().lens();
auto wei_n = wei[0];
auto wei_c = wei[1];
auto out_lens = output_shape.lens();
std::vector<std::size_t> win_size{in_c};
std::copy(in_lens.begin() + 2, in_lens.end(), std::back_inserter(win_size));
std::copy(wei.begin() + 2, wei.end(), std::back_inserter(win_size));
shape win_shape{output_shape.type(), win_size};
par_dfor(in_n, wei_c)([&](int o, int k) {
shape_for_each(win_shape, [&](auto idx_win) {
const int w = idx_win[0];
auto input_dims_start = idx_win.begin() + 1;
auto wei_dims_start = idx_win.begin() + kdims + 1;
std::vector<std::ptrdiff_t> win_start;
for(std::size_t n = 0; n < kdims; ++n)
{
win_start.push_back(std::ptrdiff_t(*(input_dims_start + n) * stride[n]) -
std::ptrdiff_t(padding[n]));
}
const int group_id = w / (wei_n / group);
const int in_ch = group_id * wei_c + k;
std::vector<std::ptrdiff_t> idx_out{o, in_ch};
for(size_t n = 0; n < kdims; n++)
{
idx_out.push_back(win_start[n] + *(wei_dims_start + n) * dilation[n]);
}
std::vector<std::ptrdiff_t> idx_wei{w, k};
std::copy(wei_dims_start, idx_win.end(), std::back_inserter(idx_wei));
std::vector<std::ptrdiff_t> idx_in{o, w};
std::copy(input_dims_start, wei_dims_start, std::back_inserter(idx_in));
if(std::all_of(
idx_out.begin() + 2, idx_out.end(), [&](auto ii) { return ii >= 0; }) and
std::equal(idx_out.begin() + 2,
idx_out.end(),
out_lens.begin() + 2,
out_lens.end(),
std::less<std::ptrdiff_t>{}))
{
output(idx_out.begin(), idx_out.end()) +=
input(idx_in.begin(), idx_in.end()) *
weights(idx_wei.begin(), idx_wei.end());
}
});
});
});
return result;
}
size_t kdims() const size_t kdims() const
{ {
check_attribute_size(); check_attribute_size();
return padding.size(); return stride.size();
} }
}; };
......
#ifndef MIGRAPHX_GUARD_OPERATORS_DEQUANTIZE_LINEAR_HPP
#define MIGRAPHX_GUARD_OPERATORS_DEQUANTIZE_LINEAR_HPP
#include <array>
#include <migraphx/op/common.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/config.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/tune_axis.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct dequantizelinear
{
std::string name() const { return "dequantizelinear"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.same_dims();
return {inputs[1].type(), inputs[0].lens(), inputs[0].strides()};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
auto x = args.at(0);
auto x_scale = args.at(1);
std::vector<int8_t> zeros(output_shape.bytes(), 0);
argument x_zero_point{{x.get_shape().type(), output_shape.lens()}, zeros.data()};
if(args.size() == 3)
{
x_zero_point = args.at(2);
}
argument result{output_shape};
visit_all(x, x_zero_point)([&](auto input, auto zero_pts) {
visit_all(result, x_scale)([&](auto output, auto scales) {
par_for(output_shape.elements(), [&](auto i) {
output[i] = static_cast<double>(static_cast<int64_t>(input[i]) -
static_cast<int64_t>(zero_pts[i])) *
scales[i];
});
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -18,19 +18,10 @@ namespace op { ...@@ -18,19 +18,10 @@ namespace op {
struct dot struct dot
{ {
float alpha = 1.0;
float beta = 1.0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.alpha, "alpha"), f(self.beta, "beta"));
}
std::string name() const { return "dot"; } std::string name() const { return "dot"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.same_type(); check_shapes{inputs, *this}.same_type().has(2);
const shape& a = inputs.at(0); const shape& a = inputs.at(0);
const shape& b = inputs.at(1); const shape& b = inputs.at(1);
auto t = a.type(); auto t = a.type();
...@@ -58,25 +49,14 @@ struct dot ...@@ -58,25 +49,14 @@ struct dot
auto out_lens = a.lens(); auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1]; out_lens[dim_1] = b.lens()[dim_1];
if(inputs.size() == 3 && out_lens != inputs.at(2).lens())
{
MIGRAPHX_THROW("DOT: dimension mismatch, operand C: {" +
to_string_range(inputs.at(2).lens()) +
"}, cannot add to operand A * B: {" + to_string_range(out_lens) + "}");
}
return {t, out_lens}; return {t, out_lens};
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
argument result; argument result = argument{output_shape};
if(args.size() == 3)
result = args[2];
else
result = argument{output_shape};
visit_all(result, args[0], args[1])( visit_all(result, args[0], args[1])(
[&](auto cmat, auto amat, auto bmat) { gemm(cmat, amat, bmat, alpha, beta); }); [&](auto cmat, auto amat, auto bmat) { gemm(cmat, amat, bmat, 1.0f, 0.0f); });
return result; return result;
} }
}; };
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -38,7 +39,7 @@ struct flatten ...@@ -38,7 +39,7 @@ struct flatten
std::string name() const { return "flatten"; } std::string name() const { return "flatten"; }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this}.has(1).standard();
auto&& lens = inputs.front().lens(); auto&& lens = inputs.front().lens();
auto x = auto x =
std::accumulate(lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<>{}); std::accumulate(lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<>{});
...@@ -50,7 +51,7 @@ struct flatten ...@@ -50,7 +51,7 @@ struct flatten
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
bool is_borrowed() const { return true; } lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
#ifndef MIGRAPHX_GUARD_OPERATORS_GET_TUPLE_ELEM_HPP
#define MIGRAPHX_GUARD_OPERATORS_GET_TUPLE_ELEM_HPP
#include "migraphx/errors.hpp"
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct get_tuple_elem
{
std::size_t index = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.index, "index"));
}
std::string name() const { return "get_tuple_elem"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).tuple_type();
const auto& sub_shapes = inputs.at(0).sub_shapes();
if(index >= sub_shapes.size())
{
MIGRAPHX_THROW("GET_TUPLE_ELEM: index " + std::to_string(index) + " is out of range " +
std::to_string(sub_shapes.size()));
}
return sub_shapes.at(index);
}
argument compute(const shape&, std::vector<argument> args) const
{
assert(args.size() == 1);
auto vec_args = args.at(0).get_sub_objects();
assert(index < vec_args.size());
return vec_args.at(index);
}
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -35,14 +35,14 @@ struct if_op ...@@ -35,14 +35,14 @@ struct if_op
MIGRAPHX_THROW("IF: output shapes of submodules must be the same."); MIGRAPHX_THROW("IF: output shapes of submodules must be the same.");
} }
return out_shapes0.front(); return shape(out_shapes0);
} }
argument compute( argument compute(const shape&,
const std::vector<argument>& args, const std::vector<argument>& args,
const std::vector<module_ref>& mods, const std::vector<module_ref>& mods,
const std::function<std::vector<argument>( const std::function<std::vector<argument>(
module_ref& mdl, const std::unordered_map<std::string, argument>& inputs)>& run) const module_ref&, const std::unordered_map<std::string, argument>&)>& run) const
{ {
auto cond = args.front().at<bool>(); auto cond = args.front().at<bool>();
module_ref mod = cond ? mods[0] : mods[1]; module_ref mod = cond ? mods[0] : mods[1];
...@@ -63,7 +63,7 @@ struct if_op ...@@ -63,7 +63,7 @@ struct if_op
[](auto&& name, auto&& arg) { return std::make_pair(name, arg); }); [](auto&& name, auto&& arg) { return std::make_pair(name, arg); });
auto results = run(mod, params); auto results = run(mod, params);
return results[0]; return argument{results};
} }
}; };
......
...@@ -31,7 +31,9 @@ struct im2col ...@@ -31,7 +31,9 @@ struct im2col
std::string name() const { return "im2col"; } std::string name() const { return "im2col"; }
shape compute_shape(std::vector<shape> inputs) const value attributes() const { return {{"normalize_padding", "padding"}}; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
auto input = inputs[0]; auto input = inputs[0];
auto weights = inputs[1]; auto weights = inputs[1];
...@@ -42,17 +44,24 @@ struct im2col ...@@ -42,17 +44,24 @@ struct im2col
check_shapes{inputs, *this}.has(2); check_shapes{inputs, *this}.has(2);
if(batch_size != 1) if(batch_size != 1)
MIGRAPHX_THROW("im2col only support batch_size 1"); MIGRAPHX_THROW("im2col only support batch_size 1");
auto padding_h = 2 * padding[0];
auto padding_w = 2 * padding[1];
if(padding.size() == 2 * stride.size())
{
padding_h = padding[0] + padding[2];
padding_w = padding[1] + padding[3];
}
auto output_height = std::size_t(std::max<std::ptrdiff_t>( auto output_height = std::size_t(std::max<std::ptrdiff_t>(
1, 1,
(input.lens()[2] - (1 + dilation[0] * (kernel_height - 1)) + 2 * padding[0]) / (input.lens()[2] - (1 + dilation[0] * (kernel_height - 1)) + padding_h) / stride[0] +
stride[0] +
1)); 1));
auto output_width = std::size_t(std::max<std::ptrdiff_t>( auto output_width = std::size_t(std::max<std::ptrdiff_t>(
1, 1,
(input.lens()[3] - (1 + dilation[1] * (kernel_width - 1)) + 2 * padding[1]) / (input.lens()[3] - (1 + dilation[1] * (kernel_width - 1)) + padding_w) / stride[1] +
stride[1] +
1)); 1));
auto channels_col = kernel_height * kernel_width * input_channels;
auto channels_col = kernel_height * kernel_width * input_channels;
return {input.type(), {output_height * output_width, channels_col}}; return {input.type(), {output_height * output_width, channels_col}};
} }
}; };
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -34,9 +35,9 @@ struct load ...@@ -34,9 +35,9 @@ struct load
{ {
if((offset + s.bytes()) > args[0].get_shape().bytes()) if((offset + s.bytes()) > args[0].get_shape().bytes())
MIGRAPHX_THROW("Load access is out of bounds"); MIGRAPHX_THROW("Load access is out of bounds");
return argument::load(s, args[0].data() + offset); return argument{s, args[0].data() + offset};
} }
bool is_borrowed() const { return true; } lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
friend std::ostream& operator<<(std::ostream& os, const load& op) friend std::ostream& operator<<(std::ostream& os, const load& op)
......
#ifndef MIGRAPHX_GUARD_OPERATORS_LOOP_HPP
#define MIGRAPHX_GUARD_OPERATORS_LOOP_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
#include <migraphx/module.hpp>
#include <migraphx/run_loop.hpp>
#include <migraphx/ranges.hpp>
#include <cmath>
#include <string>
#include <utility>
#include <set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct loop
{
int64_t max_iterations = 10;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.max_iterations, "max_iterations"));
}
std::string name() const { return "loop"; }
shape compute_shape(const std::vector<shape>& inputs, std::vector<module_ref> mods) const
{
check_shapes{inputs, *this}.standard();
if(mods.size() != 1)
{
MIGRAPHX_THROW("LOOP: operator should have one submodule.");
}
const auto& mod = mods.front();
auto mod_out_shapes = mod->get_output_shapes();
auto dep_param_num = inputs.size() - 2;
// first item of the mod output shapes is condition used in loop,
// which is not needed to compute output shape
mod_out_shapes.erase(mod_out_shapes.begin());
std::vector<shape> ins_out_shapes(mod_out_shapes.begin(),
mod_out_shapes.begin() + dep_param_num);
mod_out_shapes.erase(mod_out_shapes.begin(), mod_out_shapes.begin() + dep_param_num);
for(const auto& out_s : mod_out_shapes)
{
auto lens = out_s.lens();
lens.insert(lens.begin(), max_iterations);
ins_out_shapes.push_back({out_s.type(), lens});
}
return shape(ins_out_shapes);
}
struct ref_loop
{
int64_t max_iterations = 0;
template <class T>
void copy(context&, const argument& src, T& dst) const
{
dst = *src.cast<T>();
}
template <class T>
void copy(context&, T src, const argument& dst) const
{
*dst.cast<T>() = src;
}
void append(const std::vector<argument>& iter_state,
const std::vector<argument>& concatenated_outputs,
int iter) const
{
assert(iter_state.size() == concatenated_outputs.size());
for(auto i : range(iter_state.size()))
{
const auto& iter_stat = iter_state.at(i);
const auto& scan_out = concatenated_outputs.at(i);
auto* in_data = iter_stat.data();
auto* out_data = scan_out.data();
std::size_t out_size = iter_stat.get_shape().bytes();
assert((iter + 1) * out_size <= scan_out.get_shape().bytes());
std::copy(in_data, in_data + out_size, out_data + iter * out_size);
}
}
void set_zero(context&, const std::vector<argument>& concatenated_outputs, int iter) const
{
if(iter >= max_iterations)
return;
for(const auto& out : concatenated_outputs)
{
auto s = out.get_shape();
auto size = s.bytes() / max_iterations;
std::fill(out.data() + iter * size, out.data() + max_iterations * size, 0);
}
}
std::unordered_map<std::string, int> get_output_params(const module&) const { return {}; }
};
argument compute(context& ctx,
const shape& out_shape,
const std::vector<argument>& args,
const std::vector<module_ref>& mods,
const std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)>& run) const
{
// wrap up the arguments vector, so ref and gpu impl are the same
auto cpy_args = args;
bool in_cond = args.at(1).at<bool>();
bool cond = in_cond;
int64_t iter = 0;
// insert iter and cond used in the loop
auto s_cond = args.at(1).get_shape();
auto s_iter = args.at(0).get_shape();
cpy_args.push_back({s_iter, &iter});
cpy_args.push_back({s_cond, &cond});
cpy_args.insert(cpy_args.end(), args.begin() + 2, args.end());
// add cond and mod outputs to the argument list
cpy_args.push_back(argument(s_cond));
cpy_args.push_back(argument(out_shape));
// run loop
return run_loop(ref_loop{max_iterations}, ctx, cpy_args, mods, run);
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -22,7 +23,7 @@ struct multibroadcast ...@@ -22,7 +23,7 @@ struct multibroadcast
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.output_lens, "output_lens")); return pack(f(self.output_lens, "out_lens"));
} }
std::string name() const { return "multibroadcast"; } std::string name() const { return "multibroadcast"; }
...@@ -68,7 +69,7 @@ struct multibroadcast ...@@ -68,7 +69,7 @@ struct multibroadcast
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
bool is_borrowed() const { return true; } lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
#ifndef MIGRAPHX_GUARD_OPERATORS_MULTINOMIAL_HPP
#define MIGRAPHX_GUARD_OPERATORS_MULTINOMIAL_HPP
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/par_for.hpp>
#include <random>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct multinomial
{
shape::type_t dtype = shape::type_t::int32_type;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.dtype, "dtype"));
}
std::string name() const { return "multinomial"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2).only_dims(2);
size_t sample_size = inputs.back().lens().back();
if(not contains({shape::int32_type, shape::int64_type}, dtype))
MIGRAPHX_THROW(
"Multinomial: Invalid output type. Valid types are int32_type and int64_type.");
return {dtype, {inputs.front().lens().front(), sample_size}};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
size_t batch_size = output_shape.lens().front();
size_t class_size = args[0].get_shape().lens().back();
size_t sample_size = output_shape.lens().back();
visit_all(args[0], args[1])([&](auto cdf, auto dist) {
result.visit([&](auto output) {
par_for(batch_size * sample_size, [&](auto i) {
auto idx = args[1].get_shape().multi(i);
auto cdf_begin = cdf.begin() + (idx[0] * class_size);
auto cdf_end = cdf_begin + class_size;
auto sample_iter =
std::upper_bound(cdf_begin, cdf_end, dist[i] * *(std::prev(cdf_end)));
output[i] = std::distance(cdf_begin, sample_iter);
});
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_NONZERO_HPP
#define MIGRAPHX_GUARD_OPERATORS_NONZERO_HPP
#include <migraphx/shape_for_each.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/par_for.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct nonzero
{
std::string name() const { return "nonzero"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
auto elem_num = inputs[0].elements();
auto dim_num = inputs[0].lens().size();
std::vector<std::size_t> out_lens = {dim_num, elem_num};
return {shape::int64_type, out_lens};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
std::vector<std::vector<std::size_t>> vec_idx;
auto s = args.front().get_shape();
args.front().visit([&](auto v) {
shape_for_each(s, [&](auto idx) {
if(not float_equal(v[s.index(idx)], 0))
{
vec_idx.push_back(idx);
}
});
});
argument result{output_shape};
result.visit([&](auto output) {
std::fill(output.begin(), output.end(), 0);
par_for(vec_idx.size(), [&](auto i) {
for(std::size_t j = 0; j < vec_idx.front().size(); ++j)
{
output[output_shape.index({j, i})] = vec_idx[i][j];
}
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment