Commit 4a39a0f7 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into add-conv_bn_add-test

parents 5564172e bb827865
#ifndef MIGRAPHX_GUARD_MIGRAPHX_LIFETIME_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_LIFETIME_HPP
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
enum class lifetime
{
local,
global,
borrow
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_LIFETIME_HPP
......@@ -5,7 +5,7 @@
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/module.hpp>
#include <migraphx/program.hpp>
#include <migraphx/optional.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/type_name.hpp>
#include <migraphx/config.hpp>
......@@ -19,24 +19,51 @@ namespace match {
struct matcher_context
{
matcher_context(instruction_ref i) : last(i) {}
matcher_context(module& m) : mod(&m) {}
std::unordered_map<std::string, instruction_ref> instructions;
instruction_ref not_found() const { return last; }
template <class M>
bool matched(M m, instruction_ref ins)
{
return m.match(*this, ins) != this->not_found();
return has_value(m.match(*this, ins));
}
template <class M>
auto lazy_match(M m, instruction_ref ins)
bool matched(M m, optional<instruction_ref> ins)
{
if(ins)
return has_value(m.match(*this, *ins));
return false;
}
template <class M, class I>
auto lazy_match(M m, I ins)
{
return [=] { return this->matched(m, ins); };
}
bool has_instruction(instruction_ref ins) const
{
if(mod == nullptr)
return true;
return mod->has_instruction(ins);
}
bool has_instruction(optional<instruction_ref> ins) const
{
if(ins)
return this->has_instruction(*ins);
return false;
}
bool is_last(instruction_ref ins) const
{
assert(mod->begin() != mod->end());
assert(this->has_instruction(ins));
return ins == std::prev(mod->end());
}
private:
instruction_ref last;
module* mod = nullptr;
};
/// Convert a predicate function into a matcher
......@@ -45,12 +72,11 @@ struct predicate_matcher
{
P p;
instruction_ref match(const matcher_context& ctx, instruction_ref ins) const
optional<instruction_ref> match(const matcher_context&, instruction_ref ins) const
{
assert(ins != ctx.not_found());
if(p(ins))
return ins;
return ctx.not_found();
return optional<instruction_ref>{ins};
return nullopt;
}
};
......@@ -60,11 +86,7 @@ struct function_matcher
{
F f;
instruction_ref match(matcher_context& ctx, instruction_ref ins) const
{
assert(ins != ctx.not_found());
return f(ctx, ins);
}
auto match(matcher_context& ctx, instruction_ref ins) const { return f(ctx, ins); }
};
/// Convert a function into a matcher
......@@ -79,12 +101,17 @@ template <class M>
auto bind_match(M m, std::string name)
{
return make_function_matcher(
[ =, name = std::move(name) ](matcher_context & ctx, instruction_ref ins) {
auto result = m.match(ctx, ins);
if(result != ctx.not_found())
ctx.instructions[name] = ins;
return result;
});
[ =, name = std::move(name) ](matcher_context & ctx, instruction_ref ins)
->optional<instruction_ref> {
auto result = m.match(ctx, ins);
if(result)
{
if(not ctx.has_instruction(ins))
return nullopt;
ctx.instructions[name] = ins;
}
return result;
});
}
/// Convert a matcher to a bindable matcher
......@@ -95,10 +122,7 @@ struct bindable_matcher
auto bind(std::string name) const { return bind_match(m, std::move(name)); }
instruction_ref match(matcher_context& ctx, instruction_ref ins) const
{
return m.match(ctx, ins);
}
auto match(matcher_context& ctx, instruction_ref ins) const { return m.match(ctx, ins); }
};
/// Create a bindable matcher
......@@ -126,7 +150,10 @@ using bool_list = std::initializer_list<bool>;
struct id_matcher
{
instruction_ref match(matcher_context&, instruction_ref ins) const { return ins; }
auto match(matcher_context&, instruction_ref ins) const
{
return optional<instruction_ref>{ins};
}
};
/// The basic matcher provides the all_of composability of the matcher
......@@ -140,26 +167,23 @@ struct basic_matcher
{
// Copy m because we cant capture `this` by value
auto mm = m;
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
return make_bf_matcher([=](matcher_context& ctx,
instruction_ref ins) -> optional<instruction_ref> {
auto result = mm.match(ctx, ins);
if(result != ctx.not_found())
if(result)
{
bool matches = fold([&](auto x, auto y) {
return x and y.match(ctx, result) != ctx.not_found();
})(true, ms...);
bool matches =
fold([&](auto x, auto y) { return x and ctx.matched(y, result); })(true, ms...);
if(matches)
return result;
}
return ctx.not_found();
return nullopt;
});
}
auto bind(std::string name) const { return bind_match(m, std::move(name)); }
instruction_ref match(matcher_context& ctx, instruction_ref ins) const
{
return m.match(ctx, ins);
}
auto match(matcher_context& ctx, instruction_ref ins) const { return m.match(ctx, ins); }
};
/// Create a basic matcher from a matcher
......@@ -185,7 +209,7 @@ basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p)
/// Create a typed-erased matcher
using any_matcher_base = basic_matcher<
function_matcher<std::function<instruction_ref(matcher_context&, instruction_ref)>>>;
function_matcher<std::function<optional<instruction_ref>(matcher_context&, instruction_ref)>>>;
struct any_matcher : any_matcher_base
{
template <class M>
......@@ -198,10 +222,10 @@ struct any_matcher : any_matcher_base
#define MIGRAPHX_BASIC_MATCHER(name, ...) \
struct name##_m \
{ \
instruction_ref match(__VA_ARGS__) const; \
optional<instruction_ref> match(__VA_ARGS__) const; \
}; \
const constexpr auto name = migraphx::match::basic_matcher<name##_m>{{}}; \
inline instruction_ref name##_m::match(__VA_ARGS__) const
inline optional<instruction_ref> name##_m::match(__VA_ARGS__) const
/// This macro takes care of the boilerplate for defining a predicate matcher
#define MIGRAPHX_PRED_MATCHER(name, ...) \
......@@ -221,21 +245,43 @@ struct matcher_result
/// Match a single instruction
template <class M>
matcher_result match_instruction(module& p, instruction_ref ins, M&& m)
matcher_result match_instruction(module& mod, instruction_ref ins, M&& m)
{
assert(ins != p.end());
assert(ins != mod.end());
assert(mod.has_instruction(ins));
matcher_context ctx{mod};
matcher_result result;
matcher_context ctx{p.end()};
result.result = m.match(ctx, ins);
result.instructions = ctx.instructions;
if(m.match(ctx, ins))
{
result.result = ins;
result.instructions = ctx.instructions;
}
else
{
result.result = mod.end();
}
return result;
}
/// Find first instance of a matching instruction in a module
template <class M>
match::matcher_result find_match(module& modl, M&& m)
{
match::matcher_result result;
for(auto ins : iterator_for(modl))
{
result = match::match_instruction(modl, ins, m);
if(result.result != modl.end())
return result;
}
return result;
}
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES)
/// Find matches for an instruction in the program
/// Find matches for an instruction in the module
template <class... Ms>
void find_matches(module& p, instruction_ref ins, Ms&&... ms)
void find_matches(module& mod, instruction_ref ins, Ms&&... ms)
{
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
const
......@@ -246,27 +292,27 @@ void find_matches(module& p, instruction_ref ins, Ms&&... ms)
[&](auto&& m) {
if(match)
return;
auto r = match_instruction(p, ins, m.matcher());
if(r.result == p.end())
auto r = match_instruction(mod, ins, m.matcher());
if(r.result == mod.end())
return;
if(trace)
{
std::cout << "Matched by " << get_type_name(m) << std::endl;
p.debug_print(ins);
mod.debug_print(ins);
}
m.apply(p, r);
m.apply(mod, r);
match = true;
},
ms...);
}
/// Find matches in a program
/// Find matches in a module
template <class... Ms>
void find_matches(module& p, Ms&&... ms)
void find_matches(module& mod, Ms&&... ms)
{
for(auto ins : iterator_for(p))
for(auto ins : iterator_for(mod))
{
find_matches(p, ins, ms...);
find_matches(mod, ins, ms...);
}
}
......@@ -339,12 +385,13 @@ struct match_fold_f
template <class... Ts>
auto operator()(Ts... ms) const
{
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
bool matches = match_fold_f::fold_matchers(ctx, ins, ms...);
if(matches == Matches)
return ins;
return ctx.not_found();
});
return make_bf_matcher(
[=](matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
bool matches = match_fold_f::fold_matchers(ctx, ins, ms...);
if(matches == Matches)
return {ins};
return nullopt;
});
}
template <class Selector>
......@@ -353,17 +400,18 @@ struct match_fold_f
return [=](auto... ms) {
// Workaround ICE on gcc by packing matchers into an object
auto mpack = pack(ms...);
return make_bf_matcher([=](matcher_context& ctx, instruction_ref start) {
Op op;
bool matches = Start;
select(start, [&](auto ins) {
auto fm = [&] { return match_fold_f::fold_matchers_pack(ctx, ins, mpack); };
matches = op(always(matches), fm)();
return make_bf_matcher(
[=](matcher_context& ctx, instruction_ref start) -> optional<instruction_ref> {
Op op;
bool matches = Start;
select(start, [&](auto ins) {
auto fm = [&] { return match_fold_f::fold_matchers_pack(ctx, ins, mpack); };
matches = op(always(matches), fm)();
});
if(matches == Matches)
return {start};
return nullopt;
});
if(matches == Matches)
return start;
return ctx.not_found();
});
};
}
};
......@@ -420,64 +468,29 @@ MIGRAPHX_PRED_MATCHER(same_input_shapes, instruction_ref ins)
ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x->get_shape() == s; });
}
MIGRAPHX_BASIC_MATCHER(output, const matcher_context& ctx, instruction_ref ins)
MIGRAPHX_BASIC_MATCHER(output, const matcher_context&, instruction_ref ins)
{
if(ins->outputs().size() == 1)
return ins->outputs().front();
return ctx.not_found();
return {ins->outputs().front()};
return nullopt;
}
MIGRAPHX_BASIC_MATCHER(used_once, const matcher_context& ctx, instruction_ref ins)
{
if(ins->outputs().size() == 1)
return ins;
if(ins->outputs().empty() and std::next(ins) == ctx.not_found())
return ins;
return ctx.not_found();
}
inline auto used_once_recursive(std::size_t depth)
{
return make_basic_fun_matcher([=](const matcher_context& ctx, instruction_ref start) {
// Used once
if(start->outputs().size() == 1)
return start;
// Unused
if(start->outputs().empty())
{
if(std::next(start) == ctx.not_found())
return start;
else
return ctx.not_found();
}
// Check for dead instructions
auto is_dead = fix<bool>([&](auto self, auto ins, auto n) {
if(n == 0)
return false;
if(ins->get_shape().elements() == 0)
return false;
if(ins->outputs().empty() and std::next(ins) != ctx.not_found())
return true;
return std::all_of(ins->outputs().begin(), ins->outputs().end(), [&](auto i) {
return self(i, n - 1);
});
});
auto dead = std::count_if(start->outputs().begin(), start->outputs().end(), [&](auto i) {
return is_dead(i, depth);
});
if(dead + 1 == start->outputs().size())
return start;
return ctx.not_found();
});
return {ins};
if(ins->outputs().empty() and ctx.is_last(ins))
return {ins};
return nullopt;
}
MIGRAPHX_PRED_MATCHER(is_constant, instruction_ref ins) { return ins->can_eval(); }
MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref ins)
{
if(ins->outputs().empty() and ins != std::prev(ctx.not_found()))
return ins;
return ctx.not_found();
if(ins->outputs().empty() and not ctx.is_last(ins))
return {ins};
return nullopt;
}
template <class... Ms>
......@@ -485,14 +498,15 @@ auto skip(Ms... ms)
{
auto m = any_of(ms...);
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) {
return fix<instruction_ref>([&](auto self, auto ins) {
if(ins->inputs().size() == 1 and ctx.matched(m, ins))
{
auto next = ins->inputs().front();
return self(next);
}
return ins;
})(start);
return fix<optional<instruction_ref>>(
[&](auto self, auto ins) -> optional<instruction_ref> {
if(ins->inputs().size() == 1 and ctx.matched(m, ins))
{
auto next = ins->inputs().front();
return self(next);
}
return ins;
})(start);
});
}
......@@ -501,20 +515,21 @@ auto skip_output(Ms... ms)
{
auto m = any_of(ms...);
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) {
return fix<instruction_ref>([&](auto self, auto ins) {
if(ins->outputs().size() == 1)
{
auto next = ins->outputs().front();
if(ctx.matched(m, next))
return fix<optional<instruction_ref>>(
[&](auto self, auto ins) -> optional<instruction_ref> {
if(ins->outputs().size() == 1)
{
auto skipped_next = self(next);
if(skipped_next != ctx.not_found())
return skipped_next;
auto next = ins->outputs().front();
if(ctx.matched(m, next))
{
auto skipped_next = self(next);
if(skipped_next)
return skipped_next;
}
return next;
}
return next;
}
return ctx.not_found();
})(start);
return nullopt;
})(start);
});
}
......@@ -550,11 +565,12 @@ inline auto nargs(std::size_t n)
inline auto arg(std::size_t i)
{
return make_basic_fun_matcher([=](const matcher_context& ctx, instruction_ref ins) {
if(i < ins->inputs().size())
return ins->inputs()[i];
return ctx.not_found();
});
return make_basic_fun_matcher(
[=](const matcher_context&, instruction_ref ins) -> optional<instruction_ref> {
if(i < ins->inputs().size())
return ins->inputs()[i];
return nullopt;
});
}
// Workaround for bugs in clang
......@@ -616,52 +632,56 @@ std::size_t tree_leafs_impl(matcher_context& ctx,
template <class M, class... Ms>
auto tree(M main_op, Ms... ms)
{
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) {
// Flatten leaf nodes
std::array<instruction_ref, sizeof...(Ms)> leafs;
std::size_t idx = tree_leafs_impl(ctx, leafs, main_op, ins);
if(idx != leafs.size())
return ctx.not_found();
// Use explicit captures to workaround ICE on gcc
bool found = sequence_c<sizeof...(Ms)>([&ms..., &ctx, &leafs](auto... is) {
return fold(lazy_and{})(ctx.lazy_match(ms, leafs[is])...)();
return make_basic_fun_matcher(
[=](matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
// Flatten leaf nodes
std::array<instruction_ref, sizeof...(Ms)> leafs;
std::size_t idx = tree_leafs_impl(ctx, leafs, main_op, ins);
if(idx != leafs.size())
return nullopt;
// Use explicit captures to workaround ICE on gcc
// Capture by value to workaround compile error on gcc 9
bool found = sequence_c<sizeof...(Ms)>([ms..., &ctx, &leafs](auto... is) {
return fold(lazy_and{})(ctx.lazy_match(ms, leafs[is])...)();
});
if(not found)
return nullopt;
return ins;
});
if(not found)
return ctx.not_found();
return ins;
});
}
template <class M, class... Ms>
auto unordered_tree(M main_op, Ms... ms)
{
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) {
// Flatten leaf nodes
std::array<instruction_ref, sizeof...(Ms)> leafs;
std::size_t idx = tree_leafs_impl(ctx, leafs, main_op, ins);
if(idx != leafs.size())
return ctx.not_found();
// Use explicit captures to workaround ICE on gcc
bool found = sequence_c<sizeof...(Ms)>([ms..., &ctx, &leafs](auto... is) {
return by(fold(lazy_and{}), [is..., &ctx, &leafs](auto m) {
return fold(lazy_or{})(ctx.lazy_match(m, leafs[is])...);
})(ms...)();
return make_basic_fun_matcher(
[=](matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
// Flatten leaf nodes
std::array<instruction_ref, sizeof...(Ms)> leafs;
std::size_t idx = tree_leafs_impl(ctx, leafs, main_op, ins);
if(idx != leafs.size())
return nullopt;
// Use explicit captures to workaround ICE on gcc
bool found = sequence_c<sizeof...(Ms)>([ms..., &ctx, &leafs](auto... is) {
return by(fold(lazy_and{}), [is..., &ctx, &leafs](auto m) {
return fold(lazy_or{})(ctx.lazy_match(m, leafs[is])...);
})(ms...)();
});
if(not found)
return nullopt;
return ins;
});
if(not found)
return ctx.not_found();
return ins;
});
}
template <class M>
auto same_shape(M m)
{
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) {
auto i = m.match(ctx, ins);
if(i != ctx.not_found() and i->get_shape() == ins->get_shape())
return ins;
return ctx.not_found();
});
return make_basic_fun_matcher(
[=](matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
auto i = m.match(ctx, ins);
if(i and (*i)->get_shape() == ins->get_shape())
return ins;
return nullopt;
});
}
template <class... Ms>
......
......@@ -46,6 +46,9 @@ struct module
std::string name() const;
bool bypass() const;
void set_bypass(bool b = true);
template <class... Ts, MIGRAPHX_REQUIRES(std::is_same<Ts, instruction_ref>{}...)>
instruction_ref add_instruction(operation op, Ts... args)
{
......
......@@ -18,6 +18,8 @@ struct onnx_options
bool skip_unknown_operators = false;
/// Print program if an error occurs
bool print_program_on_error = false;
/// Max iter num for the loop operator
int64_t max_loop_iterations = 10;
};
/// Create a program from an onnx file
......@@ -29,6 +31,8 @@ program parse_onnx_buffer(const std::string& buffer, const onnx_options& options
/// Create a program from an onnx buffer
program parse_onnx_buffer(const void* data, std::size_t size, const onnx_options& options);
std::vector<std::string> get_onnx_operators();
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -8,6 +8,7 @@
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <utility>
......@@ -35,7 +36,7 @@ struct as_shape
{
return args.front().reshape(output_shape);
}
bool is_borrowed() const { return true; }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -6,6 +6,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <utility>
......@@ -29,7 +30,7 @@ struct broadcast
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"), f(self.broadcast_lens, "dims"));
return pack(f(self.axis, "axis"), f(self.broadcast_lens, "out_lens"));
}
std::string name() const { return "broadcast"; }
......@@ -66,7 +67,7 @@ struct broadcast
{
return args[0].reshape(output_shape);
}
bool is_borrowed() const { return true; }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -8,6 +8,7 @@
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/context.hpp>
#include <cmath>
#include <utility>
......@@ -29,7 +30,9 @@ struct capture
shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); }
argument compute(const shape&, std::vector<argument> args) const
// the context argument is added to prevent the op from be eliminated by
// constant propagation
argument compute(context&, const shape&, const std::vector<argument>& args) const
{
if(f)
{
......@@ -42,6 +45,8 @@ struct capture
return args.front();
}
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
} // namespace op
......
......@@ -9,6 +9,8 @@
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
......@@ -39,25 +41,30 @@ struct convolution
void check_attribute_size() const
{
if(not(padding.size() == stride.size() and padding.size() == dilation.size()))
if(not((padding.size() == stride.size() or (padding.size() / 2) == stride.size()) and
stride.size() == dilation.size()))
{
MIGRAPHX_THROW("CONVOLUTION: inconsistent attribute sizes");
}
}
shape compute_shape(std::vector<shape> inputs) const
value attributes() const { return {{"normalize_padding", "padding"}}; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2).same_type().same_ndims().min_ndims(3);
check_attribute_size();
// dim num of input and attribute should match
if(inputs[0].lens().size() != padding.size() + 2)
auto input_size = inputs[0].lens().size();
auto padding_size = padding.size();
if(not(input_size == padding_size / 2 + 2 or input_size == padding_size + 2))
{
MIGRAPHX_THROW("CONVOLUTION: input and attribute size mismatch!");
}
const shape& input = inputs.at(0);
const shape& weights = inputs.at(1);
size_t kdims = input.lens().size() - 2;
size_t kdims = input_size - 2;
if(kdims != this->kdims())
{
MIGRAPHX_THROW("convolution: input k-dims does not match attribute size");
......@@ -70,10 +77,13 @@ struct convolution
for(size_t i = 0; i < kdims; i++)
{
auto padding_factor = 2 * padding[i];
if(padding_size == 2 * kdims)
padding_factor = padding[i] + padding[i + kdims];
output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[i + 2] - (1 + dilation[i] * (weights.lens()[i + 2] - 1)) +
2 * padding[i]) /
padding_factor) /
stride[i] +
1)));
}
......@@ -84,7 +94,7 @@ struct convolution
size_t kdims() const
{
check_attribute_size();
return padding.size();
return stride.size();
}
};
......
......@@ -9,6 +9,8 @@
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/par_dfor.hpp>
#include <cmath>
#include <utility>
......@@ -39,7 +41,8 @@ struct deconvolution
void check_attribute_size() const
{
if(not(padding.size() == stride.size() and padding.size() == dilation.size()))
if(not((padding.size() == stride.size() or (padding.size() / 2) == stride.size()) and
stride.size() == dilation.size()))
{
MIGRAPHX_THROW("deconvolution: inconsistent attribute sizes");
}
......@@ -69,10 +72,85 @@ struct deconvolution
return inputs[0].with_lens(output_lens);
}
argument compute(shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto kdims = this->kdims();
visit_all(result, args[0], args[1])([&](auto output, auto input, auto weights) {
using type = typename decltype(output)::value_type;
std::fill(output.begin(), output.end(), type{0});
auto in_lens = input.get_shape().lens();
auto in_n = in_lens[0];
auto in_c = in_lens[1];
auto wei = weights.get_shape().lens();
auto wei_n = wei[0];
auto wei_c = wei[1];
auto out_lens = output_shape.lens();
std::vector<std::size_t> win_size{in_c};
std::copy(in_lens.begin() + 2, in_lens.end(), std::back_inserter(win_size));
std::copy(wei.begin() + 2, wei.end(), std::back_inserter(win_size));
shape win_shape{output_shape.type(), win_size};
par_dfor(in_n, wei_c)([&](int o, int k) {
shape_for_each(win_shape, [&](auto idx_win) {
const int w = idx_win[0];
auto input_dims_start = idx_win.begin() + 1;
auto wei_dims_start = idx_win.begin() + kdims + 1;
std::vector<std::ptrdiff_t> win_start;
for(std::size_t n = 0; n < kdims; ++n)
{
win_start.push_back(std::ptrdiff_t(*(input_dims_start + n) * stride[n]) -
std::ptrdiff_t(padding[n]));
}
const int group_id = w / (wei_n / group);
const int in_ch = group_id * wei_c + k;
std::vector<std::ptrdiff_t> idx_out{o, in_ch};
for(size_t n = 0; n < kdims; n++)
{
idx_out.push_back(win_start[n] + *(wei_dims_start + n) * dilation[n]);
}
std::vector<std::ptrdiff_t> idx_wei{w, k};
std::copy(wei_dims_start, idx_win.end(), std::back_inserter(idx_wei));
std::vector<std::ptrdiff_t> idx_in{o, w};
std::copy(input_dims_start, wei_dims_start, std::back_inserter(idx_in));
if(std::all_of(
idx_out.begin() + 2, idx_out.end(), [&](auto ii) { return ii >= 0; }) and
std::equal(idx_out.begin() + 2,
idx_out.end(),
out_lens.begin() + 2,
out_lens.end(),
std::less<std::ptrdiff_t>{}))
{
output(idx_out.begin(), idx_out.end()) +=
input(idx_in.begin(), idx_in.end()) *
weights(idx_wei.begin(), idx_wei.end());
}
});
});
});
return result;
}
size_t kdims() const
{
check_attribute_size();
return padding.size();
return stride.size();
}
};
......
#ifndef MIGRAPHX_GUARD_OPERATORS_DEQUANTIZE_LINEAR_HPP
#define MIGRAPHX_GUARD_OPERATORS_DEQUANTIZE_LINEAR_HPP
#include <array>
#include <migraphx/op/common.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/config.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/tune_axis.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct dequantizelinear
{
std::string name() const { return "dequantizelinear"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.same_dims();
return {inputs[1].type(), inputs[0].lens(), inputs[0].strides()};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
auto x = args.at(0);
auto x_scale = args.at(1);
std::vector<int8_t> zeros(output_shape.bytes(), 0);
argument x_zero_point{{x.get_shape().type(), output_shape.lens()}, zeros.data()};
if(args.size() == 3)
{
x_zero_point = args.at(2);
}
argument result{output_shape};
visit_all(x, x_zero_point)([&](auto input, auto zero_pts) {
visit_all(result, x_scale)([&](auto output, auto scales) {
par_for(output_shape.elements(), [&](auto i) {
output[i] = static_cast<double>(static_cast<int64_t>(input[i]) -
static_cast<int64_t>(zero_pts[i])) *
scales[i];
});
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -18,19 +18,10 @@ namespace op {
struct dot
{
float alpha = 1.0;
float beta = 1.0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.alpha, "alpha"), f(self.beta, "beta"));
}
std::string name() const { return "dot"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.same_type();
check_shapes{inputs, *this}.same_type().has(2);
const shape& a = inputs.at(0);
const shape& b = inputs.at(1);
auto t = a.type();
......@@ -58,25 +49,14 @@ struct dot
auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1];
if(inputs.size() == 3 && out_lens != inputs.at(2).lens())
{
MIGRAPHX_THROW("DOT: dimension mismatch, operand C: {" +
to_string_range(inputs.at(2).lens()) +
"}, cannot add to operand A * B: {" + to_string_range(out_lens) + "}");
}
return {t, out_lens};
}
argument compute(shape output_shape, std::vector<argument> args) const
{
argument result;
if(args.size() == 3)
result = args[2];
else
result = argument{output_shape};
argument result = argument{output_shape};
visit_all(result, args[0], args[1])(
[&](auto cmat, auto amat, auto bmat) { gemm(cmat, amat, bmat, alpha, beta); });
[&](auto cmat, auto amat, auto bmat) { gemm(cmat, amat, bmat, 1.0f, 0.0f); });
return result;
}
};
......
......@@ -10,6 +10,7 @@
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <utility>
......@@ -38,7 +39,7 @@ struct flatten
std::string name() const { return "flatten"; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
check_shapes{inputs, *this}.has(1).standard();
auto&& lens = inputs.front().lens();
auto x =
std::accumulate(lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<>{});
......@@ -50,7 +51,7 @@ struct flatten
{
return args[0].reshape(output_shape);
}
bool is_borrowed() const { return true; }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
#ifndef MIGRAPHX_GUARD_OPERATORS_GET_TUPLE_ELEM_HPP
#define MIGRAPHX_GUARD_OPERATORS_GET_TUPLE_ELEM_HPP
#include "migraphx/errors.hpp"
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct get_tuple_elem
{
std::size_t index = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.index, "index"));
}
std::string name() const { return "get_tuple_elem"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).tuple_type();
const auto& sub_shapes = inputs.at(0).sub_shapes();
if(index >= sub_shapes.size())
{
MIGRAPHX_THROW("GET_TUPLE_ELEM: index " + std::to_string(index) + " is out of range " +
std::to_string(sub_shapes.size()));
}
return sub_shapes.at(index);
}
argument compute(const shape&, std::vector<argument> args) const
{
assert(args.size() == 1);
auto vec_args = args.at(0).get_sub_objects();
assert(index < vec_args.size());
return vec_args.at(index);
}
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -35,14 +35,14 @@ struct if_op
MIGRAPHX_THROW("IF: output shapes of submodules must be the same.");
}
return out_shapes0.front();
return shape(out_shapes0);
}
argument compute(
const std::vector<argument>& args,
const std::vector<module_ref>& mods,
const std::function<std::vector<argument>(
module_ref& mdl, const std::unordered_map<std::string, argument>& inputs)>& run) const
argument compute(const shape&,
const std::vector<argument>& args,
const std::vector<module_ref>& mods,
const std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)>& run) const
{
auto cond = args.front().at<bool>();
module_ref mod = cond ? mods[0] : mods[1];
......@@ -63,7 +63,7 @@ struct if_op
[](auto&& name, auto&& arg) { return std::make_pair(name, arg); });
auto results = run(mod, params);
return results[0];
return argument{results};
}
};
......
......@@ -31,7 +31,9 @@ struct im2col
std::string name() const { return "im2col"; }
shape compute_shape(std::vector<shape> inputs) const
value attributes() const { return {{"normalize_padding", "padding"}}; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{
auto input = inputs[0];
auto weights = inputs[1];
......@@ -42,17 +44,24 @@ struct im2col
check_shapes{inputs, *this}.has(2);
if(batch_size != 1)
MIGRAPHX_THROW("im2col only support batch_size 1");
auto padding_h = 2 * padding[0];
auto padding_w = 2 * padding[1];
if(padding.size() == 2 * stride.size())
{
padding_h = padding[0] + padding[2];
padding_w = padding[1] + padding[3];
}
auto output_height = std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[2] - (1 + dilation[0] * (kernel_height - 1)) + 2 * padding[0]) /
stride[0] +
(input.lens()[2] - (1 + dilation[0] * (kernel_height - 1)) + padding_h) / stride[0] +
1));
auto output_width = std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[3] - (1 + dilation[1] * (kernel_width - 1)) + 2 * padding[1]) /
stride[1] +
(input.lens()[3] - (1 + dilation[1] * (kernel_width - 1)) + padding_w) / stride[1] +
1));
auto channels_col = kernel_height * kernel_width * input_channels;
auto channels_col = kernel_height * kernel_width * input_channels;
return {input.type(), {output_height * output_width, channels_col}};
}
};
......
......@@ -6,6 +6,7 @@
#include <migraphx/functional.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <utility>
......@@ -34,9 +35,9 @@ struct load
{
if((offset + s.bytes()) > args[0].get_shape().bytes())
MIGRAPHX_THROW("Load access is out of bounds");
return argument::load(s, args[0].data() + offset);
return argument{s, args[0].data() + offset};
}
bool is_borrowed() const { return true; }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
friend std::ostream& operator<<(std::ostream& os, const load& op)
......
#ifndef MIGRAPHX_GUARD_OPERATORS_LOOP_HPP
#define MIGRAPHX_GUARD_OPERATORS_LOOP_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
#include <migraphx/module.hpp>
#include <migraphx/run_loop.hpp>
#include <migraphx/ranges.hpp>
#include <cmath>
#include <string>
#include <utility>
#include <set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct loop
{
int64_t max_iterations = 10;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.max_iterations, "max_iterations"));
}
std::string name() const { return "loop"; }
shape compute_shape(const std::vector<shape>& inputs, std::vector<module_ref> mods) const
{
check_shapes{inputs, *this}.standard();
if(mods.size() != 1)
{
MIGRAPHX_THROW("LOOP: operator should have one submodule.");
}
const auto& mod = mods.front();
auto mod_out_shapes = mod->get_output_shapes();
auto dep_param_num = inputs.size() - 2;
// first item of the mod output shapes is condition used in loop,
// which is not needed to compute output shape
mod_out_shapes.erase(mod_out_shapes.begin());
std::vector<shape> ins_out_shapes(mod_out_shapes.begin(),
mod_out_shapes.begin() + dep_param_num);
mod_out_shapes.erase(mod_out_shapes.begin(), mod_out_shapes.begin() + dep_param_num);
for(const auto& out_s : mod_out_shapes)
{
auto lens = out_s.lens();
lens.insert(lens.begin(), max_iterations);
ins_out_shapes.push_back({out_s.type(), lens});
}
return shape(ins_out_shapes);
}
struct ref_loop
{
int64_t max_iterations = 0;
template <class T>
void copy(context&, const argument& src, T& dst) const
{
dst = *src.cast<T>();
}
template <class T>
void copy(context&, T src, const argument& dst) const
{
*dst.cast<T>() = src;
}
void append(const std::vector<argument>& iter_state,
const std::vector<argument>& concatenated_outputs,
int iter) const
{
assert(iter_state.size() == concatenated_outputs.size());
for(auto i : range(iter_state.size()))
{
const auto& iter_stat = iter_state.at(i);
const auto& scan_out = concatenated_outputs.at(i);
auto* in_data = iter_stat.data();
auto* out_data = scan_out.data();
std::size_t out_size = iter_stat.get_shape().bytes();
assert((iter + 1) * out_size <= scan_out.get_shape().bytes());
std::copy(in_data, in_data + out_size, out_data + iter * out_size);
}
}
void set_zero(context&, const std::vector<argument>& concatenated_outputs, int iter) const
{
if(iter >= max_iterations)
return;
for(const auto& out : concatenated_outputs)
{
auto s = out.get_shape();
auto size = s.bytes() / max_iterations;
std::fill(out.data() + iter * size, out.data() + max_iterations * size, 0);
}
}
std::unordered_map<std::string, int> get_output_params(const module&) const { return {}; }
};
argument compute(context& ctx,
const shape& out_shape,
const std::vector<argument>& args,
const std::vector<module_ref>& mods,
const std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)>& run) const
{
// wrap up the arguments vector, so ref and gpu impl are the same
auto cpy_args = args;
bool in_cond = args.at(1).at<bool>();
bool cond = in_cond;
int64_t iter = 0;
// insert iter and cond used in the loop
auto s_cond = args.at(1).get_shape();
auto s_iter = args.at(0).get_shape();
cpy_args.push_back({s_iter, &iter});
cpy_args.push_back({s_cond, &cond});
cpy_args.insert(cpy_args.end(), args.begin() + 2, args.end());
// add cond and mod outputs to the argument list
cpy_args.push_back(argument(s_cond));
cpy_args.push_back(argument(out_shape));
// run loop
return run_loop(ref_loop{max_iterations}, ctx, cpy_args, mods, run);
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -8,6 +8,7 @@
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <utility>
......@@ -22,7 +23,7 @@ struct multibroadcast
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.output_lens, "output_lens"));
return pack(f(self.output_lens, "out_lens"));
}
std::string name() const { return "multibroadcast"; }
......@@ -68,7 +69,7 @@ struct multibroadcast
{
return args[0].reshape(output_shape);
}
bool is_borrowed() const { return true; }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
#ifndef MIGRAPHX_GUARD_OPERATORS_MULTINOMIAL_HPP
#define MIGRAPHX_GUARD_OPERATORS_MULTINOMIAL_HPP
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/par_for.hpp>
#include <random>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct multinomial
{
shape::type_t dtype = shape::type_t::int32_type;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.dtype, "dtype"));
}
std::string name() const { return "multinomial"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2).only_dims(2);
size_t sample_size = inputs.back().lens().back();
if(not contains({shape::int32_type, shape::int64_type}, dtype))
MIGRAPHX_THROW(
"Multinomial: Invalid output type. Valid types are int32_type and int64_type.");
return {dtype, {inputs.front().lens().front(), sample_size}};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
size_t batch_size = output_shape.lens().front();
size_t class_size = args[0].get_shape().lens().back();
size_t sample_size = output_shape.lens().back();
visit_all(args[0], args[1])([&](auto cdf, auto dist) {
result.visit([&](auto output) {
par_for(batch_size * sample_size, [&](auto i) {
auto idx = args[1].get_shape().multi(i);
auto cdf_begin = cdf.begin() + (idx[0] * class_size);
auto cdf_end = cdf_begin + class_size;
auto sample_iter =
std::upper_bound(cdf_begin, cdf_end, dist[i] * *(std::prev(cdf_end)));
output[i] = std::distance(cdf_begin, sample_iter);
});
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_NONZERO_HPP
#define MIGRAPHX_GUARD_OPERATORS_NONZERO_HPP
#include <migraphx/shape_for_each.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/par_for.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct nonzero
{
std::string name() const { return "nonzero"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
auto elem_num = inputs[0].elements();
auto dim_num = inputs[0].lens().size();
std::vector<std::size_t> out_lens = {dim_num, elem_num};
return {shape::int64_type, out_lens};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
std::vector<std::vector<std::size_t>> vec_idx;
auto s = args.front().get_shape();
args.front().visit([&](auto v) {
shape_for_each(s, [&](auto idx) {
if(not float_equal(v[s.index(idx)], 0))
{
vec_idx.push_back(idx);
}
});
});
argument result{output_shape};
result.visit([&](auto output) {
std::fill(output.begin(), output.end(), 0);
par_for(vec_idx.size(), [&](auto i) {
for(std::size_t j = 0; j < vec_idx.front().size(); ++j)
{
output[output_shape.index({j, i})] = vec_idx[i][j];
}
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment