Commit 51597ed7 authored by Khalique's avatar Khalique
Browse files

fix tests and tf parser

parents 7bacd3ba bc80dee8
...@@ -7,6 +7,14 @@ ...@@ -7,6 +7,14 @@
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_reshapes.hpp>
namespace migraphx { namespace migraphx {
namespace driver { namespace driver {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -17,6 +25,7 @@ struct loader ...@@ -17,6 +25,7 @@ struct loader
std::string file_type; std::string file_type;
bool is_nhwc = true; bool is_nhwc = true;
unsigned trim = 0; unsigned trim = 0;
bool optimize = false;
void parse(argument_parser& ap) void parse(argument_parser& ap)
{ {
...@@ -26,6 +35,7 @@ struct loader ...@@ -26,6 +35,7 @@ struct loader
ap(is_nhwc, {"--nhwc"}, ap.help("Treat tensorflow format as nhwc"), ap.set_value(true)); ap(is_nhwc, {"--nhwc"}, ap.help("Treat tensorflow format as nhwc"), ap.set_value(true));
ap(is_nhwc, {"--nchw"}, ap.help("Treat tensorflow format as nchw"), ap.set_value(false)); ap(is_nhwc, {"--nchw"}, ap.help("Treat tensorflow format as nchw"), ap.set_value(false));
ap(trim, {"--trim", "-t"}, ap.help("Trim instructions from the end")); ap(trim, {"--trim", "-t"}, ap.help("Trim instructions from the end"));
ap(optimize, {"--optimize"}, ap.help("Optimize when reading"), ap.set_value(true));
} }
program load() program load()
...@@ -48,6 +58,20 @@ struct loader ...@@ -48,6 +58,20 @@ struct loader
auto last = std::prev(p.end(), trim); auto last = std::prev(p.end(), trim);
p.remove_instructions(last, p.end()); p.remove_instructions(last, p.end());
} }
if(optimize)
migraphx::run_passes(p,
{
migraphx::eliminate_identity{},
migraphx::dead_code_elimination{},
migraphx::simplify_algebra{},
migraphx::dead_code_elimination{},
migraphx::simplify_reshapes{},
migraphx::dead_code_elimination{},
migraphx::propagate_constant{},
migraphx::dead_code_elimination{},
migraphx::eliminate_pad{},
migraphx::dead_code_elimination{},
});
return p; return p;
} }
}; };
......
...@@ -190,6 +190,23 @@ auto pop_back_args(Ts&&... xs) ...@@ -190,6 +190,23 @@ auto pop_back_args(Ts&&... xs)
}; };
} }
template <class T>
struct always_f
{
T x;
template <class... Ts>
constexpr T operator()(Ts&&...) const
{
return x;
}
};
template <class T>
auto always(T x)
{
return always_f<T>{x};
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -20,6 +21,12 @@ struct matcher_context ...@@ -20,6 +21,12 @@ struct matcher_context
std::unordered_map<std::string, instruction_ref> instructions; std::unordered_map<std::string, instruction_ref> instructions;
instruction_ref not_found() const { return last; } instruction_ref not_found() const { return last; }
template <class M>
bool matched(M m, instruction_ref ins)
{
return m.match(*this, ins) != this->not_found();
}
private: private:
instruction_ref last; instruction_ref last;
}; };
...@@ -205,12 +212,10 @@ matcher_result match_instruction(program& p, instruction_ref ins, M&& m) ...@@ -205,12 +212,10 @@ matcher_result match_instruction(program& p, instruction_ref ins, M&& m)
return result; return result;
} }
/// Find matches in a program /// Find matches for an instruction in the program
template <class... Ms> template <class... Ms>
void find_matches(program& p, Ms&&... ms) void find_matches(program& p, instruction_ref ins, Ms&&... ms)
{ {
for(auto ins : iterator_for(p))
{
bool match = false; bool match = false;
each_args( each_args(
[&](auto&& m) { [&](auto&& m) {
...@@ -223,56 +228,131 @@ void find_matches(program& p, Ms&&... ms) ...@@ -223,56 +228,131 @@ void find_matches(program& p, Ms&&... ms)
match = true; match = true;
}, },
ms...); ms...);
}
/// Find matches in a program
template <class... Ms>
void find_matches(program& p, Ms&&... ms)
{
for(auto ins : iterator_for(p))
{
find_matches(p, ins, ms...);
} }
} }
template <class... Ts> struct lazy_and
auto all_of(Ts... ms) {
template <class F, class G>
bool operator()(F f, G g) const
{
return f() and g();
}
};
struct lazy_or
{
template <class F, class G>
bool operator()(F f, G g) const
{
return f() or g();
}
};
template <class Op, bool Start, bool Matches>
struct match_fold_f
{ {
template <class... Ms>
static bool fold_matchers(matcher_context& ctx, instruction_ref ins, Ms... ms)
{
Op op;
auto matched = [&](auto m) { return [=, &ctx] { return ctx.matched(m, ins); }; };
return fold([&](auto x, auto y) { return op(always(x), matched(y)); })(Start, ms...);
}
template <class Pack>
static bool fold_matchers_pack(matcher_context& ctx, instruction_ref ins, Pack p)
{
return p([&](auto... ms) { return match_fold_f::fold_matchers(ctx, ins, ms...); });
}
template <class... Ts>
auto operator()(Ts... ms) const
{
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) { return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
bool matches = fold([&](auto x, auto y) { bool matches = match_fold_f::fold_matchers(ctx, ins, ms...);
return x and y.match(ctx, ins) != ctx.not_found(); if(matches == Matches)
})(true, ms...);
if(matches)
return ins; return ins;
return ctx.not_found(); return ctx.not_found();
}); });
} }
template <class... Ts> template <class Selector>
auto none_of(Ts... ms) auto operator[](Selector select) const
{ {
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) { return [=](auto... ms) {
bool matches = fold([&](auto x, auto y) { // Workaround ICE on gcc by packing matchers into an object
return x and y.match(ctx, ins) == ctx.not_found(); auto mpack = pack(ms...);
})(true, ms...); return make_bf_matcher([=](matcher_context& ctx, instruction_ref start) {
if(matches) Op op;
return ins; bool matches = Start;
select(start, [&](auto ins) {
auto fm = [&] { return match_fold_f::fold_matchers_pack(ctx, ins, mpack); };
matches = op(always(matches), fm);
});
if(matches == Matches)
return start;
return ctx.not_found(); return ctx.not_found();
}); });
};
}
};
const constexpr auto all_of = match_fold_f<lazy_and, true, true>{};
const constexpr auto any_of = match_fold_f<lazy_or, false, true>{};
const constexpr auto none_of = match_fold_f<lazy_or, false, false>{};
inline auto inputs()
{
return [](auto ins, auto f) {
for(auto&& x : ins->inputs())
f(x);
};
} }
template <class... Ts> inline auto outputs()
auto any_of(Ts... ms)
{ {
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) { return [](auto ins, auto f) {
bool matches = fold([&](auto x, auto y) { for(auto&& x : ins->outputs())
return x or y.match(ctx, ins) != ctx.not_found(); f(x);
})(false, ms...); };
if(matches)
return ins;
return ctx.not_found();
});
} }
MIGRAPHX_PRED_MATCHER(any, instruction_ref) { return true; } MIGRAPHX_PRED_MATCHER(any, instruction_ref) { return true; }
MIGRAPHX_PRED_MATCHER(none, instruction_ref) { return false; } MIGRAPHX_PRED_MATCHER(none, instruction_ref) { return false; }
MIGRAPHX_PRED_MATCHER(standard_shape, instruction_ref ins) { return ins->get_shape().standard(); } MIGRAPHX_PRED_MATCHER(standard_shape, instruction_ref ins) { return ins->get_shape().standard(); }
MIGRAPHX_PRED_MATCHER(not_standard_shape, instruction_ref ins)
{
return not ins->get_shape().standard();
}
MIGRAPHX_PRED_MATCHER(broadcast_shape, instruction_ref ins) MIGRAPHX_PRED_MATCHER(broadcast_shape, instruction_ref ins)
{ {
return ins->get_shape().broadcasted(); return ins->get_shape().broadcasted();
} }
MIGRAPHX_PRED_MATCHER(transpose_shape, instruction_ref ins)
{
return ins->get_shape().transposed();
}
MIGRAPHX_PRED_MATCHER(same_input_shapes, instruction_ref ins)
{
if(ins->inputs().empty())
return false;
auto s = ins->inputs().front()->get_shape();
return std::all_of(
ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x->get_shape() == s; });
}
MIGRAPHX_BASIC_MATCHER(output, matcher_context& ctx, instruction_ref ins) MIGRAPHX_BASIC_MATCHER(output, matcher_context& ctx, instruction_ref ins)
{ {
if(ins->outputs().size() == 1) if(ins->outputs().size() == 1)
...@@ -289,10 +369,39 @@ MIGRAPHX_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins) ...@@ -289,10 +369,39 @@ MIGRAPHX_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins)
return ctx.not_found(); return ctx.not_found();
} }
inline auto name(std::string name) template <class... Ms>
auto skip_output(Ms... ms)
{
auto m = any_of(ms...);
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) {
return fix<instruction_ref>([&](auto self, auto ins) {
if(ins->outputs().size() == 1)
{
auto next = ins->outputs().front();
if(ctx.matched(m, next))
{
auto skipped_next = self(next);
if(skipped_next != ctx.not_found())
return skipped_next;
}
return next;
}
return ctx.not_found();
})(start);
});
}
inline auto name(std::string s)
{ {
return make_basic_pred_matcher( return make_basic_pred_matcher(
[ =, name = std::move(name) ](instruction_ref ins) { return ins->name() == name; }); [ =, s = std::move(s) ](instruction_ref ins) { return ins->name() == s; });
}
inline auto name(std::unordered_set<std::string> names)
{
return make_basic_pred_matcher([ =, names = std::move(names) ](instruction_ref ins) {
return names.count(ins->name()) > 0;
});
} }
inline auto nargs(std::size_t n) inline auto nargs(std::size_t n)
...@@ -338,6 +447,23 @@ inline auto either_arg(std::size_t i, std::size_t j) ...@@ -338,6 +447,23 @@ inline auto either_arg(std::size_t i, std::size_t j)
}; };
} }
template <class M>
auto same_shape(M m)
{
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) {
auto i = m.match(ctx, ins);
if(i != ctx.not_found() and i->get_shape() == ins->get_shape())
return ins;
return ctx.not_found();
});
}
template <class... Ms>
auto same_shape(Ms... ms)
{
return all_of(same_shape(ms)...);
}
} // namespace match } // namespace match
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_OPERATORS_ARGMAX_HPP
#define MIGRAPHX_GUARD_OPERATORS_ARGMAX_HPP
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/par_dfor.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct argmax
{
int64_t axis = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
}
std::string name() const { return "argmax"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs[0].lens();
int64_t n_dim = static_cast<int64_t>(lens.size());
if(axis >= n_dim || axis < 0)
{
MIGRAPHX_THROW("ARGMAX: axis is out of range.");
}
lens[axis] = 1;
return {shape::int64_type, lens};
}
template <class T>
int64_t calc_argmax(T& input, std::vector<std::size_t>& indices, size_t item_num) const
{
auto max_val = input(indices.begin(), indices.end());
int64_t max_index = 0;
for(std::size_t i = 1; i < item_num; ++i)
{
indices[axis] = i;
auto cur_val = input(indices.begin(), indices.end());
if(max_val < cur_val)
{
max_val = cur_val;
max_index = i;
}
}
return max_index;
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto batch_item_num = args.front().get_shape().lens()[axis];
result.visit([&](auto output) {
args[0].visit([&](auto input) {
par_for(output_shape.elements(), [&](auto i) {
auto data_idx = output_shape.multi(i);
output[i] = this->calc_argmax(input, data_idx, batch_item_num);
});
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP
#define MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/par_dfor.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct argmin
{
int64_t axis = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
}
std::string name() const { return "argmin"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs[0].lens();
int64_t n_dim = static_cast<int64_t>(lens.size());
if(axis >= n_dim || axis < 0)
{
MIGRAPHX_THROW("ARGMIN: axis is out of range.");
}
lens[axis] = 1;
return {shape::int64_type, lens};
}
template <class T>
int64_t calc_argmin(T& input, std::vector<std::size_t>& indices, size_t item_num) const
{
auto min_val = input(indices.begin(), indices.end());
int64_t min_index = 0;
for(std::size_t i = 1; i < item_num; ++i)
{
indices[axis] = i;
auto cur_val = input(indices.begin(), indices.end());
if(min_val > cur_val)
{
min_val = cur_val;
min_index = i;
}
}
return min_index;
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
std::size_t batch_item_num = args.front().get_shape().lens()[axis];
result.visit([&](auto output) {
args[0].visit([&](auto input) {
par_for(output_shape.elements(), [&](auto i) {
auto data_idx = output_shape.multi(i);
output[i] = this->calc_argmin(input, data_idx, batch_item_num);
});
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_LOGSOFTMAX_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_LOGSOFTMAX_HPP
#define MIGRAPHX_GUARD_OPERATORS_LOGSOFTMAX_HPP #define MIGRAPHX_GUARD_OPERATORS_LOGSOFTMAX_HPP
#include <array>
#include <migraphx/operation.hpp> #include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -35,14 +35,28 @@ struct multibroadcast ...@@ -35,14 +35,28 @@ struct multibroadcast
auto input = inputs.at(0); auto input = inputs.at(0);
if(input.lens().empty()) if(input.lens().empty())
MIGRAPHX_THROW("inputs dimensions should be > 0"); {
MIGRAPHX_THROW("MULTIBROADCAST: inputs dimensions should be > 0");
}
if(input.lens().size() > output_lens.size()) if(input.lens().size() > output_lens.size())
MIGRAPHX_THROW("inputs dimensions should <= output size"); {
MIGRAPHX_THROW("MULTIBROADCAST: inputs dimensions should <= output size");
}
std::vector<size_t> bcast_strides(output_lens.size(), 0);
auto offset = output_lens.size() - input.lens().size(); auto offset = output_lens.size() - input.lens().size();
for(std::ptrdiff_t i = input.lens().size() - 1; i >= 0; i--) for(std::ptrdiff_t i = input.lens().size() - 1; i >= 0; i--)
{
if(output_lens[i + offset] != input.lens()[i] and input.lens()[i] != 1)
{
MIGRAPHX_THROW("MULTIBROADCAST: input shape {" + to_string_range(input.lens()) +
"} cannot be broadcasted to {" + to_string_range(output_lens) +
"}!");
}
}
std::vector<size_t> bcast_strides(output_lens.size(), 0);
for(std::ptrdiff_t i = input.lens().size() - 1; i >= 0; i--)
{ {
if(output_lens[i + offset] == input.lens()[i]) if(output_lens[i + offset] == input.lens()[i])
{ {
......
#ifndef MIGRAPHX_GUARD_OPERATORS_SOFTMAX_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_SOFTMAX_HPP
#define MIGRAPHX_GUARD_OPERATORS_SOFTMAX_HPP #define MIGRAPHX_GUARD_OPERATORS_SOFTMAX_HPP
#include <array>
#include <migraphx/operation.hpp> #include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
#include <migraphx/op/abs.hpp> #include <migraphx/op/abs.hpp>
#include <migraphx/op/acos.hpp> #include <migraphx/op/acos.hpp>
#include <migraphx/op/add.hpp> #include <migraphx/op/add.hpp>
#include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp>
#include <migraphx/op/asin.hpp> #include <migraphx/op/asin.hpp>
#include <migraphx/op/as_shape.hpp> #include <migraphx/op/as_shape.hpp>
#include <migraphx/op/atan.hpp> #include <migraphx/op/atan.hpp>
......
...@@ -99,6 +99,8 @@ struct shape ...@@ -99,6 +99,8 @@ struct shape
/// Map element index to space index /// Map element index to space index
std::size_t index(std::size_t i) const; std::size_t index(std::size_t i) const;
std::vector<std::size_t> multi(std::size_t i) const;
/// Returns true if the shape is packed with no padding /// Returns true if the shape is packed with no padding
bool packed() const; bool packed() const;
/// Returns true is the shape has been transposed. That is the strides are not in descending /// Returns true is the shape has been transposed. That is the strides are not in descending
......
...@@ -63,6 +63,8 @@ struct onnx_parser ...@@ -63,6 +63,8 @@ struct onnx_parser
add_variadic_op("Max", op::max{}); add_variadic_op("Max", op::max{});
add_variadic_op("Min", op::min{}); add_variadic_op("Min", op::min{});
add_mem_op("ArgMax", &onnx_parser::parse_argmax);
add_mem_op("ArgMin", &onnx_parser::parse_argmin);
add_mem_op("Clip", &onnx_parser::parse_clip); add_mem_op("Clip", &onnx_parser::parse_clip);
add_mem_op("LRN", &onnx_parser::parse_lrn); add_mem_op("LRN", &onnx_parser::parse_lrn);
add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler); add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
...@@ -93,6 +95,7 @@ struct onnx_parser ...@@ -93,6 +95,7 @@ struct onnx_parser
add_mem_op("GRU", &onnx_parser::parse_gru); add_mem_op("GRU", &onnx_parser::parse_gru);
add_mem_op("LSTM", &onnx_parser::parse_lstm); add_mem_op("LSTM", &onnx_parser::parse_lstm);
add_mem_op("Pad", &onnx_parser::parse_pad); add_mem_op("Pad", &onnx_parser::parse_pad);
add_mem_op("ReduceSum", &onnx_parser::parse_reduce_sum);
// init the activation function map // init the activation function map
init_actv_func(); init_actv_func();
...@@ -182,7 +185,15 @@ struct onnx_parser ...@@ -182,7 +185,15 @@ struct onnx_parser
s0.end(), s0.end(),
s1.begin() + offset, s1.begin() + offset,
out_lens.begin() + offset, out_lens.begin() + offset,
[](auto a, auto b) { return std::max(a, b); }); [&](auto a, auto b) {
if(a != b and a != 1 and b != 1)
{
MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" +
to_string_range(s0) + "} and {" +
to_string_range(s1) + "} mismatch!");
}
return std::max(a, b);
});
return out_lens; return out_lens;
} }
...@@ -266,6 +277,60 @@ struct onnx_parser ...@@ -266,6 +277,60 @@ struct onnx_parser
return prog.add_instruction(op::logsoftmax{axis}, std::move(args)); return prog.add_instruction(op::logsoftmax{axis}, std::move(args));
} }
instruction_ref parse_argmax(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
int64_t axis = 0;
if(contains(attributes, "axis"))
{
axis = static_cast<int64_t>(parse_value(attributes.at("axis")).at<int>());
}
int keep_dims = 1;
if(contains(attributes, "keepdims"))
{
keep_dims = parse_value(attributes.at("keepdims")).at<int>();
}
if(keep_dims == 0)
{
auto ins = prog.add_instruction(op::argmax{axis}, std::move(args));
return prog.add_instruction(op::squeeze{{axis}}, ins);
}
else
{
return prog.add_instruction(op::argmax{axis}, std::move(args));
}
}
instruction_ref parse_argmin(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
int64_t axis = 0;
if(contains(attributes, "axis"))
{
axis = static_cast<int64_t>(parse_value(attributes.at("axis")).at<int>());
}
int keep_dims = 1;
if(contains(attributes, "keepdims"))
{
keep_dims = parse_value(attributes.at("keepdims")).at<int>();
}
if(keep_dims == 0)
{
auto ins = prog.add_instruction(op::argmin{axis}, std::move(args));
return prog.add_instruction(op::squeeze{{axis}}, ins);
}
else
{
return prog.add_instruction(op::argmin{axis}, std::move(args));
}
}
instruction_ref instruction_ref
parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
...@@ -1222,6 +1287,40 @@ struct onnx_parser ...@@ -1222,6 +1287,40 @@ struct onnx_parser
return {hidden_states, last_output, last_cell_output}; return {hidden_states, last_output, last_cell_output};
} }
instruction_ref parse_reduce_sum(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
{
std::size_t n_dim = args.front()->get_shape().lens().size();
// default to reduce over all dimensions
std::vector<std::size_t> axes(n_dim);
std::iota(axes.begin(), axes.end(), 0);
if(contains(attributes, "axes"))
{
axes.clear();
auto&& attr_axes = attributes["axes"].ints();
axes = std::vector<std::size_t>(attr_axes.begin(), attr_axes.end());
}
int keep_dims = 1;
if(contains(attributes, "keepdims"))
{
keep_dims = parse_value(attributes.at("keepdims")).at<int>();
}
if(keep_dims == 1)
{
return prog.add_instruction(op::reduce_sum{axes}, std::move(args));
}
else
{
auto ins = prog.add_instruction(op::reduce_sum{axes}, std::move(args));
std::vector<int64_t> squeeze_axes{axes.begin(), axes.end()};
return prog.add_instruction(op::squeeze{squeeze_axes}, ins);
}
}
void parse_from(std::istream& is) void parse_from(std::istream& is)
{ {
onnx::ModelProto model; onnx::ModelProto model;
......
...@@ -138,6 +138,24 @@ std::size_t shape::index(std::size_t i) const ...@@ -138,6 +138,24 @@ std::size_t shape::index(std::size_t i) const
return result; return result;
} }
} }
std::vector<std::size_t> shape::multi(std::size_t i) const
{
assert(this->standard());
std::vector<std::size_t> indices(lens().size());
std::transform(strides().begin(),
strides().end(),
lens().begin(),
indices.begin(),
[&](std::size_t stride, std::size_t len) {
assert(len > 0 and stride > 0);
return (i / stride) % len;
});
return indices;
}
bool shape::packed() const { return this->elements() == this->element_space(); } bool shape::packed() const { return this->elements() == this->element_space(); }
bool shape::transposed() const bool shape::transposed() const
......
...@@ -2,14 +2,17 @@ ...@@ -2,14 +2,17 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/op/as_shape.hpp> #include <migraphx/op/as_shape.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/matcher.hpp>
#include <unordered_set> #include <unordered_set>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
bool is_reshaper(instruction_ref ins) const auto& reshaper_names()
{ {
// clang-format off // clang-format off
static const std::unordered_set<std::string> names = { static const std::unordered_set<std::string> names = {
...@@ -19,17 +22,10 @@ bool is_reshaper(instruction_ref ins) ...@@ -19,17 +22,10 @@ bool is_reshaper(instruction_ref ins)
"unsqueeze" "unsqueeze"
}; };
// clang-format on // clang-format on
return contains(names, ins->name()); return names;
} }
bool is_transpose_output(instruction_ref ins) bool is_reshaper(instruction_ref ins) { return contains(reshaper_names(), ins->name()); }
{
if(ins->outputs().size() != 1)
return false;
if(ins->outputs().front()->name() == "contiguous")
return is_transpose_output(ins->outputs().front());
return ins->outputs().front()->name() == "transpose";
}
instruction_ref find_transpose_input(instruction_ref ins) instruction_ref find_transpose_input(instruction_ref ins)
{ {
...@@ -42,21 +38,62 @@ instruction_ref find_transpose_input(instruction_ref ins) ...@@ -42,21 +38,62 @@ instruction_ref find_transpose_input(instruction_ref ins)
return ins; return ins;
} }
void simplify_reshapes::apply(program& p) const auto get_transpose_dims(instruction_ref ins)
{ {
auto end = std::prev(p.end()); return any_cast<const op::transpose&>(ins->get_operator()).dims;
for(auto ins : iterator_for(p)) }
std::vector<int64_t> reorder_dims(std::vector<int64_t> dims, std::vector<int64_t> permutation)
{
std::vector<int64_t> result(dims.size());
assert(dims.size() == permutation.size());
for(std::size_t i = 0; i < dims.size(); i++)
{ {
if(ins == end and ins->name() == "contiguous") result[i] = dims[permutation[i]];
continue; }
// Skip possible dead instructions return result;
if(ins->outputs().empty() and ins != end) }
continue;
if(is_reshaper(ins)) bool is_no_transpose(const std::vector<int64_t>& dims)
{
if(dims.empty())
return true;
if(dims.front() != 0)
return false;
return std::adjacent_find(
dims.begin(), dims.end(), [](auto x, auto y) { return (y - x) != 1; }) == dims.end();
}
template <class Vector, class Op>
std::vector<int64_t> sort_permutation(const Vector& data, Op op)
{
std::vector<std::int64_t> result(data.size());
std::iota(result.begin(), result.end(), 0);
std::sort(result.begin(), result.end(), [&](auto x, auto y) { return op(data[x], data[y]); });
return result;
}
std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation)
{
return sort_permutation(permutation, std::less<>{});
}
std::vector<int64_t> find_permutation(const shape& s)
{
return sort_permutation(s.strides(), std::greater<>{});
}
struct find_reshaper
{
auto matcher() const
{ {
if(std::any_of(ins->outputs().begin(), ins->outputs().end(), &is_reshaper)) return match::name(reshaper_names())(
continue; match::any_of[match::outputs()](match::name(reshaper_names())));
// Gather reshapes }
void apply(program& p, const match::matcher_result& mr) const
{
auto ins = mr.result;
std::vector<instruction_ref> reshapes{ins}; std::vector<instruction_ref> reshapes{ins};
while(is_reshaper(reshapes.back())) while(is_reshaper(reshapes.back()))
{ {
...@@ -83,21 +120,107 @@ void simplify_reshapes::apply(program& p) const ...@@ -83,21 +120,107 @@ void simplify_reshapes::apply(program& p) const
p.replace_instruction(r.first, r.second); p.replace_instruction(r.first, r.second);
} }
} }
else if(ins->name() == "transpose") };
struct find_nop_reshapes
{
auto matcher() const
{ {
if(is_transpose_output(ins)) auto reshapes = reshaper_names();
continue; reshapes.insert("transpose");
reshapes.insert("slice");
return match::name(reshapes)(match::same_shape(match::arg(0)));
}
void apply(program& p, const match::matcher_result& mr) const
{
auto ins = mr.result;
p.replace_instruction(ins, ins->inputs().front());
}
};
struct find_transpose
{
auto matcher() const
{
return match::name("transpose")(match::none_of(
match::skip_output(match::name("contiguous"))(match::name("transpose"))));
}
void apply(program& p, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto x = ins; auto x = ins;
auto t = ins; auto t = ins;
std::vector<std::int64_t> dims(ins->get_shape().lens().size());
std::iota(dims.begin(), dims.end(), 0);
do do
{ {
dims = reorder_dims(get_transpose_dims(t), dims);
x = t; x = t;
t = find_transpose_input(x); t = find_transpose_input(x);
} while(x != t and t->name() == "transpose"); } while(x != t and t->name() == "transpose");
if(t == ins or t->name() != "transpose") if(t == ins or t->name() != "transpose")
continue; return;
if(is_no_transpose(dims))
{
p.replace_instruction(ins, t->inputs().front()); p.replace_instruction(ins, t->inputs().front());
} }
else
{
p.replace_instruction(ins, op::transpose{{dims}}, t->inputs().front());
}
}
};
struct find_concat_transpose
{
auto matcher() const
{
return match::name("concat")(match::same_input_shapes(),
match::all_of[match::inputs()](match::transpose_shape()));
}
void apply(program& p, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto s = ins->inputs().front()->get_shape();
assert(s.transposed());
auto op = any_cast<op::concat>(ins->get_operator());
auto permutation = find_permutation(s);
auto ipermutation = invert_permutation(permutation);
op.axis = ipermutation[op.axis];
std::vector<instruction_ref> inputs;
std::transform(
ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) {
if(i->name() == "transpose" and i->inputs().front()->get_shape().standard())
return i->inputs().front();
return p.insert_instruction(ins, op::transpose{permutation}, i);
});
auto concat = p.insert_instruction(ins, op, inputs);
auto t = p.insert_instruction(ins, op::transpose{ipermutation}, concat);
assert(ins->get_shape().lens() == t->get_shape().lens());
p.replace_instruction(ins, t);
}
};
void simplify_reshapes::apply(program& p) const
{
auto end = std::prev(p.end());
for(auto ins : iterator_for(p))
{
if(ins == end and ins->name() == "contiguous")
continue;
// Skip possible dead instructions
if(ins->outputs().empty() and ins != end)
continue;
match::find_matches(p,
ins,
find_nop_reshapes{},
find_reshaper{},
find_transpose{},
find_concat_transpose{});
} }
} }
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
#include <migraphx/op/pad.hpp> #include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp> #include <migraphx/op/pooling.hpp>
#include <migraphx/op/softmax.hpp> #include <migraphx/op/softmax.hpp>
#include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/par_dfor.hpp> #include <migraphx/par_dfor.hpp>
...@@ -539,18 +541,11 @@ struct cpu_softmax ...@@ -539,18 +541,11 @@ struct cpu_softmax
std::string name() const { return "cpu::softmax"; } std::string name() const { return "cpu::softmax"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
template <typename T>
std::size_t compute_batch_index(T idx, shape& batch_shape, int axis) const
{
idx[axis] = 0;
return batch_shape.index(idx);
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
auto batch_lens = output_shape.lens(); auto batch_lens = output_shape.lens();
std::size_t n_dims = batch_lens[op.axis];
batch_lens[op.axis] = 1; batch_lens[op.axis] = 1;
shape batch_shape{shape::int32_type, batch_lens}; shape batch_shape{shape::int32_type, batch_lens};
...@@ -558,26 +553,33 @@ struct cpu_softmax ...@@ -558,26 +553,33 @@ struct cpu_softmax
using value_type = typename decltype(input)::value_type; using value_type = typename decltype(input)::value_type;
std::vector<value_type> batch_max(batch_shape.elements(), std::vector<value_type> batch_max(batch_shape.elements(),
std::numeric_limits<value_type>::lowest()); std::numeric_limits<value_type>::lowest());
shape_for_each(output_shape, [&](auto idx) { std::vector<value_type> batch_sum(batch_shape.elements(), value_type(0));
auto index = this->compute_batch_index(idx, batch_shape, op.axis); par_for(batch_shape.elements(), [&](auto i) {
batch_max[index] = std::max(batch_max[index], input(idx.begin(), idx.end())); auto idx = batch_shape.multi(i);
}); for(std::size_t j = 0; j < n_dims; ++j)
{
idx[op.axis] = j;
batch_max[i] = std::max(batch_max[i], input(idx.begin(), idx.end()));
}
shape_for_each(output_shape, [&](auto idx) { for(std::size_t j = 0; j < n_dims; ++j)
auto index = this->compute_batch_index(idx, batch_shape, op.axis); {
output(idx.begin(), idx.end()) = idx[op.axis] = j;
std::exp(input(idx.begin(), idx.end()) - batch_max[index]); std::size_t index = output_shape.index(idx);
}); output[index] = std::exp(input[index] - batch_max[i]);
}
std::vector<value_type> batch_sum(batch_shape.elements(), value_type(0)); for(std::size_t j = 0; j < n_dims; ++j)
shape_for_each(output_shape, [&](auto idx) { {
auto index = this->compute_batch_index(idx, batch_shape, op.axis); idx[op.axis] = j;
batch_sum[index] += output(idx.begin(), idx.end()); batch_sum[i] += output(idx.begin(), idx.end());
}); }
shape_for_each(output_shape, [&](auto idx) { for(std::size_t j = 0; j < n_dims; ++j)
auto index = this->compute_batch_index(idx, batch_shape, op.axis); {
output(idx.begin(), idx.end()) /= batch_sum[index]; idx[op.axis] = j;
output(idx.begin(), idx.end()) /= batch_sum[i];
}
}); });
}); });
...@@ -597,49 +599,50 @@ struct cpu_logsoftmax ...@@ -597,49 +599,50 @@ struct cpu_logsoftmax
std::string name() const { return "cpu::logsoftmax"; } std::string name() const { return "cpu::logsoftmax"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
template <typename T>
std::size_t compute_batch_index(T idx, const shape& batch_shape, int axis) const
{
idx[axis] = 0;
return batch_shape.index(idx);
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
auto batch_lens = output_shape.lens(); auto batch_lens = output_shape.lens();
std::size_t n_dims = batch_lens[op.axis];
batch_lens[op.axis] = 1; batch_lens[op.axis] = 1;
shape batch_shape{shape::int32_type, batch_lens}; shape batch_shape{shape::int32_type, batch_lens};
// use a parallel implementation to acheive better performance
// one thread for one batch
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
using value_type = typename decltype(input)::value_type; using value_type = typename decltype(input)::value_type;
std::vector<value_type> batch_max(batch_shape.elements(), std::vector<value_type> batch_max(batch_shape.elements(),
std::numeric_limits<value_type>::lowest()); std::numeric_limits<value_type>::lowest());
shape_for_each(output_shape, [&](auto idx) { std::vector<value_type> batch_sum(batch_shape.elements(), value_type(0));
auto index = this->compute_batch_index(idx, batch_shape, op.axis);
batch_max[index] = std::max(batch_max[index], input(idx.begin(), idx.end()));
});
shape_for_each(output_shape, [&](auto idx) { par_for(batch_shape.elements(), [&](auto i) {
auto index = this->compute_batch_index(idx, batch_shape, op.axis); auto idx = batch_shape.multi(i);
output(idx.begin(), idx.end()) = input(idx.begin(), idx.end()) - batch_max[index]; for(std::size_t j = 0; j < n_dims; ++j)
}); {
idx[op.axis] = j;
batch_max[i] = std::max(batch_max[i], input(idx.begin(), idx.end()));
}
std::vector<value_type> batch_sum(batch_shape.elements(), value_type(0)); for(std::size_t j = 0; j < n_dims; ++j)
shape_for_each(output_shape, [&](auto idx) { {
auto index = this->compute_batch_index(idx, batch_shape, op.axis); idx[op.axis] = j;
batch_sum[index] += std::exp(output(idx.begin(), idx.end())); std::size_t index = output_shape.index(idx);
}); output[index] = input[index] - batch_max[i];
}
for(std::size_t i = 0; i < batch_sum.size(); ++i) for(std::size_t j = 0; j < n_dims; ++j)
{ {
batch_sum[i] = std::log(batch_sum[i]); idx[op.axis] = j;
batch_sum[i] += std::exp(output(idx.begin(), idx.end()));
} }
shape_for_each(output_shape, [&](auto idx) { batch_sum[i] = std::log(batch_sum[i]);
auto index = this->compute_batch_index(idx, batch_shape, op.axis);
output(idx.begin(), idx.end()) -= batch_sum[index]; for(std::size_t j = 0; j < n_dims; ++j)
{
idx[op.axis] = j;
output(idx.begin(), idx.end()) -= batch_sum[i];
}
}); });
}); });
......
...@@ -12,6 +12,8 @@ endif() ...@@ -12,6 +12,8 @@ endif()
add_library(migraphx_device add_library(migraphx_device
device/add.cpp device/add.cpp
device/argmax.cpp
device/argmin.cpp
device/max.cpp device/max.cpp
device/min.cpp device/min.cpp
device/exp.cpp device/exp.cpp
...@@ -44,6 +46,8 @@ target_include_directories(migraphx_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURR ...@@ -44,6 +46,8 @@ target_include_directories(migraphx_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURR
target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/device/include>) target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/device/include>)
add_library(migraphx_gpu add_library(migraphx_gpu
argmax.cpp
argmin.cpp
eliminate_workspace.cpp eliminate_workspace.cpp
fuse_ops.cpp fuse_ops.cpp
hip.cpp hip.cpp
......
#include <migraphx/gpu/argmax.hpp>
#include <migraphx/gpu/device/argmax.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_argmax::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2).standard();
return op.compute_shape({inputs.at(0)});
}
argument hip_argmax::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::argmax(ctx.get_stream().get(), args.back(), args.front(), op.axis);
return args.back();
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/argmin.hpp>
#include <migraphx/gpu/device/argmin.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_argmin::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2).standard();
return op.compute_shape({inputs.at(0)});
}
argument hip_argmin::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::argmin(ctx.get_stream().get(), args.back(), args.front(), op.axis);
return args.back();
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/argmax.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/arg_op.hpp>
#include <migraphx/gpu/hip.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void argmax(hipStream_t stream, const argument& result, const argument& arg, int64_t axis)
{
arg_op(argmax_op{}, stream, result, arg, axis);
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/argmin.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/arg_op.hpp>
#include <migraphx/gpu/hip.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void argmin(hipStream_t stream, const argument& result, const argument& arg, int64_t axis)
{
arg_op(argmin_op{}, stream, result, arg, axis);
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -20,10 +20,12 @@ argument concat(hipStream_t stream, ...@@ -20,10 +20,12 @@ argument concat(hipStream_t stream,
auto&& arg = args[j]; auto&& arg = args[j];
std::size_t nelements = arg.get_shape().elements(); std::size_t nelements = arg.get_shape().elements();
auto offset = offsets[j]; auto offset = offsets[j];
hip_visit_all(args.back(), arg)([&](auto output, auto input) { shape arg_shape{arg.get_shape().type(), arg.get_shape().lens()};
hip_visit_all(args.back(), arg, arg_shape)([&](auto output, auto input, auto input_shape) {
gs_launch(stream, nelements)([=](auto i) { gs_launch(stream, nelements)([=](auto i) {
auto idx = output.get_shape().index(input.get_shape().multi(i)); auto input_idx = input_shape.multi(i);
output.data()[idx + offset] = input.data()[i]; auto idx = output.get_shape().index(input_idx);
output.data()[idx + offset] = input[input_idx];
}); });
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment