"vscode:/vscode.git/clone" did not exist on "b22ebd44857aa87b7223e53f1cf0f518569fb1d4"
Unverified Commit bc80dee8 authored by mvermeulen's avatar mvermeulen Committed by GitHub
Browse files

Merge pull request #265 from ROCmSoftwarePlatform/tf-transpose

Transpose each layer of TF
parents 8d5a2210 2ee59b2b
...@@ -7,6 +7,14 @@ ...@@ -7,6 +7,14 @@
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_reshapes.hpp>
namespace migraphx { namespace migraphx {
namespace driver { namespace driver {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -17,6 +25,7 @@ struct loader ...@@ -17,6 +25,7 @@ struct loader
std::string file_type; std::string file_type;
bool is_nhwc = true; bool is_nhwc = true;
unsigned trim = 0; unsigned trim = 0;
bool optimize = false;
void parse(argument_parser& ap) void parse(argument_parser& ap)
{ {
...@@ -26,6 +35,7 @@ struct loader ...@@ -26,6 +35,7 @@ struct loader
ap(is_nhwc, {"--nhwc"}, ap.help("Treat tensorflow format as nhwc"), ap.set_value(true)); ap(is_nhwc, {"--nhwc"}, ap.help("Treat tensorflow format as nhwc"), ap.set_value(true));
ap(is_nhwc, {"--nchw"}, ap.help("Treat tensorflow format as nchw"), ap.set_value(false)); ap(is_nhwc, {"--nchw"}, ap.help("Treat tensorflow format as nchw"), ap.set_value(false));
ap(trim, {"--trim", "-t"}, ap.help("Trim instructions from the end")); ap(trim, {"--trim", "-t"}, ap.help("Trim instructions from the end"));
ap(optimize, {"--optimize"}, ap.help("Optimize when reading"), ap.set_value(true));
} }
program load() program load()
...@@ -48,6 +58,20 @@ struct loader ...@@ -48,6 +58,20 @@ struct loader
auto last = std::prev(p.end(), trim); auto last = std::prev(p.end(), trim);
p.remove_instructions(last, p.end()); p.remove_instructions(last, p.end());
} }
if(optimize)
migraphx::run_passes(p,
{
migraphx::eliminate_identity{},
migraphx::dead_code_elimination{},
migraphx::simplify_algebra{},
migraphx::dead_code_elimination{},
migraphx::simplify_reshapes{},
migraphx::dead_code_elimination{},
migraphx::propagate_constant{},
migraphx::dead_code_elimination{},
migraphx::eliminate_pad{},
migraphx::dead_code_elimination{},
});
return p; return p;
} }
}; };
......
...@@ -190,6 +190,23 @@ auto pop_back_args(Ts&&... xs) ...@@ -190,6 +190,23 @@ auto pop_back_args(Ts&&... xs)
}; };
} }
template <class T>
struct always_f
{
T x;
template <class... Ts>
constexpr T operator()(Ts&&...) const
{
return x;
}
};
template <class T>
auto always(T x)
{
return always_f<T>{x};
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -20,6 +21,12 @@ struct matcher_context ...@@ -20,6 +21,12 @@ struct matcher_context
std::unordered_map<std::string, instruction_ref> instructions; std::unordered_map<std::string, instruction_ref> instructions;
instruction_ref not_found() const { return last; } instruction_ref not_found() const { return last; }
template <class M>
bool matched(M m, instruction_ref ins)
{
return m.match(*this, ins) != this->not_found();
}
private: private:
instruction_ref last; instruction_ref last;
}; };
...@@ -205,74 +212,147 @@ matcher_result match_instruction(program& p, instruction_ref ins, M&& m) ...@@ -205,74 +212,147 @@ matcher_result match_instruction(program& p, instruction_ref ins, M&& m)
return result; return result;
} }
/// Find matches for an instruction in the program
template <class... Ms>
void find_matches(program& p, instruction_ref ins, Ms&&... ms)
{
bool match = false;
each_args(
[&](auto&& m) {
if(match)
return;
auto r = match_instruction(p, ins, m.matcher());
if(r.result == p.end())
return;
m.apply(p, r);
match = true;
},
ms...);
}
/// Find matches in a program /// Find matches in a program
template <class... Ms> template <class... Ms>
void find_matches(program& p, Ms&&... ms) void find_matches(program& p, Ms&&... ms)
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
bool match = false; find_matches(p, ins, ms...);
each_args(
[&](auto&& m) {
if(match)
return;
auto r = match_instruction(p, ins, m.matcher());
if(r.result == p.end())
return;
m.apply(p, r);
match = true;
},
ms...);
} }
} }
template <class... Ts> struct lazy_and
auto all_of(Ts... ms)
{ {
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) { template <class F, class G>
bool matches = fold([&](auto x, auto y) { bool operator()(F f, G g) const
return x and y.match(ctx, ins) != ctx.not_found(); {
})(true, ms...); return f() and g();
if(matches) }
return ins; };
return ctx.not_found();
});
}
template <class... Ts> struct lazy_or
auto none_of(Ts... ms)
{ {
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) { template <class F, class G>
bool matches = fold([&](auto x, auto y) { bool operator()(F f, G g) const
return x and y.match(ctx, ins) == ctx.not_found(); {
})(true, ms...); return f() or g();
if(matches) }
return ins; };
return ctx.not_found();
}); template <class Op, bool Start, bool Matches>
struct match_fold_f
{
template <class... Ms>
static bool fold_matchers(matcher_context& ctx, instruction_ref ins, Ms... ms)
{
Op op;
auto matched = [&](auto m) { return [=, &ctx] { return ctx.matched(m, ins); }; };
return fold([&](auto x, auto y) { return op(always(x), matched(y)); })(Start, ms...);
}
template <class Pack>
static bool fold_matchers_pack(matcher_context& ctx, instruction_ref ins, Pack p)
{
return p([&](auto... ms) { return match_fold_f::fold_matchers(ctx, ins, ms...); });
}
template <class... Ts>
auto operator()(Ts... ms) const
{
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
bool matches = match_fold_f::fold_matchers(ctx, ins, ms...);
if(matches == Matches)
return ins;
return ctx.not_found();
});
}
template <class Selector>
auto operator[](Selector select) const
{
return [=](auto... ms) {
// Workaround ICE on gcc by packing matchers into an object
auto mpack = pack(ms...);
return make_bf_matcher([=](matcher_context& ctx, instruction_ref start) {
Op op;
bool matches = Start;
select(start, [&](auto ins) {
auto fm = [&] { return match_fold_f::fold_matchers_pack(ctx, ins, mpack); };
matches = op(always(matches), fm);
});
if(matches == Matches)
return start;
return ctx.not_found();
});
};
}
};
const constexpr auto all_of = match_fold_f<lazy_and, true, true>{};
const constexpr auto any_of = match_fold_f<lazy_or, false, true>{};
const constexpr auto none_of = match_fold_f<lazy_or, false, false>{};
inline auto inputs()
{
return [](auto ins, auto f) {
for(auto&& x : ins->inputs())
f(x);
};
} }
template <class... Ts> inline auto outputs()
auto any_of(Ts... ms)
{ {
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) { return [](auto ins, auto f) {
bool matches = fold([&](auto x, auto y) { for(auto&& x : ins->outputs())
return x or y.match(ctx, ins) != ctx.not_found(); f(x);
})(false, ms...); };
if(matches)
return ins;
return ctx.not_found();
});
} }
MIGRAPHX_PRED_MATCHER(any, instruction_ref) { return true; } MIGRAPHX_PRED_MATCHER(any, instruction_ref) { return true; }
MIGRAPHX_PRED_MATCHER(none, instruction_ref) { return false; } MIGRAPHX_PRED_MATCHER(none, instruction_ref) { return false; }
MIGRAPHX_PRED_MATCHER(standard_shape, instruction_ref ins) { return ins->get_shape().standard(); } MIGRAPHX_PRED_MATCHER(standard_shape, instruction_ref ins) { return ins->get_shape().standard(); }
MIGRAPHX_PRED_MATCHER(not_standard_shape, instruction_ref ins)
{
return not ins->get_shape().standard();
}
MIGRAPHX_PRED_MATCHER(broadcast_shape, instruction_ref ins) MIGRAPHX_PRED_MATCHER(broadcast_shape, instruction_ref ins)
{ {
return ins->get_shape().broadcasted(); return ins->get_shape().broadcasted();
} }
MIGRAPHX_PRED_MATCHER(transpose_shape, instruction_ref ins)
{
return ins->get_shape().transposed();
}
MIGRAPHX_PRED_MATCHER(same_input_shapes, instruction_ref ins)
{
if(ins->inputs().empty())
return false;
auto s = ins->inputs().front()->get_shape();
return std::all_of(
ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x->get_shape() == s; });
}
MIGRAPHX_BASIC_MATCHER(output, matcher_context& ctx, instruction_ref ins) MIGRAPHX_BASIC_MATCHER(output, matcher_context& ctx, instruction_ref ins)
{ {
if(ins->outputs().size() == 1) if(ins->outputs().size() == 1)
...@@ -289,10 +369,39 @@ MIGRAPHX_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins) ...@@ -289,10 +369,39 @@ MIGRAPHX_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins)
return ctx.not_found(); return ctx.not_found();
} }
inline auto name(std::string name) template <class... Ms>
auto skip_output(Ms... ms)
{
auto m = any_of(ms...);
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) {
return fix<instruction_ref>([&](auto self, auto ins) {
if(ins->outputs().size() == 1)
{
auto next = ins->outputs().front();
if(ctx.matched(m, next))
{
auto skipped_next = self(next);
if(skipped_next != ctx.not_found())
return skipped_next;
}
return next;
}
return ctx.not_found();
})(start);
});
}
inline auto name(std::string s)
{ {
return make_basic_pred_matcher( return make_basic_pred_matcher(
[ =, name = std::move(name) ](instruction_ref ins) { return ins->name() == name; }); [ =, s = std::move(s) ](instruction_ref ins) { return ins->name() == s; });
}
inline auto name(std::unordered_set<std::string> names)
{
return make_basic_pred_matcher([ =, names = std::move(names) ](instruction_ref ins) {
return names.count(ins->name()) > 0;
});
} }
inline auto nargs(std::size_t n) inline auto nargs(std::size_t n)
...@@ -338,6 +447,23 @@ inline auto either_arg(std::size_t i, std::size_t j) ...@@ -338,6 +447,23 @@ inline auto either_arg(std::size_t i, std::size_t j)
}; };
} }
template <class M>
auto same_shape(M m)
{
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) {
auto i = m.match(ctx, ins);
if(i != ctx.not_found() and i->get_shape() == ins->get_shape())
return ins;
return ctx.not_found();
});
}
template <class... Ms>
auto same_shape(Ms... ms)
{
return all_of(same_shape(ms)...);
}
} // namespace match } // namespace match
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -2,14 +2,17 @@ ...@@ -2,14 +2,17 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/op/as_shape.hpp> #include <migraphx/op/as_shape.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/matcher.hpp>
#include <unordered_set> #include <unordered_set>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
bool is_reshaper(instruction_ref ins) const auto& reshaper_names()
{ {
// clang-format off // clang-format off
static const std::unordered_set<std::string> names = { static const std::unordered_set<std::string> names = {
...@@ -19,17 +22,10 @@ bool is_reshaper(instruction_ref ins) ...@@ -19,17 +22,10 @@ bool is_reshaper(instruction_ref ins)
"unsqueeze" "unsqueeze"
}; };
// clang-format on // clang-format on
return contains(names, ins->name()); return names;
} }
bool is_transpose_output(instruction_ref ins) bool is_reshaper(instruction_ref ins) { return contains(reshaper_names(), ins->name()); }
{
if(ins->outputs().size() != 1)
return false;
if(ins->outputs().front()->name() == "contiguous")
return is_transpose_output(ins->outputs().front());
return ins->outputs().front()->name() == "transpose";
}
instruction_ref find_transpose_input(instruction_ref ins) instruction_ref find_transpose_input(instruction_ref ins)
{ {
...@@ -42,62 +38,189 @@ instruction_ref find_transpose_input(instruction_ref ins) ...@@ -42,62 +38,189 @@ instruction_ref find_transpose_input(instruction_ref ins)
return ins; return ins;
} }
void simplify_reshapes::apply(program& p) const auto get_transpose_dims(instruction_ref ins)
{ {
auto end = std::prev(p.end()); return any_cast<const op::transpose&>(ins->get_operator()).dims;
for(auto ins : iterator_for(p)) }
std::vector<int64_t> reorder_dims(std::vector<int64_t> dims, std::vector<int64_t> permutation)
{
std::vector<int64_t> result(dims.size());
assert(dims.size() == permutation.size());
for(std::size_t i = 0; i < dims.size(); i++)
{ {
if(ins == end and ins->name() == "contiguous") result[i] = dims[permutation[i]];
continue; }
// Skip possible dead instructions return result;
if(ins->outputs().empty() and ins != end) }
continue;
if(is_reshaper(ins)) bool is_no_transpose(const std::vector<int64_t>& dims)
{
if(dims.empty())
return true;
if(dims.front() != 0)
return false;
return std::adjacent_find(
dims.begin(), dims.end(), [](auto x, auto y) { return (y - x) != 1; }) == dims.end();
}
template <class Vector, class Op>
std::vector<int64_t> sort_permutation(const Vector& data, Op op)
{
std::vector<std::int64_t> result(data.size());
std::iota(result.begin(), result.end(), 0);
std::sort(result.begin(), result.end(), [&](auto x, auto y) { return op(data[x], data[y]); });
return result;
}
std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation)
{
return sort_permutation(permutation, std::less<>{});
}
std::vector<int64_t> find_permutation(const shape& s)
{
return sort_permutation(s.strides(), std::greater<>{});
}
struct find_reshaper
{
auto matcher() const
{
return match::name(reshaper_names())(
match::any_of[match::outputs()](match::name(reshaper_names())));
}
void apply(program& p, const match::matcher_result& mr) const
{
auto ins = mr.result;
std::vector<instruction_ref> reshapes{ins};
while(is_reshaper(reshapes.back()))
{ {
if(std::any_of(ins->outputs().begin(), ins->outputs().end(), &is_reshaper)) assert(!reshapes.back()->inputs().empty());
continue; assert(p.has_instruction(reshapes.back()->inputs().front()));
// Gather reshapes auto input = reshapes.back()->inputs().front();
std::vector<instruction_ref> reshapes{ins}; reshapes.push_back(input);
while(is_reshaper(reshapes.back())) }
{
assert(!reshapes.back()->inputs().empty());
assert(p.has_instruction(reshapes.back()->inputs().front()));
auto input = reshapes.back()->inputs().front();
reshapes.push_back(input);
}
std::pair<instruction_ref, instruction_ref> r{p.end(), p.end()}; std::pair<instruction_ref, instruction_ref> r{p.end(), p.end()};
for(auto start : iterator_for(reshapes)) for(auto start : iterator_for(reshapes))
{ {
auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) { auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) {
return i->get_shape() == (*start)->get_shape() and i != (*start); return i->get_shape() == (*start)->get_shape() and i != (*start);
}); });
if(last != reshapes.rend()) if(last != reshapes.rend())
{
r = std::make_pair(*start, *last);
break;
}
}
if(r.first != r.second)
{ {
p.replace_instruction(r.first, r.second); r = std::make_pair(*start, *last);
break;
} }
} }
else if(ins->name() == "transpose") if(r.first != r.second)
{
p.replace_instruction(r.first, r.second);
}
}
};
struct find_nop_reshapes
{
auto matcher() const
{
auto reshapes = reshaper_names();
reshapes.insert("transpose");
reshapes.insert("slice");
return match::name(reshapes)(match::same_shape(match::arg(0)));
}
void apply(program& p, const match::matcher_result& mr) const
{
auto ins = mr.result;
p.replace_instruction(ins, ins->inputs().front());
}
};
struct find_transpose
{
auto matcher() const
{
return match::name("transpose")(match::none_of(
match::skip_output(match::name("contiguous"))(match::name("transpose"))));
}
void apply(program& p, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto x = ins;
auto t = ins;
std::vector<std::int64_t> dims(ins->get_shape().lens().size());
std::iota(dims.begin(), dims.end(), 0);
do
{
dims = reorder_dims(get_transpose_dims(t), dims);
x = t;
t = find_transpose_input(x);
} while(x != t and t->name() == "transpose");
if(t == ins or t->name() != "transpose")
return;
if(is_no_transpose(dims))
{ {
if(is_transpose_output(ins))
continue;
auto x = ins;
auto t = ins;
do
{
x = t;
t = find_transpose_input(x);
} while(x != t and t->name() == "transpose");
if(t == ins or t->name() != "transpose")
continue;
p.replace_instruction(ins, t->inputs().front()); p.replace_instruction(ins, t->inputs().front());
} }
else
{
p.replace_instruction(ins, op::transpose{{dims}}, t->inputs().front());
}
}
};
struct find_concat_transpose
{
auto matcher() const
{
return match::name("concat")(match::same_input_shapes(),
match::all_of[match::inputs()](match::transpose_shape()));
}
void apply(program& p, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto s = ins->inputs().front()->get_shape();
assert(s.transposed());
auto op = any_cast<op::concat>(ins->get_operator());
auto permutation = find_permutation(s);
auto ipermutation = invert_permutation(permutation);
op.axis = ipermutation[op.axis];
std::vector<instruction_ref> inputs;
std::transform(
ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) {
if(i->name() == "transpose" and i->inputs().front()->get_shape().standard())
return i->inputs().front();
return p.insert_instruction(ins, op::transpose{permutation}, i);
});
auto concat = p.insert_instruction(ins, op, inputs);
auto t = p.insert_instruction(ins, op::transpose{ipermutation}, concat);
assert(ins->get_shape().lens() == t->get_shape().lens());
p.replace_instruction(ins, t);
}
};
void simplify_reshapes::apply(program& p) const
{
auto end = std::prev(p.end());
for(auto ins : iterator_for(p))
{
if(ins == end and ins->name() == "contiguous")
continue;
// Skip possible dead instructions
if(ins->outputs().empty() and ins != end)
continue;
match::find_matches(p,
ins,
find_nop_reshapes{},
find_reshaper{},
find_transpose{},
find_concat_transpose{});
} }
} }
......
...@@ -20,10 +20,12 @@ argument concat(hipStream_t stream, ...@@ -20,10 +20,12 @@ argument concat(hipStream_t stream,
auto&& arg = args[j]; auto&& arg = args[j];
std::size_t nelements = arg.get_shape().elements(); std::size_t nelements = arg.get_shape().elements();
auto offset = offsets[j]; auto offset = offsets[j];
hip_visit_all(args.back(), arg)([&](auto output, auto input) { shape arg_shape{arg.get_shape().type(), arg.get_shape().lens()};
hip_visit_all(args.back(), arg, arg_shape)([&](auto output, auto input, auto input_shape) {
gs_launch(stream, nelements)([=](auto i) { gs_launch(stream, nelements)([=](auto i) {
auto idx = output.get_shape().index(input.get_shape().multi(i)); auto input_idx = input_shape.multi(i);
output.data()[idx + offset] = input.data()[i]; auto idx = output.get_shape().index(input_idx);
output.data()[idx + offset] = input[input_idx];
}); });
}); });
} }
......
...@@ -200,12 +200,33 @@ struct hip_add_relu ...@@ -200,12 +200,33 @@ struct hip_add_relu
} }
}; };
void move_broadcasted_back(std::vector<instruction_ref>& args)
{
// Ensure the last arguments is the broadcasted one
auto it = std::find_if(
args.begin(), args.end(), [](auto arg) { return arg->get_shape().broadcasted(); });
if(it != args.end())
std::swap(*it, *std::prev(args.end(), 2));
}
void move_standard_front(std::vector<instruction_ref>& args)
{
// Ensure the first arguments is the standard one
auto it = std::find_if(
args.begin(), args.end(), [](auto arg) { return arg->get_shape().standard(); });
if(it != args.end())
std::swap(*it, args.front());
}
struct find_add_relu struct find_add_relu
{ {
auto matcher() const auto matcher() const
{ {
return match::name("gpu::relu")(match::arg(0)( return match::name("gpu::relu")(
match::any_of(match::name("gpu::add"), match::name("hip::triadd")).bind("add"))); match::arg(0)(match::any_of(match::name("gpu::add"),
match::name("hip::triadd"),
match::any_of[match::inputs()](match::standard_shape()))
.bind("add")));
} }
void apply(program& p, match::matcher_result r) const void apply(program& p, match::matcher_result r) const
...@@ -213,6 +234,9 @@ struct find_add_relu ...@@ -213,6 +234,9 @@ struct find_add_relu
auto add_ins = r.instructions["add"]; auto add_ins = r.instructions["add"];
auto ins = r.result; auto ins = r.result;
auto args = add_ins->inputs(); auto args = add_ins->inputs();
move_standard_front(args);
move_broadcasted_back(args);
// Use the allocation from the relu operator // Use the allocation from the relu operator
args.back() = ins->inputs().back(); args.back() = ins->inputs().back();
if(add_ins->name() == "gpu::add") if(add_ins->name() == "gpu::add")
...@@ -226,24 +250,26 @@ struct find_triadd ...@@ -226,24 +250,26 @@ struct find_triadd
{ {
auto matcher() const auto matcher() const
{ {
return match::name("gpu::add")(match::either_arg(0, 1)(match::name("gpu::add").bind("add"), return match::name("gpu::add")(match::either_arg(0, 1)(
match::any().bind("input"))); match::name("gpu::add").bind("add"),
match::any(match::any_of[match::inputs()](match::standard_shape())).bind("input")));
} }
void apply(program& p, match::matcher_result r) const void apply(program& p, match::matcher_result r) const
{ {
auto add_ins = r.instructions["add"]; auto add_ins = r.instructions["add"];
auto input_ins = r.instructions["input"]; auto input_ins = r.instructions["input"];
auto ins = r.result; auto ins = r.result;
auto args = add_ins->inputs(); auto args = add_ins->inputs();
assert(add_ins != input_ins);
auto is_broadcasted = [](auto arg) { return arg->get_shape().broadcasted(); }; auto is_broadcasted = [](auto arg) { return arg->get_shape().broadcasted(); };
if(std::count_if(args.begin(), args.end(), is_broadcasted) > 1) if(std::count_if(args.begin(), args.end(), is_broadcasted) > 1)
return; return;
args.insert(args.begin(), input_ins); args.insert(args.begin(), input_ins);
// Ensure the last arguments is the broadcasted one move_standard_front(args);
auto it = std::find_if(args.begin(), args.end(), is_broadcasted); move_broadcasted_back(args);
if(it != args.end())
std::swap(*it, *std::prev(args.end(), 2));
args.back() = ins->inputs().back(); args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_triadd{}, args); p.replace_instruction(ins, hip_triadd{}, args);
} }
...@@ -402,8 +428,8 @@ void fuse_ops::apply(program& p) const ...@@ -402,8 +428,8 @@ void fuse_ops::apply(program& p) const
// clang-format off // clang-format off
match::find_matches(p, find_triadd{}); match::find_matches(p, find_triadd{});
match::find_matches(p, match::find_matches(p,
// find_conv_bias_relu{ctx}, find_conv_bias_relu{ctx},
// find_conv_bias{ctx}, find_conv_bias{ctx},
find_add_relu{} find_add_relu{}
); );
// clang-format on // clang-format on
......
...@@ -36,6 +36,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const ...@@ -36,6 +36,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
// clang-format off // clang-format off
return return
{ {
dead_code_elimination{},
simplify_reshapes{},
dead_code_elimination{}, dead_code_elimination{},
eliminate_identity{}, eliminate_identity{},
eliminate_pad{}, eliminate_pad{},
...@@ -48,11 +50,11 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const ...@@ -48,11 +50,11 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
//dead_code_elimination{}, //dead_code_elimination{},
simplify_algebra{}, simplify_algebra{},
dead_code_elimination{}, dead_code_elimination{},
propagate_constant{},
dead_code_elimination{},
auto_contiguous{}, auto_contiguous{},
simplify_reshapes{}, simplify_reshapes{},
dead_code_elimination{}, dead_code_elimination{},
propagate_constant{},
dead_code_elimination{},
lowering{ctx}, lowering{ctx},
eliminate_concat{concat_gpu_optimization{}}, eliminate_concat{concat_gpu_optimization{}},
dead_code_elimination{}, dead_code_elimination{},
......
...@@ -37,6 +37,48 @@ struct tf_parser ...@@ -37,6 +37,48 @@ struct tf_parser
std::unordered_map<std::string, op_func> ops; std::unordered_map<std::string, op_func> ops;
bool should_transpose(instruction_ref ins) const
{
return is_nhwc and ins->get_shape().lens().size() == 4;
}
instruction_ref to_nhwc(instruction_ref ins)
{
if(should_transpose(ins))
return prog.add_instruction(op::transpose{{0, 2, 3, 1}}, ins);
return ins;
}
instruction_ref to_nchw(instruction_ref ins)
{
if(should_transpose(ins))
return prog.add_instruction(op::transpose{{0, 3, 1, 2}}, ins);
return ins;
}
instruction_ref to_kcxy(instruction_ref ins)
{
if(should_transpose(ins))
return prog.add_instruction(op::transpose{{3, 2, 0, 1}}, ins);
return ins;
}
instruction_ref make_contiguous(instruction_ref ins)
{
if(ins->get_shape().standard())
return ins;
else
return prog.add_instruction(op::contiguous{}, ins);
}
std::vector<instruction_ref> to_nchw(const std::vector<instruction_ref>& args)
{
std::vector<instruction_ref> result(args.size());
std::transform(
args.begin(), args.end(), result.begin(), [&](auto ins) { return this->to_nchw(ins); });
return result;
}
std::vector<size_t> parse_axes(const attribute_map& attributes, const std::string& s) const std::vector<size_t> parse_axes(const attribute_map& attributes, const std::string& s) const
{ {
auto attrs = attributes.at(s).list().i(); auto attrs = attributes.at(s).list().i();
...@@ -119,59 +161,67 @@ struct tf_parser ...@@ -119,59 +161,67 @@ struct tf_parser
add_mem_op("AvgPool", &tf_parser::parse_pooling); add_mem_op("AvgPool", &tf_parser::parse_pooling);
add_mem_op("BiasAdd", &tf_parser::parse_biasadd); add_mem_op("BiasAdd", &tf_parser::parse_biasadd);
add_mem_op("ConcatV2", &tf_parser::parse_concat); add_mem_op("ConcatV2", &tf_parser::parse_concat, false);
add_mem_op("Const", &tf_parser::parse_constant); add_mem_op("Const", &tf_parser::parse_constant);
add_mem_op("Conv2D", &tf_parser::parse_conv); add_mem_op("Conv2D", &tf_parser::parse_conv);
add_mem_op("DepthwiseConv2dNative", &tf_parser::parse_depthwiseconv); add_mem_op("DepthwiseConv2dNative", &tf_parser::parse_depthwiseconv);
add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm); add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm);
add_mem_op("MatMul", &tf_parser::parse_matmul); add_mem_op("MatMul", &tf_parser::parse_matmul, false);
add_mem_op("MaxPool", &tf_parser::parse_pooling); add_mem_op("MaxPool", &tf_parser::parse_pooling);
add_mem_op("Mean", &tf_parser::parse_mean); add_mem_op("Mean", &tf_parser::parse_mean);
add_mem_op("Pack", &tf_parser::parse_pack); add_mem_op("Pack", &tf_parser::parse_pack, false);
add_mem_op("Pad", &tf_parser::parse_pad); add_mem_op("Pad", &tf_parser::parse_pad);
add_mem_op("Reshape", &tf_parser::parse_reshape); add_mem_op("Reshape", &tf_parser::parse_reshape, false);
add_mem_op("Softmax", &tf_parser::parse_softmax); add_mem_op("Softmax", &tf_parser::parse_softmax);
add_mem_op("Squeeze", &tf_parser::parse_squeeze); add_mem_op("Squeeze", &tf_parser::parse_squeeze, false);
add_mem_op("StridedSlice", &tf_parser::parse_stridedslice); add_mem_op("StridedSlice", &tf_parser::parse_stridedslice);
} }
template <class F> template <class F>
void add_op(std::string name, F f) void add_op(std::string name, F f, bool transpose = true)
{ {
ops.emplace(name, f); if(transpose)
} {
ops.emplace(name,
// Multi output op op_func{[=](const attribute_map& attributes,
template <class F> const std::vector<instruction_ref>& args) -> instruction_ref {
void add_multi_op(std::string name, F f) return to_nhwc(f(attributes, to_nchw(args)));
{ }});
ops.emplace(name, f); }
else
{
ops.emplace(name, f);
}
} }
template <class F> template <class F>
void add_mem_op(std::string name, F f) void add_mem_op(std::string name, F f, bool transpose = true)
{ {
add_op(name, [=](auto&&... xs) { add_op(name,
return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...); [=](auto&&... xs) {
}); return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
},
transpose);
} }
template <class T> template <class T>
void add_binary_op(std::string name, T x) void add_binary_op(std::string name, T x)
{ {
add_op(name, [this, x](const attribute_map& attributes, std::vector<instruction_ref> args) { add_op(name,
if(args.size() != 2) [this, x](const attribute_map&, std::vector<instruction_ref> args) {
MIGRAPHX_THROW("binary operators should have 2 operands"); if(args.size() != 2)
auto l0 = args[1]; MIGRAPHX_THROW("binary operators should have 2 operands");
if(contains(attributes, "data_format")) // TODO
{ // if(contains(attributes, "data_format"))
if(is_nhwc) // {
{ // if(is_nhwc)
l0 = prog.add_instruction(op::transpose{{0, 3, 1, 2}}, args[1]); // {
} // l0 = prog.add_instruction(op::transpose{{0, 3, 1, 2}}, args[1]);
} // }
return add_broadcastable_binary_op(args[0], l0, x); // }
}); return add_broadcastable_binary_op(args[0], args[1], x);
},
false);
} }
template <class T> template <class T>
...@@ -210,20 +260,22 @@ struct tf_parser ...@@ -210,20 +260,22 @@ struct tf_parser
auto l0 = prog.add_instruction(op::multibroadcast{output_lens}, arg0); auto l0 = prog.add_instruction(op::multibroadcast{output_lens}, arg0);
auto l1 = prog.add_instruction(op::multibroadcast{output_lens}, arg1); auto l1 = prog.add_instruction(op::multibroadcast{output_lens}, arg1);
return prog.add_instruction(x, l0, l1); return to_nhwc(prog.add_instruction(x, to_nchw(l0), to_nchw(l1)));
} }
else else
{ {
return prog.add_instruction(x, {arg0, arg1}); return to_nhwc(prog.add_instruction(x, {to_nchw(arg0), to_nchw(arg1)}));
} }
} }
template <class T> template <class T>
void add_generic_op(std::string name, T x) void add_generic_op(std::string name, T x, bool transpose = true)
{ {
add_op(name, [this, x](const attribute_map&, std::vector<instruction_ref> args) { add_op(name,
return prog.add_instruction(x, args); [this, x](const attribute_map&, std::vector<instruction_ref> args) {
}); return prog.add_instruction(x, args);
},
transpose);
} }
instruction_ref instruction_ref
...@@ -253,7 +305,7 @@ struct tf_parser ...@@ -253,7 +305,7 @@ struct tf_parser
{ {
// get index for axis within args // get index for axis within args
size_t axis_idx = attributes.at("N").i(); size_t axis_idx = attributes.at("N").i();
size_t axis = parse_axis(args[axis_idx]->eval().at<int64_t>()); size_t axis = args[axis_idx]->eval().at<int64_t>();
op::concat op{axis}; op::concat op{axis};
// return only first N arguments (assuming last index is the axis value) // return only first N arguments (assuming last index is the axis value)
return prog.add_instruction( return prog.add_instruction(
...@@ -264,16 +316,8 @@ struct tf_parser ...@@ -264,16 +316,8 @@ struct tf_parser
attribute_map attributes, attribute_map attributes,
const std::vector<instruction_ref>&) const std::vector<instruction_ref>&)
{ {
literal v = parse_tensor(attributes.at("value").tensor()); literal v = parse_tensor(attributes.at("value").tensor());
auto l0 = prog.add_literal(v); return prog.add_literal(v);
size_t num_axes = l0->get_shape().lens().size();
if(num_axes >= 4)
{
std::vector<int64_t> transpose_axes = get_axes(num_axes);
reorder_data(transpose_axes);
l0 = prog.add_instruction(op::transpose{transpose_axes}, l0);
}
return l0;
} }
instruction_ref instruction_ref
...@@ -304,22 +348,9 @@ struct tf_parser ...@@ -304,22 +348,9 @@ struct tf_parser
op.dilation[0] = dilation[2]; op.dilation[0] = dilation[2];
op.dilation[1] = dilation[3]; op.dilation[1] = dilation[3];
} }
auto weights = args[1];
// check if weights are from a constant
if(weights->name() != "@param") auto weights = to_kcxy(args[1]);
{ auto l0 = args[0];
if(is_nhwc)
{
weights = prog.add_instruction(op::transpose{{1, 3, 0, 2}}, args[1]);
}
else
{
weights = prog.add_instruction(op::transpose{{3, 2, 0, 1}}, args[1]);
}
}
auto l0 = args[0];
if(contains(attributes, "padding")) if(contains(attributes, "padding"))
{ {
const std::string& pad_mode = attributes.at("padding").s(); const std::string& pad_mode = attributes.at("padding").s();
...@@ -368,8 +399,7 @@ struct tf_parser ...@@ -368,8 +399,7 @@ struct tf_parser
op.padding[1] = padding[1]; op.padding[1] = padding[1];
} }
} }
return prog.add_instruction(op, {l0, to_kcxy(args[1])});
return prog.add_instruction(op, {l0, weights});
} }
instruction_ref parse_depthwiseconv(const std::string&, instruction_ref parse_depthwiseconv(const std::string&,
...@@ -392,6 +422,8 @@ struct tf_parser ...@@ -392,6 +422,8 @@ struct tf_parser
op.stride[0] = stride[2]; op.stride[0] = stride[2];
op.stride[1] = stride[3]; op.stride[1] = stride[3];
} }
auto weights = to_kcxy(args[1]);
if(contains(attributes, "dilations")) if(contains(attributes, "dilations"))
{ {
std::vector<size_t> dilation; std::vector<size_t> dilation;
...@@ -405,20 +437,6 @@ struct tf_parser ...@@ -405,20 +437,6 @@ struct tf_parser
op.dilation[1] = dilation[3]; op.dilation[1] = dilation[3];
} }
auto weights = args[1];
// check if weights are from a constant
if(weights->name() != "@param")
{
if(is_nhwc)
{
weights = prog.add_instruction(op::transpose{{1, 3, 0, 2}}, args[1]);
}
else
{
weights = prog.add_instruction(op::transpose{{3, 2, 0, 1}}, args[1]);
}
}
auto l0 = args[0]; auto l0 = args[0];
if(contains(attributes, "padding")) if(contains(attributes, "padding"))
{ {
...@@ -466,8 +484,8 @@ struct tf_parser ...@@ -466,8 +484,8 @@ struct tf_parser
new_weights_shape[0] = out_channels; new_weights_shape[0] = out_channels;
new_weights_shape[1] = 1; new_weights_shape[1] = 1;
// Make sure weights are contiguous before doing reshape // Make sure weights are contiguous before doing reshape
auto cweights = prog.add_instruction(op::contiguous{}, weights); auto new_weights =
auto new_weights = prog.add_instruction(op::reshape{new_weights_shape}, cweights); prog.add_instruction(op::reshape{new_weights_shape}, make_contiguous(weights));
return prog.add_instruction(op, {l0, new_weights}); return prog.add_instruction(op, {l0, new_weights});
} }
...@@ -535,16 +553,14 @@ struct tf_parser ...@@ -535,16 +553,14 @@ struct tf_parser
MIGRAPHX_THROW("TF_PARSER: axis value of " + to_string(axis) + MIGRAPHX_THROW("TF_PARSER: axis value of " + to_string(axis) +
" must be smaller than input size " + to_string(input_size)); " must be smaller than input size " + to_string(input_size));
} }
// check if input arg needs axis to be converted to NCHW
if(input_size >= 4)
axis = parse_axis(axis);
std::transform( std::transform(
args.begin(), args.begin(),
args.end(), args.end(),
std::back_inserter(unsqueezed_args), std::back_inserter(unsqueezed_args),
[&](instruction_ref arg) { return prog.add_instruction(op::unsqueeze{{axis}}, arg); }); [&](instruction_ref arg) { return prog.add_instruction(op::unsqueeze{{axis}}, arg); });
return prog.add_instruction(op::concat{static_cast<size_t>(axis)}, unsqueezed_args); return to_nhwc(
prog.add_instruction(op::concat{static_cast<size_t>(axis)}, unsqueezed_args));
} }
instruction_ref instruction_ref
...@@ -647,7 +663,7 @@ struct tf_parser ...@@ -647,7 +663,7 @@ struct tf_parser
MIGRAPHX_THROW("reshape needs 2 arguments (input, new_shape)"); MIGRAPHX_THROW("reshape needs 2 arguments (input, new_shape)");
auto s = args[1]->eval(); auto s = args[1]->eval();
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); }); s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
return prog.add_instruction(op, args[0]); return prog.add_instruction(op, make_contiguous(args[0]));
} }
void parse_from(std::istream& is) void parse_from(std::istream& is)
...@@ -678,7 +694,7 @@ struct tf_parser ...@@ -678,7 +694,7 @@ struct tf_parser
std::vector<instruction_ref> args) std::vector<instruction_ref> args)
{ {
op::squeeze op; op::squeeze op;
auto axes = parse_axes(attributes, "squeeze_dims"); auto axes = attributes.at("squeeze_dims").list().i();
copy(axes, std::back_inserter(op.axes)); copy(axes, std::back_inserter(op.axes));
auto args0_dims = args[0]->get_shape().lens(); auto args0_dims = args[0]->get_shape().lens();
if(op.axes.empty()) // no squeeze_dims provided, remove any dim that equals 1 if(op.axes.empty()) // no squeeze_dims provided, remove any dim that equals 1
...@@ -691,7 +707,7 @@ struct tf_parser ...@@ -691,7 +707,7 @@ struct tf_parser
} }
} }
} }
return prog.add_instruction(op, args[0]); return prog.add_instruction(op, make_contiguous(args[0]));
} }
instruction_ref parse_stridedslice(const std::string&, instruction_ref parse_stridedslice(const std::string&,
...@@ -702,11 +718,6 @@ struct tf_parser ...@@ -702,11 +718,6 @@ struct tf_parser
auto starts = args[1]->eval().get<int32_t>().to_vector(); auto starts = args[1]->eval().get<int32_t>().to_vector();
auto ends = args[2]->eval().get<int32_t>().to_vector(); auto ends = args[2]->eval().get<int32_t>().to_vector();
size_t num_axes = args[0]->get_shape().lens().size(); size_t num_axes = args[0]->get_shape().lens().size();
if(num_axes >= 4)
{
reorder_data(starts);
reorder_data(ends);
}
op.starts = std::vector<int64_t>(starts.begin(), starts.end()); op.starts = std::vector<int64_t>(starts.begin(), starts.end());
op.ends = std::vector<int64_t>(ends.begin(), ends.end()); op.ends = std::vector<int64_t>(ends.begin(), ends.end());
...@@ -725,13 +736,9 @@ struct tf_parser ...@@ -725,13 +736,9 @@ struct tf_parser
if(((shrink_axis_mask >> i) & bitwise_compare) == 1) if(((shrink_axis_mask >> i) & bitwise_compare) == 1)
squeeze_axes.push_back(i); squeeze_axes.push_back(i);
} }
if(num_axes >= 4)
{
squeeze_axes = parse_axes(squeeze_axes);
}
auto l0 = prog.add_instruction(op, args[0]); auto l0 = prog.add_instruction(op, make_contiguous(args[0]));
return prog.add_instruction(op::squeeze{squeeze_axes}, l0); return to_nhwc(prog.add_instruction(op::squeeze{squeeze_axes}, l0));
} }
void parse_graph(const tensorflow::GraphDef& graph) void parse_graph(const tensorflow::GraphDef& graph)
...@@ -748,7 +755,7 @@ struct tf_parser ...@@ -748,7 +755,7 @@ struct tf_parser
reorder_data(dims); reorder_data(dims);
} }
shape s = shape{shape_type, dims}; shape s = shape{shape_type, dims};
instructions[name] = prog.add_parameter(name, s); instructions[name] = to_nhwc(prog.add_parameter(name, s));
} }
for(auto&& p : nodes) for(auto&& p : nodes)
{ {
...@@ -1098,6 +1105,7 @@ program parse_tf(const std::string& name, bool is_nhwc) ...@@ -1098,6 +1105,7 @@ program parse_tf(const std::string& name, bool is_nhwc)
#else #else
parser.parse_from(input); parser.parse_from(input);
#endif #endif
parser.to_nchw(std::prev(parser.prog.end()));
return std::move(parser.prog); return std::move(parser.prog);
} }
......
...@@ -148,6 +148,56 @@ TEST_CASE(match_arg7) ...@@ -148,6 +148,56 @@ TEST_CASE(match_arg7)
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
TEST_CASE(match_arg8)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::all_of(match::arg(0)(match::name("@literal")),
match::arg(1)(match::name("@literal"))),
match::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_nargs1)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::nargs(2));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_nargs2)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::nargs(2), match::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_nargs3)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::all_of(match::nargs(2)));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_args1) TEST_CASE(match_args1)
{ {
migraphx::program p; migraphx::program p;
...@@ -307,6 +357,19 @@ TEST_CASE(match_all_of2) ...@@ -307,6 +357,19 @@ TEST_CASE(match_all_of2)
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == p.end()});
} }
TEST_CASE(match_all_of3)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::all_of(match::all_of(
match::arg(0)(match::name("@literal")), match::arg(1)(match::name("@literal")))));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_any_of1) TEST_CASE(match_any_of1)
{ {
migraphx::program p; migraphx::program p;
...@@ -359,6 +422,132 @@ TEST_CASE(match_none_of2) ...@@ -359,6 +422,132 @@ TEST_CASE(match_none_of2)
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == p.end()});
} }
TEST_CASE(match_output1)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one);
auto sum = p.add_instruction(sum_op{}, minus, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::output(match::name("sum")));
auto r = find_match(p, m);
EXPECT(bool{r.result == minus});
}
TEST_CASE(match_output2)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one);
auto sum = p.add_instruction(sum_op{}, minus, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("@literal")(match::output(match::name("sum")));
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
TEST_CASE(match_skip_output1)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one);
auto sum = p.add_instruction(sum_op{}, minus, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(p, m);
EXPECT(bool{r.result == minus});
}
TEST_CASE(match_skip_output2)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one);
auto minus_pass = p.add_instruction(pass_op{}, minus);
auto sum = p.add_instruction(sum_op{}, minus_pass, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(p, m);
EXPECT(bool{r.result == minus});
}
TEST_CASE(match_skip_output3)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one);
auto minus_pass1 = p.add_instruction(pass_op{}, minus);
auto minus_pass2 = p.add_instruction(pass_op{}, minus_pass1);
auto minus_pass3 = p.add_instruction(pass_op{}, minus_pass2);
auto sum = p.add_instruction(sum_op{}, minus_pass3, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(p, m);
EXPECT(bool{r.result == minus});
}
TEST_CASE(match_skip_output4)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto pass = p.add_instruction(pass_op{}, one);
auto sum = p.add_instruction(sum_op{}, pass, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("@literal")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(p, m);
EXPECT(bool{r.result == two});
}
TEST_CASE(match_skip_output5)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto pass = p.add_instruction(pass_op{}, one);
auto sum1 = p.add_instruction(sum_op{}, pass, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, one);
auto sum3 = p.add_instruction(sum_op{}, sum2, two);
p.add_instruction(pass_op{}, sum3);
auto m = match::name("@literal")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
TEST_CASE(match_skip_output6)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one);
auto sum1 = p.add_instruction(sum_op{}, minus, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, one);
auto sum3 = p.add_instruction(sum_op{}, sum2, two);
p.add_instruction(pass_op{}, sum3);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(p, m);
EXPECT(bool{r.result == minus});
}
TEST_CASE(match_skip_output7)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus1 = p.add_instruction(minus_op{}, two, one);
auto minus2 = p.add_instruction(minus_op{}, two, minus1);
auto sum = p.add_instruction(sum_op{}, one, minus2);
p.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("minus")));
auto r = find_match(p, m);
EXPECT(bool{r.result == minus1});
}
TEST_CASE(match_bind1) TEST_CASE(match_bind1)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -38,7 +38,7 @@ TEST_CASE(test_shape_packed) ...@@ -38,7 +38,7 @@ TEST_CASE(test_shape_packed)
EXPECT(not s.broadcasted()); EXPECT(not s.broadcasted());
} }
TEST_CASE(test_shape_transposed) TEST_CASE(test_shape_transposed1)
{ {
migraphx::shape s{migraphx::shape::float_type, {2, 2}, {1, 2}}; migraphx::shape s{migraphx::shape::float_type, {2, 2}, {1, 2}};
EXPECT(not s.standard()); EXPECT(not s.standard());
...@@ -47,6 +47,15 @@ TEST_CASE(test_shape_transposed) ...@@ -47,6 +47,15 @@ TEST_CASE(test_shape_transposed)
EXPECT(not s.broadcasted()); EXPECT(not s.broadcasted());
} }
TEST_CASE(test_shape_transposed2)
{
migraphx::shape s{migraphx::shape::float_type, {1, 1, 1, 1, 2}, {2, 2, 2, 2, 1}};
EXPECT(s.standard());
EXPECT(s.packed());
EXPECT(not s.transposed());
EXPECT(not s.broadcasted());
}
TEST_CASE(test_shape_broadcasted) TEST_CASE(test_shape_broadcasted)
{ {
migraphx::shape s{migraphx::shape::float_type, {2, 2}, {1, 0}}; migraphx::shape s{migraphx::shape::float_type, {2, 2}, {1, 0}};
......
#include <migraphx/simplify_reshapes.hpp> #include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <test.hpp> #include <test.hpp>
...@@ -165,4 +166,144 @@ TEST_CASE(transpose_double_contiguous) ...@@ -165,4 +166,144 @@ TEST_CASE(transpose_double_contiguous)
EXPECT(p.has_instruction(t)); EXPECT(p.has_instruction(t));
} }
TEST_CASE(transpose_partial1)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = p.add_parameter("x", s);
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
p.add_instruction(pass_op{}, t2);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}
TEST_CASE(transpose_partial2)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = p.add_parameter("x", s);
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
auto t3 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2);
p.add_instruction(pass_op{}, t3);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 2);
}
TEST_CASE(transpose_partial3)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = p.add_parameter("x", s);
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
auto t3 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2);
auto t4 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t3);
p.add_instruction(pass_op{}, t4);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 3);
}
TEST_CASE(nop_transpose1)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = p.add_parameter("x", s);
auto t = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, x);
p.add_instruction(pass_op{}, t);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}
TEST_CASE(nop_transpose2)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = p.add_parameter("x", s);
auto t1 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, x);
auto t2 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, t1);
auto t3 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, t2);
auto t4 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, t3);
p.add_instruction(pass_op{}, t4);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 4);
}
TEST_CASE(nop_transpose3)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto concat = p.add_instruction(migraphx::op::concat{3}, x, y);
auto t1 = p.add_instruction(migraphx::op::transpose{{0, 1, 2, 3}}, concat);
auto t2 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, t1);
p.add_instruction(pass_op{}, t2);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}
TEST_CASE(concat_transpose1)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto xt = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x);
auto yt = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, y);
auto concat = p.add_instruction(migraphx::op::concat{2}, xt, yt);
auto t = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, concat);
p.add_instruction(pass_op{}, t);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape().lens() == out_shape.lens());
EXPECT(std::distance(p.begin(), p.end()) == n - 3);
auto new_concat =
std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != p.end()});
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 3);
}
TEST_CASE(concat_transpose2)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto xt = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, x);
auto yt = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, y);
auto concat = p.add_instruction(migraphx::op::concat{3}, xt, yt);
auto t = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat);
p.add_instruction(pass_op{}, t);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape().lens() == out_shape.lens());
EXPECT(std::distance(p.begin(), p.end()) == n - 2);
auto new_concat =
std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != p.end()});
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/tf.hpp> #include <migraphx/tf.hpp>
#include "test.hpp" #include "test.hpp"
migraphx::program optimize_tf(const std::string& name, bool is_nhwc)
{
auto prog = migraphx::parse_tf(name, is_nhwc);
if(is_nhwc)
migraphx::run_passes(prog,
{migraphx::simplify_reshapes{},
migraphx::dead_code_elimination{},
migraphx::eliminate_identity{}});
return prog;
}
TEST_CASE(add_test) TEST_CASE(add_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
p.add_instruction(migraphx::op::add{}, l0, l1); p.add_instruction(migraphx::op::add{}, l0, l1);
auto prog = migraphx::parse_tf("add_test.pb", false); auto prog = optimize_tf("add_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -28,7 +43,7 @@ TEST_CASE(add_bcast_test) ...@@ -28,7 +43,7 @@ TEST_CASE(add_bcast_test)
auto l2 = p.add_instruction(migraphx::op::multibroadcast{s0.lens()}, l0); auto l2 = p.add_instruction(migraphx::op::multibroadcast{s0.lens()}, l0);
auto l3 = p.add_instruction(migraphx::op::multibroadcast{s0.lens()}, l1); auto l3 = p.add_instruction(migraphx::op::multibroadcast{s0.lens()}, l1);
p.add_instruction(migraphx::op::add{}, l2, l3); p.add_instruction(migraphx::op::add{}, l2, l3);
auto prog = migraphx::parse_tf("add_bcast_test.pb", false); auto prog = optimize_tf("add_bcast_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -51,7 +66,7 @@ TEST_CASE(batchnorm_test) ...@@ -51,7 +66,7 @@ TEST_CASE(batchnorm_test)
auto l4 = p.add_parameter("4", s0); auto l4 = p.add_parameter("4", s0);
auto l1 = p.add_literal(migraphx::literal{s0, const_vals}); auto l1 = p.add_literal(migraphx::literal{s0, const_vals});
p.add_instruction(op, l0, l1, l2, l3, l4); p.add_instruction(op, l0, l1, l2, l3, l4);
auto prog = migraphx::parse_tf("batchnorm_test.pb", true); auto prog = optimize_tf("batchnorm_test.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -65,7 +80,7 @@ TEST_CASE(biasadd_test) ...@@ -65,7 +80,7 @@ TEST_CASE(biasadd_test)
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {500}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {500}});
auto l2 = p.add_instruction(migraphx::op::broadcast{axis, l0->get_shape().lens()}, l1); auto l2 = p.add_instruction(migraphx::op::broadcast{axis, l0->get_shape().lens()}, l1);
p.add_instruction(migraphx::op::add{}, l0, l2); p.add_instruction(migraphx::op::add{}, l0, l2);
auto prog = migraphx::parse_tf("biasadd_test.pb", true); auto prog = optimize_tf("biasadd_test.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -83,7 +98,7 @@ TEST_CASE(concat_test) ...@@ -83,7 +98,7 @@ TEST_CASE(concat_test)
p.add_literal(migraphx::shape{migraphx::shape::int32_type}, std::vector<int>{axis}); p.add_literal(migraphx::shape{migraphx::shape::int32_type}, std::vector<int>{axis});
p.add_instruction(migraphx::op::concat{static_cast<std::size_t>(axis)}, l0, l1); p.add_instruction(migraphx::op::concat{static_cast<std::size_t>(axis)}, l0, l1);
auto prog = migraphx::parse_tf("concat_test.pb", false); auto prog = optimize_tf("concat_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -92,7 +107,7 @@ TEST_CASE(const_test) ...@@ -92,7 +107,7 @@ TEST_CASE(const_test)
{ {
migraphx::program p; migraphx::program p;
p.add_literal(migraphx::shape{migraphx::shape::float_type}, std::vector<float>{1.0f}); p.add_literal(migraphx::shape{migraphx::shape::float_type}, std::vector<float>{1.0f});
auto prog = migraphx::parse_tf("constant_test.pb", false); auto prog = optimize_tf("constant_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -112,10 +127,9 @@ TEST_CASE(conv_test) ...@@ -112,10 +127,9 @@ TEST_CASE(conv_test)
op.padding = {1, 1}; op.padding = {1, 1};
op.stride = {1, 1}; op.stride = {1, 1};
op.dilation = {1, 1}; op.dilation = {1, 1};
auto l2 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l1); auto l2 = p.add_instruction(migraphx::op::transpose{{3, 2, 0, 1}}, l1);
auto l3 = p.add_instruction(migraphx::op::transpose{{1, 3, 0, 2}}, l2); p.add_instruction(op, l0, l2);
p.add_instruction(op, l0, l3); auto prog = optimize_tf("conv_test.pb", true);
auto prog = migraphx::parse_tf("conv_test.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -136,12 +150,11 @@ TEST_CASE(depthwiseconv_test) ...@@ -136,12 +150,11 @@ TEST_CASE(depthwiseconv_test)
op.stride = {1, 1}; op.stride = {1, 1};
op.dilation = {1, 1}; op.dilation = {1, 1};
op.group = 3; op.group = 3;
auto l2 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l1); auto l3 = p.add_instruction(migraphx::op::transpose{{3, 2, 0, 1}}, l1);
auto l3 = p.add_instruction(migraphx::op::transpose{{1, 3, 0, 2}}, l2);
auto l4 = p.add_instruction(migraphx::op::contiguous{}, l3); auto l4 = p.add_instruction(migraphx::op::contiguous{}, l3);
auto l5 = p.add_instruction(migraphx::op::reshape{{3, 1, 3, 3}}, l4); auto l5 = p.add_instruction(migraphx::op::reshape{{3, 1, 3, 3}}, l4);
p.add_instruction(op, l0, l5); p.add_instruction(op, l0, l5);
auto prog = migraphx::parse_tf("depthwise_conv_test.pb", true); auto prog = optimize_tf("depthwise_conv_test.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -151,7 +164,7 @@ TEST_CASE(identity_test) ...@@ -151,7 +164,7 @@ TEST_CASE(identity_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::identity{}, l0); p.add_instruction(migraphx::op::identity{}, l0);
auto prog = migraphx::parse_tf("identity_test.pb", false); auto prog = optimize_tf("identity_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -166,7 +179,7 @@ TEST_CASE(matmul_test) ...@@ -166,7 +179,7 @@ TEST_CASE(matmul_test)
auto trans_l1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1); auto trans_l1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
p.add_instruction(migraphx::op::dot{}, trans_l0, trans_l1); p.add_instruction(migraphx::op::dot{}, trans_l0, trans_l1);
auto prog = migraphx::parse_tf("matmul_test.pb", false); auto prog = optimize_tf("matmul_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -183,7 +196,7 @@ TEST_CASE(mean_test) ...@@ -183,7 +196,7 @@ TEST_CASE(mean_test)
p.add_instruction(op, l0); p.add_instruction(op, l0);
auto l3 = p.add_instruction(op, l0); auto l3 = p.add_instruction(op, l0);
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3); p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3);
auto prog = migraphx::parse_tf("mean_test.pb", false); auto prog = optimize_tf("mean_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -193,14 +206,11 @@ TEST_CASE(mean_test_nhwc) ...@@ -193,14 +206,11 @@ TEST_CASE(mean_test_nhwc)
migraphx::program p; migraphx::program p;
migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 2}}; migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 2}};
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_literal(l);
p.add_literal(l);
migraphx::op::pooling op; migraphx::op::pooling op;
op.lengths = {16, 16}; op.lengths = {16, 16};
p.add_instruction(op, l0); auto l3 = p.add_instruction(op, l0);
auto l3 = p.add_instruction(op, l0);
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3); p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3);
auto prog = migraphx::parse_tf("mean_test_nhwc.pb", true); auto prog = optimize_tf("mean_test_nhwc.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -212,7 +222,7 @@ TEST_CASE(mul_test) ...@@ -212,7 +222,7 @@ TEST_CASE(mul_test)
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 16}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 16}});
p.add_instruction(migraphx::op::mul{}, l0, l1); p.add_instruction(migraphx::op::mul{}, l0, l1);
auto prog = migraphx::parse_tf("mul_test.pb", false); auto prog = optimize_tf("mul_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -234,7 +244,7 @@ TEST_CASE(pack_test) ...@@ -234,7 +244,7 @@ TEST_CASE(pack_test)
return p.add_instruction(migraphx::op::unsqueeze{{axis}}, arg); return p.add_instruction(migraphx::op::unsqueeze{{axis}}, arg);
}); });
p.add_instruction(migraphx::op::concat{static_cast<size_t>(axis)}, unsqueezed_args); p.add_instruction(migraphx::op::concat{static_cast<size_t>(axis)}, unsqueezed_args);
auto prog = migraphx::parse_tf("pack_test.pb", false); auto prog = optimize_tf("pack_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -242,12 +252,15 @@ TEST_CASE(pack_test) ...@@ -242,12 +252,15 @@ TEST_CASE(pack_test)
TEST_CASE(pack_test_nhwc) TEST_CASE(pack_test_nhwc)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}}); auto lt0 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0);
auto l2 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
std::vector<migraphx::instruction_ref> args{l0, l1, l2}; auto lt1 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l1);
auto l2 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt2 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l2);
std::vector<migraphx::instruction_ref> args{lt0, lt1, lt2};
std::vector<migraphx::instruction_ref> unsqueezed_args; std::vector<migraphx::instruction_ref> unsqueezed_args;
int64_t nchw_axis = 1; int64_t nchw_axis = 3;
std::transform(args.begin(), std::transform(args.begin(),
args.end(), args.end(),
...@@ -256,7 +269,7 @@ TEST_CASE(pack_test_nhwc) ...@@ -256,7 +269,7 @@ TEST_CASE(pack_test_nhwc)
return p.add_instruction(migraphx::op::unsqueeze{{nchw_axis}}, arg); return p.add_instruction(migraphx::op::unsqueeze{{nchw_axis}}, arg);
}); });
p.add_instruction(migraphx::op::concat{static_cast<size_t>(nchw_axis)}, unsqueezed_args); p.add_instruction(migraphx::op::concat{static_cast<size_t>(nchw_axis)}, unsqueezed_args);
auto prog = migraphx::parse_tf("pack_test_nhwc.pb", true); auto prog = optimize_tf("pack_test_nhwc.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -273,9 +286,9 @@ TEST_CASE(pooling_test) ...@@ -273,9 +286,9 @@ TEST_CASE(pooling_test)
max_pool_op.stride = {2, 2}; max_pool_op.stride = {2, 2};
avg_pool_op.lengths = {2, 2}; avg_pool_op.lengths = {2, 2};
max_pool_op.lengths = {2, 2}; max_pool_op.lengths = {2, 2};
p.add_instruction(avg_pool_op, l0);
p.add_instruction(max_pool_op, l0); p.add_instruction(max_pool_op, l0);
auto prog = migraphx::parse_tf("pooling_test.pb", true); // p.add_instruction(avg_pool_op, l0);
auto prog = optimize_tf("pooling_test.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -285,7 +298,7 @@ TEST_CASE(relu_test) ...@@ -285,7 +298,7 @@ TEST_CASE(relu_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::relu{}, l0); p.add_instruction(migraphx::op::relu{}, l0);
auto prog = migraphx::parse_tf("relu_test.pb", false); auto prog = optimize_tf("relu_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -295,7 +308,7 @@ TEST_CASE(relu6_test) ...@@ -295,7 +308,7 @@ TEST_CASE(relu6_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::clip{6.0, 0.0}, l0); p.add_instruction(migraphx::op::clip{6.0, 0.0}, l0);
auto prog = migraphx::parse_tf("relu6_test.pb", false); auto prog = optimize_tf("relu6_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -308,7 +321,7 @@ TEST_CASE(reshape_test) ...@@ -308,7 +321,7 @@ TEST_CASE(reshape_test)
// in tf, the second arg is a literal that contains new dimensions // in tf, the second arg is a literal that contains new dimensions
p.add_literal(migraphx::literal{s0, {1, 1, 1, 16}}); p.add_literal(migraphx::literal{s0, {1, 1, 1, 16}});
p.add_instruction(migraphx::op::reshape{{1, 1, 1, 16}}, l0); p.add_instruction(migraphx::op::reshape{{1, 1, 1, 16}}, l0);
auto prog = migraphx::parse_tf("reshape_test.pb", false); auto prog = optimize_tf("reshape_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -321,7 +334,7 @@ TEST_CASE(softmax_test) ...@@ -321,7 +334,7 @@ TEST_CASE(softmax_test)
auto r = p.add_instruction(migraphx::op::reshape{{long(dims[0]), long(dims[1]), 1, 1}}, l0); auto r = p.add_instruction(migraphx::op::reshape{{long(dims[0]), long(dims[1]), 1, 1}}, l0);
auto s = p.add_instruction(migraphx::op::softmax{}, r); auto s = p.add_instruction(migraphx::op::softmax{}, r);
p.add_instruction(migraphx::op::reshape{{long(dims[0]), long(dims[1])}}, s); p.add_instruction(migraphx::op::reshape{{long(dims[0]), long(dims[1])}}, s);
auto prog = migraphx::parse_tf("softmax_test.pb", false); auto prog = optimize_tf("softmax_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -331,7 +344,7 @@ TEST_CASE(squeeze_test) ...@@ -331,7 +344,7 @@ TEST_CASE(squeeze_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 1}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 1}});
p.add_instruction(migraphx::op::squeeze{{0, 3}}, l0); p.add_instruction(migraphx::op::squeeze{{0, 3}}, l0);
auto prog = migraphx::parse_tf("squeeze_test.pb", false); auto prog = optimize_tf("squeeze_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -343,18 +356,13 @@ TEST_CASE(stridedslice_test) ...@@ -343,18 +356,13 @@ TEST_CASE(stridedslice_test)
std::size_t num_axes = 4; std::size_t num_axes = 4;
migraphx::op::slice op; migraphx::op::slice op;
op.starts = {0, 0, 0, 0}; op.starts = {0, 0, 0, 0};
op.ends = {1, 5, 1, 1}; op.ends = {1, 1, 1, 5};
op.axes = std::vector<int64_t>(num_axes); op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0); std::iota(op.axes.begin(), op.axes.end(), 0);
// add literals for starts, ends, and strides in tf (NHWC format)
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{0, 0, 0, 0});
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{1, 1, 1, 5});
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{1, 1, 1, 1});
auto l1 = p.add_instruction(op, l0); auto l1 = p.add_instruction(op, l0);
auto shrink_axis = 2; auto shrink_axis = 1;
p.add_instruction(migraphx::op::squeeze{{shrink_axis}}, l1); p.add_instruction(migraphx::op::squeeze{{shrink_axis}}, l1);
auto prog = migraphx::parse_tf("stridedslice_test.pb", true); auto prog = optimize_tf("stridedslice_test.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment