Commit 20b1d690 authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into tests

parents 17aaaa1e ba729cfc
#ifndef MIGRAPHX_GUARD_RTGLIB_ARRAY_HPP
#define MIGRAPHX_GUARD_RTGLIB_ARRAY_HPP
#include <migraphx/config.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/requires.hpp>
#include <type_traits>
#include <array>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace detail {
template <class R, class...>
struct array_type
{
using type = R;
};
template <class... Ts>
struct array_type<void, Ts...> : std::common_type<Ts...>
{
};
template <class R, class... Ts>
using array_type_t = typename array_type<R, Ts...>::type;
template <class T, std::size_t N, std::size_t... I>
constexpr std::array<std::remove_cv_t<T>, N> to_array_impl(T (&a)[N], seq<I...>)
{
return {{a[I]...}};
}
} // namespace detail
template <class Result = void, class... Ts, MIGRAPHX_REQUIRES((sizeof...(Ts) > 0))>
constexpr std::array<detail::array_type_t<Result, Ts...>, sizeof...(Ts)> make_array(Ts&&... xs)
{
return {static_cast<detail::array_type_t<Result, Ts...>>(std::forward<Ts>(xs))...};
}
constexpr std::array<int, 0> make_array() { return {}; }
template <class T, std::size_t N>
constexpr auto to_array(T (&a)[N])
{
return detail::to_array_impl(a, detail::gens<N>{});
}
namespace detail {
template <std::size_t Offset = 0, class Array, std::size_t... I>
constexpr auto rearray_impl(Array a, seq<I...>)
{
return make_array(a[I + Offset]...);
}
} // namespace detail
template <class T, std::size_t N>
constexpr auto pop_front(std::array<T, N> a)
{
return detail::rearray_impl(a, detail::gens<N - 1>{});
}
template <class T, std::size_t N>
constexpr auto pop_back(std::array<T, N> a)
{
return detail::rearray_impl<1>(a, detail::gens<N - 1>{});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -13,9 +13,9 @@ struct program; ...@@ -13,9 +13,9 @@ struct program;
/** /**
* Remove identical instructions. * Remove identical instructions.
*/ */
struct common_subexpression_elimination struct eliminate_common_subexpression
{ {
std::string name() const { return "common_subexpression_elimination"; } std::string name() const { return "eliminate_common_subexpression"; }
void apply(program& p) const; void apply(program& p) const;
}; };
......
...@@ -15,6 +15,12 @@ struct swallow ...@@ -15,6 +15,12 @@ struct swallow
} }
}; };
template <class T>
auto tuple_size(const T&)
{
return typename std::tuple_size<T>::type{};
}
namespace detail { namespace detail {
template <class R, class F> template <class R, class F>
...@@ -83,6 +89,12 @@ constexpr auto sequence_c(F&& f) ...@@ -83,6 +89,12 @@ constexpr auto sequence_c(F&& f)
return detail::sequence_c_impl(f, detail::gens<N>{}); return detail::sequence_c_impl(f, detail::gens<N>{});
} }
template <class IntegerConstant, class F>
constexpr auto sequence(IntegerConstant ic, F&& f)
{
return sequence_c<ic>(f);
}
template <class F, class... Ts> template <class F, class... Ts>
constexpr void each_args(F f, Ts&&... xs) constexpr void each_args(F f, Ts&&... xs)
{ {
...@@ -95,9 +107,9 @@ constexpr void each_args(F) ...@@ -95,9 +107,9 @@ constexpr void each_args(F)
} }
template <class F, class T> template <class F, class T>
auto unpack(F f, T& x) auto unpack(F f, T&& x)
{ {
return sequence_c<std::tuple_size<T>{}>([&](auto... is) { f(std::get<is>(x)...); }); return sequence(tuple_size(x), [&](auto... is) { f(std::get<is>(static_cast<T&&>(x))...); });
} }
/// Implements a fix-point combinator /// Implements a fix-point combinator
...@@ -149,6 +161,52 @@ auto index_of(T& x) ...@@ -149,6 +161,52 @@ auto index_of(T& x)
return [&](auto&& y) { return x[y]; }; return [&](auto&& y) { return x[y]; };
} }
template <class T, class... Ts>
decltype(auto) front_args(T&& x, Ts&&...)
{
return static_cast<T&&>(x);
}
template <class... Ts>
decltype(auto) back_args(Ts&&... xs)
{
return std::get<sizeof...(Ts) - 1>(std::tuple<Ts&&...>(static_cast<Ts&&>(xs)...));
}
template <class T, class... Ts>
auto pop_front_args(T&&, Ts&&... xs)
{
return [&](auto f) { f(static_cast<Ts&&>(xs)...); };
}
template <class... Ts>
auto pop_back_args(Ts&&... xs)
{
return [&](auto f) {
using tuple_type = std::tuple<Ts&&...>;
auto t = tuple_type(static_cast<Ts&&>(xs)...);
return sequence_c<sizeof...(Ts) - 1>(
[&](auto... is) { return f(std::get<is>(static_cast<tuple_type&&>(t))...); });
};
}
template <class T>
struct always_f
{
T x;
template <class... Ts>
constexpr T operator()(Ts&&...) const
{
return x;
}
};
template <class T>
auto always(T x)
{
return always_f<T>{x};
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -25,7 +25,7 @@ constexpr T normalize(unsigned long z) ...@@ -25,7 +25,7 @@ constexpr T normalize(unsigned long z)
template <class T, MIGRAPHX_REQUIRES(is_signed<T>{} and not is_floating_point<T>{})> template <class T, MIGRAPHX_REQUIRES(is_signed<T>{} and not is_floating_point<T>{})>
constexpr T normalize(unsigned long z) constexpr T normalize(unsigned long z)
{ {
const auto max = std::numeric_limits<T>::max(); const auto max = std::numeric_limits<T>::max() / 64;
const auto half_max = max / 2; const auto half_max = max / 2;
return half_max - (z % max); return half_max - (z % max);
} }
...@@ -33,7 +33,7 @@ constexpr T normalize(unsigned long z) ...@@ -33,7 +33,7 @@ constexpr T normalize(unsigned long z)
template <class T, MIGRAPHX_REQUIRES(not is_signed<T>{} and std::is_integral<T>{})> template <class T, MIGRAPHX_REQUIRES(not is_signed<T>{} and std::is_integral<T>{})>
constexpr T normalize(unsigned long z) constexpr T normalize(unsigned long z)
{ {
const auto max = std::numeric_limits<T>::max(); const auto max = std::numeric_limits<T>::max() / 64;
return z % max; return z % max;
} }
...@@ -87,6 +87,16 @@ std::vector<T> generate_tensor_data(const migraphx::shape& s, unsigned long seed ...@@ -87,6 +87,16 @@ std::vector<T> generate_tensor_data(const migraphx::shape& s, unsigned long seed
return result; return result;
} }
template <class T>
std::vector<T> fill_tensor_data(const migraphx::shape& s, unsigned long value = 0)
{
std::vector<T> result(s.elements());
std::generate(result.begin(), result.end(), [=] { return value; });
return result;
}
argument fill_argument(shape s, unsigned long value = 0);
argument generate_argument(shape s, unsigned long seed = 0); argument generate_argument(shape s, unsigned long seed = 0);
literal generate_literal(shape s, unsigned long seed = 0); literal generate_literal(shape s, unsigned long seed = 0);
......
...@@ -79,6 +79,7 @@ struct literal : raw_data<literal> ...@@ -79,6 +79,7 @@ struct literal : raw_data<literal>
template <class Iterator> template <class Iterator>
void fill(Iterator start, Iterator end) void fill(Iterator start, Iterator end)
{ {
assert(std::distance(start, end) == m_shape.elements());
if(m_shape.standard()) if(m_shape.standard())
{ {
m_shape.visit_type([&](auto as) { std::copy(start, end, as.from(buffer.get())); }); m_shape.visit_type([&](auto as) { std::copy(start, end, as.from(buffer.get())); });
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -20,6 +21,12 @@ struct matcher_context ...@@ -20,6 +21,12 @@ struct matcher_context
std::unordered_map<std::string, instruction_ref> instructions; std::unordered_map<std::string, instruction_ref> instructions;
instruction_ref not_found() const { return last; } instruction_ref not_found() const { return last; }
template <class M>
bool matched(M m, instruction_ref ins)
{
return m.match(*this, ins) != this->not_found();
}
private: private:
instruction_ref last; instruction_ref last;
}; };
...@@ -67,7 +74,7 @@ auto bind_match(M m, std::string name) ...@@ -67,7 +74,7 @@ auto bind_match(M m, std::string name)
[ =, name = std::move(name) ](matcher_context & ctx, instruction_ref ins) { [ =, name = std::move(name) ](matcher_context & ctx, instruction_ref ins) {
auto result = m.match(ctx, ins); auto result = m.match(ctx, ins);
if(result != ctx.not_found()) if(result != ctx.not_found())
ctx.instructions.emplace(name, ins); ctx.instructions[name] = ins;
return result; return result;
}); });
} }
...@@ -205,12 +212,10 @@ matcher_result match_instruction(program& p, instruction_ref ins, M&& m) ...@@ -205,12 +212,10 @@ matcher_result match_instruction(program& p, instruction_ref ins, M&& m)
return result; return result;
} }
/// Find matches in a program /// Find matches for an instruction in the program
template <class... Ms> template <class... Ms>
void find_matches(program& p, Ms&&... ms) void find_matches(program& p, instruction_ref ins, Ms&&... ms)
{ {
for(auto ins : iterator_for(p))
{
bool match = false; bool match = false;
each_args( each_args(
[&](auto&& m) { [&](auto&& m) {
...@@ -223,64 +228,160 @@ void find_matches(program& p, Ms&&... ms) ...@@ -223,64 +228,160 @@ void find_matches(program& p, Ms&&... ms)
match = true; match = true;
}, },
ms...); ms...);
}
/// Find matches in a program
template <class... Ms>
void find_matches(program& p, Ms&&... ms)
{
for(auto ins : iterator_for(p))
{
find_matches(p, ins, ms...);
} }
} }
template <class... Ts> template <class M>
auto all_of(Ts... ms) struct find_skip
{ {
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) { M m;
bool matches = fold([&](auto x, auto y) { M matcher() const { return m; }
return x and y.match(ctx, ins) != ctx.not_found();
})(true, ms...); void apply(program&, const matcher_result&) const {}
if(matches) };
return ins;
return ctx.not_found(); template <class M>
}); find_skip<M> make_find_skip(M m)
{
return {m};
} }
template <class... Ts> struct lazy_and
auto none_of(Ts... ms) {
template <class F, class G>
bool operator()(F f, G g) const
{
return f() and g();
}
};
struct lazy_or
{ {
template <class F, class G>
bool operator()(F f, G g) const
{
return f() or g();
}
};
template <class Op, bool Start, bool Matches>
struct match_fold_f
{
template <class... Ms>
static bool fold_matchers(matcher_context& ctx, instruction_ref ins, Ms... ms)
{
Op op;
auto matched = [&](auto m) { return [=, &ctx] { return ctx.matched(m, ins); }; };
return fold([&](auto x, auto y) { return op(always(x), matched(y)); })(Start, ms...);
}
template <class Pack>
static bool fold_matchers_pack(matcher_context& ctx, instruction_ref ins, Pack p)
{
return p([&](auto... ms) { return match_fold_f::fold_matchers(ctx, ins, ms...); });
}
template <class... Ts>
auto operator()(Ts... ms) const
{
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) { return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
bool matches = fold([&](auto x, auto y) { bool matches = match_fold_f::fold_matchers(ctx, ins, ms...);
return x and y.match(ctx, ins) == ctx.not_found(); if(matches == Matches)
})(true, ms...);
if(matches)
return ins; return ins;
return ctx.not_found(); return ctx.not_found();
}); });
} }
template <class... Ts> template <class Selector>
auto any_of(Ts... ms) auto operator[](Selector select) const
{ {
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) { return [=](auto... ms) {
bool matches = fold([&](auto x, auto y) { // Workaround ICE on gcc by packing matchers into an object
return x or y.match(ctx, ins) != ctx.not_found(); auto mpack = pack(ms...);
})(false, ms...); return make_bf_matcher([=](matcher_context& ctx, instruction_ref start) {
if(matches) Op op;
return ins; bool matches = Start;
select(start, [&](auto ins) {
auto fm = [&] { return match_fold_f::fold_matchers_pack(ctx, ins, mpack); };
matches = op(always(matches), fm);
});
if(matches == Matches)
return start;
return ctx.not_found(); return ctx.not_found();
}); });
};
}
};
const constexpr auto all_of = match_fold_f<lazy_and, true, true>{};
const constexpr auto any_of = match_fold_f<lazy_or, false, true>{};
const constexpr auto none_of = match_fold_f<lazy_or, false, false>{};
template <class... Ms>
auto skip_matches(Ms... ms)
{
return make_find_skip(any_of(ms...));
}
inline auto inputs()
{
return [](auto ins, auto f) {
for(auto&& x : ins->inputs())
f(x);
};
}
inline auto outputs()
{
return [](auto ins, auto f) {
for(auto&& x : ins->outputs())
f(x);
};
} }
MIGRAPHX_PRED_MATCHER(any, instruction_ref) { return true; } MIGRAPHX_PRED_MATCHER(any, instruction_ref) { return true; }
MIGRAPHX_PRED_MATCHER(none, instruction_ref) { return false; } MIGRAPHX_PRED_MATCHER(none, instruction_ref) { return false; }
MIGRAPHX_PRED_MATCHER(standard_shape, instruction_ref ins) { return ins->get_shape().standard(); } MIGRAPHX_PRED_MATCHER(standard_shape, instruction_ref ins) { return ins->get_shape().standard(); }
MIGRAPHX_PRED_MATCHER(not_standard_shape, instruction_ref ins)
{
return not ins->get_shape().standard();
}
MIGRAPHX_PRED_MATCHER(broadcast_shape, instruction_ref ins) MIGRAPHX_PRED_MATCHER(broadcast_shape, instruction_ref ins)
{ {
return ins->get_shape().broadcasted(); return ins->get_shape().broadcasted();
} }
MIGRAPHX_BASIC_MATCHER(output, matcher_context& ctx, instruction_ref ins) MIGRAPHX_PRED_MATCHER(transpose_shape, instruction_ref ins)
{
return ins->get_shape().transposed();
}
MIGRAPHX_PRED_MATCHER(same_input_shapes, instruction_ref ins)
{
if(ins->inputs().empty())
return false;
auto s = ins->inputs().front()->get_shape();
return std::all_of(
ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x->get_shape() == s; });
}
MIGRAPHX_BASIC_MATCHER(output, const matcher_context& ctx, instruction_ref ins)
{ {
if(ins->outputs().size() == 1) if(ins->outputs().size() == 1)
return ins->outputs().front(); return ins->outputs().front();
return ctx.not_found(); return ctx.not_found();
} }
MIGRAPHX_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins) MIGRAPHX_BASIC_MATCHER(used_once, const matcher_context& ctx, instruction_ref ins)
{ {
if(ins->outputs().size() == 1) if(ins->outputs().size() == 1)
return ins; return ins;
...@@ -289,10 +390,89 @@ MIGRAPHX_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins) ...@@ -289,10 +390,89 @@ MIGRAPHX_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins)
return ctx.not_found(); return ctx.not_found();
} }
inline auto name(std::string name) inline auto used_once_recursive(std::size_t depth)
{
return make_basic_fun_matcher([=](const matcher_context& ctx, instruction_ref start) {
// Used once
if(start->outputs().size() == 1)
return start;
// Unused
if(start->outputs().empty())
{
if(std::next(start) == ctx.not_found())
return start;
else
return ctx.not_found();
}
// Check for dead instructions
auto is_dead = fix<bool>([&](auto self, auto ins, auto n) {
if(n == 0)
return false;
if(ins->get_shape().elements() == 0)
return false;
if(ins->outputs().empty() and std::next(ins) != ctx.not_found())
return true;
return std::all_of(ins->outputs().begin(), ins->outputs().end(), [&](auto i) {
return self(i, n - 1);
});
});
auto dead = std::count_if(start->outputs().begin(), start->outputs().end(), [&](auto i) {
return is_dead(i, depth);
});
if(dead + 1 == start->outputs().size())
return start;
return ctx.not_found();
});
}
MIGRAPHX_PRED_MATCHER(is_constant, instruction_ref ins) { return ins->can_eval(); }
MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref ins)
{
if(ins->outputs().empty() and ins != std::prev(ctx.not_found()))
return ins;
return ctx.not_found();
}
template <class... Ms>
auto skip_output(Ms... ms)
{
auto m = any_of(ms...);
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) {
return fix<instruction_ref>([&](auto self, auto ins) {
if(ins->outputs().size() == 1)
{
auto next = ins->outputs().front();
if(ctx.matched(m, next))
{
auto skipped_next = self(next);
if(skipped_next != ctx.not_found())
return skipped_next;
}
return next;
}
return ctx.not_found();
})(start);
});
}
inline auto name(std::string s)
{ {
return make_basic_pred_matcher( return make_basic_pred_matcher(
[ =, name = std::move(name) ](instruction_ref ins) { return ins->name() == name; }); [ =, s = std::move(s) ](instruction_ref ins) { return ins->name() == s; });
}
inline auto name(std::unordered_set<std::string> names)
{
return make_basic_pred_matcher([ =, names = std::move(names) ](instruction_ref ins) {
return names.count(ins->name()) > 0;
});
}
template <class... Ts>
inline auto name(std::string s, Ts... xs) // NOLINT
{
return name(std::unordered_set<std::string>{std::move(s), std::move(xs)...});
} }
inline auto nargs(std::size_t n) inline auto nargs(std::size_t n)
...@@ -302,7 +482,7 @@ inline auto nargs(std::size_t n) ...@@ -302,7 +482,7 @@ inline auto nargs(std::size_t n)
inline auto arg(std::size_t i) inline auto arg(std::size_t i)
{ {
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) { return make_basic_fun_matcher([=](const matcher_context& ctx, instruction_ref ins) {
if(i < ins->inputs().size()) if(i < ins->inputs().size())
return ins->inputs()[i]; return ins->inputs()[i];
return ctx.not_found(); return ctx.not_found();
...@@ -338,6 +518,23 @@ inline auto either_arg(std::size_t i, std::size_t j) ...@@ -338,6 +518,23 @@ inline auto either_arg(std::size_t i, std::size_t j)
}; };
} }
template <class M>
auto same_shape(M m)
{
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) {
auto i = m.match(ctx, ins);
if(i != ctx.not_found() and i->get_shape() == ins->get_shape())
return ins;
return ctx.not_found();
});
}
template <class... Ms>
auto same_shape(Ms... ms)
{
return all_of(same_shape(ms)...);
}
} // namespace match } // namespace match
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_OPERATORS_ARGMAX_HPP
#define MIGRAPHX_GUARD_OPERATORS_ARGMAX_HPP
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/par_dfor.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct argmax
{
int64_t axis = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
}
std::string name() const { return "argmax"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs[0].lens();
int64_t n_dim = static_cast<int64_t>(lens.size());
if(axis >= n_dim || axis < 0)
{
MIGRAPHX_THROW("ARGMAX: axis is out of range.");
}
lens[axis] = 1;
return {shape::int64_type, lens};
}
template <class T>
int64_t calc_argmax(T& input, std::vector<std::size_t>& indices, size_t item_num) const
{
auto max_val = input(indices.begin(), indices.end());
int64_t max_index = 0;
for(std::size_t i = 1; i < item_num; ++i)
{
indices[axis] = i;
auto cur_val = input(indices.begin(), indices.end());
if(max_val < cur_val)
{
max_val = cur_val;
max_index = i;
}
}
return max_index;
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto batch_item_num = args.front().get_shape().lens()[axis];
result.visit([&](auto output) {
args[0].visit([&](auto input) {
par_for(output_shape.elements(), [&](auto i) {
auto data_idx = output_shape.multi(i);
output[i] = this->calc_argmax(input, data_idx, batch_item_num);
});
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP
#define MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/par_dfor.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct argmin
{
int64_t axis = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
}
std::string name() const { return "argmin"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs[0].lens();
int64_t n_dim = static_cast<int64_t>(lens.size());
if(axis >= n_dim || axis < 0)
{
MIGRAPHX_THROW("ARGMIN: axis is out of range.");
}
lens[axis] = 1;
return {shape::int64_type, lens};
}
template <class T>
int64_t calc_argmin(T& input, std::vector<std::size_t>& indices, size_t item_num) const
{
auto min_val = input(indices.begin(), indices.end());
int64_t min_index = 0;
for(std::size_t i = 1; i < item_num; ++i)
{
indices[axis] = i;
auto cur_val = input(indices.begin(), indices.end());
if(min_val > cur_val)
{
min_val = cur_val;
min_index = i;
}
}
return min_index;
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
std::size_t batch_item_num = args.front().get_shape().lens()[axis];
result.visit([&](auto output) {
args[0].visit([&](auto input) {
par_for(output_shape.elements(), [&](auto i) {
auto data_idx = output_shape.multi(i);
output[i] = this->calc_argmin(input, data_idx, batch_item_num);
});
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -28,23 +28,31 @@ struct binary : op_name<Derived> ...@@ -28,23 +28,31 @@ struct binary : op_name<Derived>
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) { auto s1 = args[0].get_shape();
if(input1.get_shape().packed() and input2.get_shape().packed()) auto s2 = args[1].get_shape();
if(s1 == s2 and s1.packed())
{ {
shape std_shape{s1.type(), s1.lens()};
argument std_result{std_shape, result.data()};
argument std_arg0{std_shape, args[0].data()};
argument std_arg1{std_shape, args[1].data()};
visit_all(std_result, std_arg0, std_arg1)([&](auto output, auto input1, auto input2) {
std::transform(input1.begin(), std::transform(input1.begin(),
input1.end(), input1.end(),
input2.begin(), input2.begin(),
output.begin(), output.begin(),
static_cast<const Derived&>(*this).apply()); static_cast<const Derived&>(*this).apply());
});
} }
else else
{ {
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
shape_for_each(output.get_shape(), [&](const auto& idx) { shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = static_cast<const Derived&>(*this).apply()( output(idx.begin(), idx.end()) = static_cast<const Derived&>(*this).apply()(
input1(idx.begin(), idx.end()), input2(idx.begin(), idx.end())); input1(idx.begin(), idx.end()), input2(idx.begin(), idx.end()));
}); });
}
}); });
}
return result; return result;
} }
......
#ifndef MIGRAPHX_GUARD_OPERATORS_CAPTURE_HPP
#define MIGRAPHX_GUARD_OPERATORS_CAPTURE_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct capture
{
std::size_t ins_index;
std::function<void(std::size_t ins_index, std::vector<argument>)> f{};
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.ins_index, "ins_index"));
}
std::string name() const { return "capture"; }
shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); }
argument compute(const shape&, std::vector<argument> args) const
{
if(f)
{
f(ins_index, args);
}
else
{
MIGRAPHX_THROW("CAPTURE: callback function is not callable!");
}
return args.front();
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -44,8 +44,7 @@ struct convolution ...@@ -44,8 +44,7 @@ struct convolution
const shape& input = inputs.at(0); const shape& input = inputs.at(0);
const shape& weights = inputs.at(1); const shape& weights = inputs.at(1);
auto t = input.type(); auto t = input.type();
if(padding_mode == default_)
{
return {t, return {t,
{ {
input.lens()[0], input.lens()[0],
...@@ -64,32 +63,6 @@ struct convolution ...@@ -64,32 +63,6 @@ struct convolution
1)), 1)),
}}; }};
} }
else if(padding_mode == same)
{
return {t,
{input.lens()[0],
weights.lens()[0],
static_cast<std::size_t>(
std::ceil(static_cast<double>(input.lens()[2]) / stride[0])),
static_cast<std::size_t>(
std::ceil(static_cast<double>(input.lens()[3]) / stride[1]))}};
}
else if(padding_mode == valid)
{
return {
t,
{input.lens()[0],
weights.lens()[0],
static_cast<std::size_t>(std::ceil(
static_cast<double>(input.lens()[2] - weights.lens()[2] + 1) / stride[0])),
static_cast<std::size_t>(std::ceil(
static_cast<double>(input.lens()[3] - weights.lens()[3] + 1) / stride[1]))}};
}
else
{
MIGRAPHX_THROW("Invalid padding mode");
}
}
}; };
} // namespace op } // namespace op
......
#ifndef MIGRAPHX_GUARD_OPERATORS_ERF_HPP
#define MIGRAPHX_GUARD_OPERATORS_ERF_HPP
#include <migraphx/op/unary.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct erf : unary<erf>
{
auto apply() const
{
return [](auto x) { return std::erf(x); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_LOGSOFTMAX_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_LOGSOFTMAX_HPP
#define MIGRAPHX_GUARD_OPERATORS_LOGSOFTMAX_HPP #define MIGRAPHX_GUARD_OPERATORS_LOGSOFTMAX_HPP
#include <array>
#include <migraphx/operation.hpp> #include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -30,7 +23,7 @@ struct logsoftmax ...@@ -30,7 +23,7 @@ struct logsoftmax
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1).standard(); check_shapes{inputs}.has(1).standard();
if(axis < 0 || axis > inputs[0].lens().size()) if(axis < 0 || axis >= inputs[0].lens().size())
{ {
MIGRAPHX_THROW("LogSoftMax: input axis value " + std::to_string(axis) + MIGRAPHX_THROW("LogSoftMax: input axis value " + std::to_string(axis) +
" is out of range"); " is out of range");
......
...@@ -35,14 +35,28 @@ struct multibroadcast ...@@ -35,14 +35,28 @@ struct multibroadcast
auto input = inputs.at(0); auto input = inputs.at(0);
if(input.lens().empty()) if(input.lens().empty())
MIGRAPHX_THROW("inputs dimensions should be > 0"); {
MIGRAPHX_THROW("MULTIBROADCAST: inputs dimensions should be > 0");
}
if(input.lens().size() > output_lens.size()) if(input.lens().size() > output_lens.size())
MIGRAPHX_THROW("inputs dimensions should <= output size"); {
MIGRAPHX_THROW("MULTIBROADCAST: inputs dimensions should <= output size");
}
std::vector<size_t> bcast_strides(output_lens.size(), 0);
auto offset = output_lens.size() - input.lens().size(); auto offset = output_lens.size() - input.lens().size();
for(std::ptrdiff_t i = input.lens().size() - 1; i >= 0; i--) for(std::ptrdiff_t i = input.lens().size() - 1; i >= 0; i--)
{
if(output_lens[i + offset] != input.lens()[i] and input.lens()[i] != 1)
{
MIGRAPHX_THROW("MULTIBROADCAST: input shape {" + to_string_range(input.lens()) +
"} cannot be broadcasted to {" + to_string_range(output_lens) +
"}!");
}
}
std::vector<size_t> bcast_strides(output_lens.size(), 0);
for(std::ptrdiff_t i = input.lens().size() - 1; i >= 0; i--)
{ {
if(output_lens[i + offset] == input.lens()[i]) if(output_lens[i + offset] == input.lens()[i])
{ {
......
...@@ -31,7 +31,7 @@ struct pooling ...@@ -31,7 +31,7 @@ struct pooling
{ {
return pack(f(self.mode, "mode"), return pack(f(self.mode, "mode"),
f(self.padding, "padding"), f(self.padding, "padding"),
f(self.padding, "padding_mode"), f(self.padding_mode, "padding_mode"),
f(self.stride, "stride"), f(self.stride, "stride"),
f(self.lengths, "lengths")); f(self.lengths, "lengths"));
} }
...@@ -48,52 +48,22 @@ struct pooling ...@@ -48,52 +48,22 @@ struct pooling
assert(lengths[0] <= (input.lens()[2] + 2 * padding[0])); assert(lengths[0] <= (input.lens()[2] + 2 * padding[0]));
assert(lengths[1] <= (input.lens()[3] + 2 * padding[1])); assert(lengths[1] <= (input.lens()[3] + 2 * padding[1]));
if(padding_mode == default_)
{
return {t, return {t,
{ {
input.lens()[0], input.lens()[0],
input.lens()[1], input.lens()[1],
std::size_t(std::max<std::ptrdiff_t>( std::size_t(std::max<std::ptrdiff_t>(
1, 1,
floor_divide<std::ptrdiff_t>( floor_divide<std::ptrdiff_t>(input.lens()[2] + 2 * padding[0] - lengths[0],
input.lens()[2] + 2 * padding[0] - lengths[0], stride[0]) + stride[0]) +
1)), 1)),
std::size_t(std::max<std::ptrdiff_t>( std::size_t(std::max<std::ptrdiff_t>(
1, 1,
floor_divide<std::ptrdiff_t>( floor_divide<std::ptrdiff_t>(input.lens()[3] + 2 * padding[1] - lengths[1],
input.lens()[3] + 2 * padding[1] - lengths[1], stride[1]) + stride[1]) +
1)), 1)),
}}; }};
} }
else if(padding_mode == same)
{
return {t,
{input.lens()[0],
input.lens()[1],
ceil_divide<std::size_t>(input.lens()[2], stride[0]),
ceil_divide<std::size_t>(input.lens()[3], stride[1])}};
}
else if(padding_mode == valid)
{
return {
t,
{
input.lens()[0],
input.lens()[1],
std::size_t(std::max<std::ptrdiff_t>(
1,
floor_divide<std::ptrdiff_t>(input.lens()[2] - lengths[0], stride[0]) + 1)),
std::size_t(std::max<std::ptrdiff_t>(
1,
floor_divide<std::ptrdiff_t>(input.lens()[3] - lengths[1], stride[1]) + 1)),
}};
}
else
{
MIGRAPHX_THROW("Invalid padding mode");
}
}
}; };
} // namespace op } // namespace op
......
#ifndef MIGRAPHX_GUARD_OPERATORS_POW_HPP
#define MIGRAPHX_GUARD_OPERATORS_POW_HPP
#include <migraphx/op/binary.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct pow : binary<pow>
{
auto apply() const
{
return [](auto x, auto y) { return std::pow(x, y); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_QUANT_CONVOLUTION_HPP
#define MIGRAPHX_GUARD_OPERATORS_QUANT_CONVOLUTION_HPP
#include <array>
#include <migraphx/op/common.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct quant_convolution
{
std::array<std::size_t, 2> padding = {{0, 0}};
std::array<std::size_t, 2> stride = {{1, 1}};
std::array<std::size_t, 2> dilation = {{1, 1}};
padding_mode_t padding_mode = default_;
int group = 1;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.padding, "padding"),
f(self.stride, "stride"),
f(self.dilation, "dilation"),
f(self.padding_mode, "padding_mode"),
f(self.group, "group"));
}
std::string name() const { return "quant_convolution"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2).same_type().same_ndims().only_dims(4);
const shape& input = inputs.at(0);
const shape& weights = inputs.at(1);
auto t = input.type();
// all input type must be int8_type and output is float_type
if(t != shape::int8_type)
{
MIGRAPHX_THROW("QUANT_CONVOLUTION: only accept input and weights of type int8_t");
}
t = shape::int32_type;
return {t,
{
input.lens()[0],
weights.lens()[0],
std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[2] - (1 + dilation[0] * (weights.lens()[2] - 1)) +
2 * padding[0]) /
stride[0] +
1)),
std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[3] - (1 + dilation[1] * (weights.lens()[3] - 1)) +
2 * padding[1]) /
stride[1] +
1)),
}};
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_QUANT_DOT_HPP
#define MIGRAPHX_GUARD_OPERATORS_QUANT_DOT_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct quant_dot
{
int32_t alpha = 1;
int32_t beta = 1;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(as_number(self.alpha), "alpha"), f(as_number(self.beta), "beta"));
}
std::string name() const { return "quant_dot"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{{inputs.at(0), inputs.at(1)}, *this}.same_type();
const shape& a = inputs.at(0);
const shape& b = inputs.at(1);
auto t = a.type();
if(t != shape::int8_type)
{
MIGRAPHX_THROW("QUANT_DOT: only support data type int8_t");
}
if(!std::all_of(inputs.begin(), inputs.end(), [](auto s) { return s.lens().size() >= 2; }))
{
MIGRAPHX_THROW("QUANT_DOT: dot only accept 2 or more dims operands");
}
// only handle the case that the batch size of a and b are the same
if(!std::equal(
a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2, b.lens().rend()))
{
MIGRAPHX_THROW("QUANT_DOT: batch size of A and B mismatch: {" +
to_string_range(a.lens()) + "} x {" + to_string_range(b.lens()) + "}");
}
std::size_t dim_0 = a.lens().size() - 2;
std::size_t dim_1 = a.lens().size() - 1;
if(a.lens()[dim_1] != b.lens()[dim_0])
{
MIGRAPHX_THROW("QUANT_DOT: inner dimensions do not match: {" +
to_string_range(a.lens()) + "} x {" + to_string_range(b.lens()) + "}");
}
// k be multiple of 4
if((a.lens()[dim_1] % 4) != 0)
{
MIGRAPHX_THROW("QUANT_DOT: size of A {" + to_string_range(a.lens()) + "} and B {" +
to_string_range(b.lens()) + "} must be multiple of 4 for int8 type");
}
auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1];
if(inputs.size() == 3 && out_lens != inputs.at(2).lens())
{
MIGRAPHX_THROW("QUANT_DOT: dimension mismatch, operand C: {" +
to_string_range(inputs.at(2).lens()) +
"}, cannot add to operand A * B: {" + to_string_range(out_lens) + "}");
}
if(inputs.size() == 3 && inputs.at(2).type() != shape::int32_type)
{
MIGRAPHX_THROW("QUANT_DOT: operand C type must be int32");
}
return {shape::int32_type, out_lens};
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_MEAN_HPP
#define MIGRAPHX_GUARD_OPERATORS_MEAN_HPP
#include <migraphx/op/reduce_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct reduce_mean : reduce_op<reduce_mean>
{
reduce_mean() {}
reduce_mean(std::vector<int64_t> ax) : reduce_op(std::move(ax)) {}
auto op() const
{
return [=](auto x, auto y) { return x + y; };
}
auto output(const shape& s) const
{
return [&](auto val) { return val / s.elements(); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_OP_HPP
#define MIGRAPHX_GUARD_OPERATORS_OP_HPP
#include <migraphx/op/name.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/config.hpp>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct lowest
{
template <class T>
operator T() const
{
return std::numeric_limits<T>::lowest();
}
};
struct highest
{
template <class T>
operator T() const
{
return std::numeric_limits<T>::max();
}
};
struct zero
{
template <class T>
operator T() const
{
return T{0};
}
};
template <class Derived>
struct reduce_op : op_name<Derived>
{
std::vector<std::int64_t> axes{};
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axes, "axes"));
}
std::vector<int64_t> tune_axes(std::size_t n_dim) const
{
auto tuned_axes = axes;
if(tuned_axes.empty())
{
tuned_axes.resize(n_dim);
std::iota(tuned_axes.begin(), tuned_axes.end(), 0);
}
else
{
for(auto& axis : tuned_axes)
{
int64_t s_dim = static_cast<int64_t>(n_dim);
if(axis >= s_dim or axis < -s_dim)
{
MIGRAPHX_THROW("REDUCE_OP: axis out of range");
}
if(axis < 0)
{
axis += n_dim;
}
}
}
return tuned_axes;
}
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto s = inputs.at(0);
auto lens = s.lens();
auto tuned_axes = tune_axes(lens.size());
for(auto axis : tuned_axes)
{
lens[axis] = 1;
}
return {s.type(), lens};
}
template <class T>
void tune_dims(const std::vector<int64_t>& tuned_axes,
const std::vector<T>& in_lens,
std::vector<T>& out_lens) const
{
for(auto axis : tuned_axes)
{
out_lens[axis] = in_lens[axis];
}
}
template <class T>
void reduce(tensor_view<T>& input,
shape& batch_shape,
std::vector<int64_t>& tuned_axes,
std::vector<std::size_t>& out_idx,
tensor_view<T>& output) const
{
auto data_idx = out_idx;
T val = static_cast<const Derived&>(*this).init();
shape_for_each(batch_shape, [&](auto b_idx) {
this->tune_dims(tuned_axes, b_idx, data_idx);
val = static_cast<const Derived&>(*this).op()(
static_cast<const Derived&>(*this).input()(input(data_idx.begin(), data_idx.end())),
val);
});
output(out_idx.begin(), out_idx.end()) =
static_cast<const Derived&>(*this).output(batch_shape)(val);
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto arg_lens = args.front().get_shape().lens();
auto tuned_axes = tune_axes(arg_lens.size());
std::vector<std::size_t> batch_lens(output_shape.lens().size(), 1);
tune_dims(tuned_axes, arg_lens, batch_lens);
shape batch_shape{output_shape.type(), batch_lens};
visit_all(result, args[0])([&](auto output, auto input) {
par_for(output_shape.elements(), [&](auto i) {
auto out_idx = output_shape.multi(i);
this->reduce(input, batch_shape, tuned_axes, out_idx, output);
});
});
return result;
}
auto init() const { return zero(); }
auto input() const
{
return [&](auto val) { return val; };
}
auto output(const shape&) const
{
return [&](auto val) { return val; };
}
reduce_op() {}
reduce_op(std::vector<int64_t> ax) : axes(std::move(ax)) {}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment