Commit e7d26442 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'bert_operators' into test_bert

parents 729f577f 026365a6
......@@ -7,6 +7,14 @@
#include <migraphx/onnx.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_reshapes.hpp>
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -17,6 +25,7 @@ struct loader
std::string file_type;
bool is_nhwc = true;
unsigned trim = 0;
bool optimize = false;
void parse(argument_parser& ap)
{
......@@ -26,6 +35,7 @@ struct loader
ap(is_nhwc, {"--nhwc"}, ap.help("Treat tensorflow format as nhwc"), ap.set_value(true));
ap(is_nhwc, {"--nchw"}, ap.help("Treat tensorflow format as nchw"), ap.set_value(false));
ap(trim, {"--trim", "-t"}, ap.help("Trim instructions from the end"));
ap(optimize, {"--optimize"}, ap.help("Optimize when reading"), ap.set_value(true));
}
program load()
......@@ -48,6 +58,20 @@ struct loader
auto last = std::prev(p.end(), trim);
p.remove_instructions(last, p.end());
}
if(optimize)
migraphx::run_passes(p,
{
migraphx::eliminate_identity{},
migraphx::dead_code_elimination{},
migraphx::simplify_algebra{},
migraphx::dead_code_elimination{},
migraphx::simplify_reshapes{},
migraphx::dead_code_elimination{},
migraphx::propagate_constant{},
migraphx::dead_code_elimination{},
migraphx::eliminate_pad{},
migraphx::dead_code_elimination{},
});
return p;
}
};
......
......@@ -190,6 +190,23 @@ auto pop_back_args(Ts&&... xs)
};
}
template <class T>
struct always_f
{
T x;
template <class... Ts>
constexpr T operator()(Ts&&...) const
{
return x;
}
};
template <class T>
auto always(T x)
{
return always_f<T>{x};
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -8,6 +8,7 @@
#include <migraphx/iterator_for.hpp>
#include <migraphx/config.hpp>
#include <unordered_map>
#include <unordered_set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -20,6 +21,12 @@ struct matcher_context
std::unordered_map<std::string, instruction_ref> instructions;
instruction_ref not_found() const { return last; }
template <class M>
bool matched(M m, instruction_ref ins)
{
return m.match(*this, ins) != this->not_found();
}
private:
instruction_ref last;
};
......@@ -205,12 +212,10 @@ matcher_result match_instruction(program& p, instruction_ref ins, M&& m)
return result;
}
/// Find matches in a program
/// Find matches for an instruction in the program
template <class... Ms>
void find_matches(program& p, Ms&&... ms)
void find_matches(program& p, instruction_ref ins, Ms&&... ms)
{
for(auto ins : iterator_for(p))
{
bool match = false;
each_args(
[&](auto&& m) {
......@@ -223,56 +228,131 @@ void find_matches(program& p, Ms&&... ms)
match = true;
},
ms...);
}
/// Find matches in a program
template <class... Ms>
void find_matches(program& p, Ms&&... ms)
{
for(auto ins : iterator_for(p))
{
find_matches(p, ins, ms...);
}
}
template <class... Ts>
auto all_of(Ts... ms)
struct lazy_and
{
template <class F, class G>
bool operator()(F f, G g) const
{
return f() and g();
}
};
struct lazy_or
{
template <class F, class G>
bool operator()(F f, G g) const
{
return f() or g();
}
};
template <class Op, bool Start, bool Matches>
struct match_fold_f
{
template <class... Ms>
static bool fold_matchers(matcher_context& ctx, instruction_ref ins, Ms... ms)
{
Op op;
auto matched = [&](auto m) { return [=, &ctx] { return ctx.matched(m, ins); }; };
return fold([&](auto x, auto y) { return op(always(x), matched(y)); })(Start, ms...);
}
template <class Pack>
static bool fold_matchers_pack(matcher_context& ctx, instruction_ref ins, Pack p)
{
return p([&](auto... ms) { return match_fold_f::fold_matchers(ctx, ins, ms...); });
}
template <class... Ts>
auto operator()(Ts... ms) const
{
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
bool matches = fold([&](auto x, auto y) {
return x and y.match(ctx, ins) != ctx.not_found();
})(true, ms...);
if(matches)
bool matches = match_fold_f::fold_matchers(ctx, ins, ms...);
if(matches == Matches)
return ins;
return ctx.not_found();
});
}
}
template <class... Ts>
auto none_of(Ts... ms)
{
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
bool matches = fold([&](auto x, auto y) {
return x and y.match(ctx, ins) == ctx.not_found();
})(true, ms...);
if(matches)
return ins;
template <class Selector>
auto operator[](Selector select) const
{
return [=](auto... ms) {
// Workaround ICE on gcc by packing matchers into an object
auto mpack = pack(ms...);
return make_bf_matcher([=](matcher_context& ctx, instruction_ref start) {
Op op;
bool matches = Start;
select(start, [&](auto ins) {
auto fm = [&] { return match_fold_f::fold_matchers_pack(ctx, ins, mpack); };
matches = op(always(matches), fm);
});
if(matches == Matches)
return start;
return ctx.not_found();
});
};
}
};
const constexpr auto all_of = match_fold_f<lazy_and, true, true>{};
const constexpr auto any_of = match_fold_f<lazy_or, false, true>{};
const constexpr auto none_of = match_fold_f<lazy_or, false, false>{};
inline auto inputs()
{
return [](auto ins, auto f) {
for(auto&& x : ins->inputs())
f(x);
};
}
template <class... Ts>
auto any_of(Ts... ms)
inline auto outputs()
{
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
bool matches = fold([&](auto x, auto y) {
return x or y.match(ctx, ins) != ctx.not_found();
})(false, ms...);
if(matches)
return ins;
return ctx.not_found();
});
return [](auto ins, auto f) {
for(auto&& x : ins->outputs())
f(x);
};
}
MIGRAPHX_PRED_MATCHER(any, instruction_ref) { return true; }
MIGRAPHX_PRED_MATCHER(none, instruction_ref) { return false; }
MIGRAPHX_PRED_MATCHER(standard_shape, instruction_ref ins) { return ins->get_shape().standard(); }
MIGRAPHX_PRED_MATCHER(not_standard_shape, instruction_ref ins)
{
return not ins->get_shape().standard();
}
MIGRAPHX_PRED_MATCHER(broadcast_shape, instruction_ref ins)
{
return ins->get_shape().broadcasted();
}
MIGRAPHX_PRED_MATCHER(transpose_shape, instruction_ref ins)
{
return ins->get_shape().transposed();
}
MIGRAPHX_PRED_MATCHER(same_input_shapes, instruction_ref ins)
{
if(ins->inputs().empty())
return false;
auto s = ins->inputs().front()->get_shape();
return std::all_of(
ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x->get_shape() == s; });
}
MIGRAPHX_BASIC_MATCHER(output, matcher_context& ctx, instruction_ref ins)
{
if(ins->outputs().size() == 1)
......@@ -289,10 +369,39 @@ MIGRAPHX_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins)
return ctx.not_found();
}
inline auto name(std::string name)
template <class... Ms>
auto skip_output(Ms... ms)
{
auto m = any_of(ms...);
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) {
return fix<instruction_ref>([&](auto self, auto ins) {
if(ins->outputs().size() == 1)
{
auto next = ins->outputs().front();
if(ctx.matched(m, next))
{
auto skipped_next = self(next);
if(skipped_next != ctx.not_found())
return skipped_next;
}
return next;
}
return ctx.not_found();
})(start);
});
}
inline auto name(std::string s)
{
return make_basic_pred_matcher(
[ =, name = std::move(name) ](instruction_ref ins) { return ins->name() == name; });
[ =, s = std::move(s) ](instruction_ref ins) { return ins->name() == s; });
}
inline auto name(std::unordered_set<std::string> names)
{
return make_basic_pred_matcher([ =, names = std::move(names) ](instruction_ref ins) {
return names.count(ins->name()) > 0;
});
}
inline auto nargs(std::size_t n)
......@@ -338,6 +447,23 @@ inline auto either_arg(std::size_t i, std::size_t j)
};
}
template <class M>
auto same_shape(M m)
{
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) {
auto i = m.match(ctx, ins);
if(i != ctx.not_found() and i->get_shape() == ins->get_shape())
return ins;
return ctx.not_found();
});
}
template <class... Ms>
auto same_shape(Ms... ms)
{
return all_of(same_shape(ms)...);
}
} // namespace match
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -2,14 +2,17 @@
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/as_shape.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/matcher.hpp>
#include <unordered_set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
bool is_reshaper(instruction_ref ins)
const auto& reshaper_names()
{
// clang-format off
static const std::unordered_set<std::string> names = {
......@@ -19,17 +22,10 @@ bool is_reshaper(instruction_ref ins)
"unsqueeze"
};
// clang-format on
return contains(names, ins->name());
return names;
}
bool is_transpose_output(instruction_ref ins)
{
if(ins->outputs().size() != 1)
return false;
if(ins->outputs().front()->name() == "contiguous")
return is_transpose_output(ins->outputs().front());
return ins->outputs().front()->name() == "transpose";
}
bool is_reshaper(instruction_ref ins) { return contains(reshaper_names(), ins->name()); }
instruction_ref find_transpose_input(instruction_ref ins)
{
......@@ -42,21 +38,62 @@ instruction_ref find_transpose_input(instruction_ref ins)
return ins;
}
void simplify_reshapes::apply(program& p) const
auto get_transpose_dims(instruction_ref ins)
{
auto end = std::prev(p.end());
for(auto ins : iterator_for(p))
return any_cast<const op::transpose&>(ins->get_operator()).dims;
}
std::vector<int64_t> reorder_dims(std::vector<int64_t> dims, std::vector<int64_t> permutation)
{
std::vector<int64_t> result(dims.size());
assert(dims.size() == permutation.size());
for(std::size_t i = 0; i < dims.size(); i++)
{
if(ins == end and ins->name() == "contiguous")
continue;
// Skip possible dead instructions
if(ins->outputs().empty() and ins != end)
continue;
if(is_reshaper(ins))
result[i] = dims[permutation[i]];
}
return result;
}
bool is_no_transpose(const std::vector<int64_t>& dims)
{
if(dims.empty())
return true;
if(dims.front() != 0)
return false;
return std::adjacent_find(
dims.begin(), dims.end(), [](auto x, auto y) { return (y - x) != 1; }) == dims.end();
}
template <class Vector, class Op>
std::vector<int64_t> sort_permutation(const Vector& data, Op op)
{
std::vector<std::int64_t> result(data.size());
std::iota(result.begin(), result.end(), 0);
std::sort(result.begin(), result.end(), [&](auto x, auto y) { return op(data[x], data[y]); });
return result;
}
std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation)
{
return sort_permutation(permutation, std::less<>{});
}
std::vector<int64_t> find_permutation(const shape& s)
{
return sort_permutation(s.strides(), std::greater<>{});
}
struct find_reshaper
{
auto matcher() const
{
if(std::any_of(ins->outputs().begin(), ins->outputs().end(), &is_reshaper))
continue;
// Gather reshapes
return match::name(reshaper_names())(
match::any_of[match::outputs()](match::name(reshaper_names())));
}
void apply(program& p, const match::matcher_result& mr) const
{
auto ins = mr.result;
std::vector<instruction_ref> reshapes{ins};
while(is_reshaper(reshapes.back()))
{
......@@ -83,21 +120,107 @@ void simplify_reshapes::apply(program& p) const
p.replace_instruction(r.first, r.second);
}
}
else if(ins->name() == "transpose")
};
struct find_nop_reshapes
{
auto matcher() const
{
if(is_transpose_output(ins))
continue;
auto reshapes = reshaper_names();
reshapes.insert("transpose");
reshapes.insert("slice");
return match::name(reshapes)(match::same_shape(match::arg(0)));
}
void apply(program& p, const match::matcher_result& mr) const
{
auto ins = mr.result;
p.replace_instruction(ins, ins->inputs().front());
}
};
struct find_transpose
{
auto matcher() const
{
return match::name("transpose")(match::none_of(
match::skip_output(match::name("contiguous"))(match::name("transpose"))));
}
void apply(program& p, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto x = ins;
auto t = ins;
std::vector<std::int64_t> dims(ins->get_shape().lens().size());
std::iota(dims.begin(), dims.end(), 0);
do
{
dims = reorder_dims(get_transpose_dims(t), dims);
x = t;
t = find_transpose_input(x);
} while(x != t and t->name() == "transpose");
if(t == ins or t->name() != "transpose")
continue;
return;
if(is_no_transpose(dims))
{
p.replace_instruction(ins, t->inputs().front());
}
else
{
p.replace_instruction(ins, op::transpose{{dims}}, t->inputs().front());
}
}
};
struct find_concat_transpose
{
auto matcher() const
{
return match::name("concat")(match::same_input_shapes(),
match::all_of[match::inputs()](match::transpose_shape()));
}
void apply(program& p, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto s = ins->inputs().front()->get_shape();
assert(s.transposed());
auto op = any_cast<op::concat>(ins->get_operator());
auto permutation = find_permutation(s);
auto ipermutation = invert_permutation(permutation);
op.axis = ipermutation[op.axis];
std::vector<instruction_ref> inputs;
std::transform(
ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) {
if(i->name() == "transpose" and i->inputs().front()->get_shape().standard())
return i->inputs().front();
return p.insert_instruction(ins, op::transpose{permutation}, i);
});
auto concat = p.insert_instruction(ins, op, inputs);
auto t = p.insert_instruction(ins, op::transpose{ipermutation}, concat);
assert(ins->get_shape().lens() == t->get_shape().lens());
p.replace_instruction(ins, t);
}
};
void simplify_reshapes::apply(program& p) const
{
auto end = std::prev(p.end());
for(auto ins : iterator_for(p))
{
if(ins == end and ins->name() == "contiguous")
continue;
// Skip possible dead instructions
if(ins->outputs().empty() and ins != end)
continue;
match::find_matches(p,
ins,
find_nop_reshapes{},
find_reshaper{},
find_transpose{},
find_concat_transpose{});
}
}
......
......@@ -20,10 +20,12 @@ argument concat(hipStream_t stream,
auto&& arg = args[j];
std::size_t nelements = arg.get_shape().elements();
auto offset = offsets[j];
hip_visit_all(args.back(), arg)([&](auto output, auto input) {
shape arg_shape{arg.get_shape().type(), arg.get_shape().lens()};
hip_visit_all(args.back(), arg, arg_shape)([&](auto output, auto input, auto input_shape) {
gs_launch(stream, nelements)([=](auto i) {
auto idx = output.get_shape().index(input.get_shape().multi(i));
output.data()[idx + offset] = input.data()[i];
auto input_idx = input_shape.multi(i);
auto idx = output.get_shape().index(input_idx);
output.data()[idx + offset] = input[input_idx];
});
});
}
......
......@@ -200,12 +200,33 @@ struct hip_add_relu
}
};
void move_broadcasted_back(std::vector<instruction_ref>& args)
{
// Ensure the last arguments is the broadcasted one
auto it = std::find_if(
args.begin(), args.end(), [](auto arg) { return arg->get_shape().broadcasted(); });
if(it != args.end())
std::swap(*it, *std::prev(args.end(), 2));
}
void move_standard_front(std::vector<instruction_ref>& args)
{
// Ensure the first arguments is the standard one
auto it = std::find_if(
args.begin(), args.end(), [](auto arg) { return arg->get_shape().standard(); });
if(it != args.end())
std::swap(*it, args.front());
}
struct find_add_relu
{
auto matcher() const
{
return match::name("gpu::relu")(match::arg(0)(
match::any_of(match::name("gpu::add"), match::name("hip::triadd")).bind("add")));
return match::name("gpu::relu")(
match::arg(0)(match::any_of(match::name("gpu::add"),
match::name("hip::triadd"),
match::any_of[match::inputs()](match::standard_shape()))
.bind("add")));
}
void apply(program& p, match::matcher_result r) const
......@@ -213,6 +234,9 @@ struct find_add_relu
auto add_ins = r.instructions["add"];
auto ins = r.result;
auto args = add_ins->inputs();
move_standard_front(args);
move_broadcasted_back(args);
// Use the allocation from the relu operator
args.back() = ins->inputs().back();
if(add_ins->name() == "gpu::add")
......@@ -226,8 +250,9 @@ struct find_triadd
{
auto matcher() const
{
return match::name("gpu::add")(match::either_arg(0, 1)(match::name("gpu::add").bind("add"),
match::any().bind("input")));
return match::name("gpu::add")(match::either_arg(0, 1)(
match::name("gpu::add").bind("add"),
match::any(match::any_of[match::inputs()](match::standard_shape())).bind("input")));
}
void apply(program& p, match::matcher_result r) const
......@@ -236,14 +261,15 @@ struct find_triadd
auto input_ins = r.instructions["input"];
auto ins = r.result;
auto args = add_ins->inputs();
assert(add_ins != input_ins);
auto is_broadcasted = [](auto arg) { return arg->get_shape().broadcasted(); };
if(std::count_if(args.begin(), args.end(), is_broadcasted) > 1)
return;
args.insert(args.begin(), input_ins);
// Ensure the last arguments is the broadcasted one
auto it = std::find_if(args.begin(), args.end(), is_broadcasted);
if(it != args.end())
std::swap(*it, *std::prev(args.end(), 2));
move_standard_front(args);
move_broadcasted_back(args);
args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_triadd{}, args);
}
......@@ -402,8 +428,8 @@ void fuse_ops::apply(program& p) const
// clang-format off
match::find_matches(p, find_triadd{});
match::find_matches(p,
// find_conv_bias_relu{ctx},
// find_conv_bias{ctx},
find_conv_bias_relu{ctx},
find_conv_bias{ctx},
find_add_relu{}
);
// clang-format on
......
......@@ -36,6 +36,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
// clang-format off
return
{
dead_code_elimination{},
simplify_reshapes{},
dead_code_elimination{},
eliminate_identity{},
eliminate_pad{},
......@@ -48,11 +50,11 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
//dead_code_elimination{},
simplify_algebra{},
dead_code_elimination{},
propagate_constant{},
dead_code_elimination{},
auto_contiguous{},
simplify_reshapes{},
dead_code_elimination{},
propagate_constant{},
dead_code_elimination{},
lowering{ctx},
eliminate_concat{concat_gpu_optimization{}},
dead_code_elimination{},
......
......@@ -37,6 +37,48 @@ struct tf_parser
std::unordered_map<std::string, op_func> ops;
bool should_transpose(instruction_ref ins) const
{
return is_nhwc and ins->get_shape().lens().size() == 4;
}
instruction_ref to_nhwc(instruction_ref ins)
{
if(should_transpose(ins))
return prog.add_instruction(op::transpose{{0, 2, 3, 1}}, ins);
return ins;
}
instruction_ref to_nchw(instruction_ref ins)
{
if(should_transpose(ins))
return prog.add_instruction(op::transpose{{0, 3, 1, 2}}, ins);
return ins;
}
instruction_ref to_kcxy(instruction_ref ins)
{
if(should_transpose(ins))
return prog.add_instruction(op::transpose{{3, 2, 0, 1}}, ins);
return ins;
}
instruction_ref make_contiguous(instruction_ref ins)
{
if(ins->get_shape().standard())
return ins;
else
return prog.add_instruction(op::contiguous{}, ins);
}
std::vector<instruction_ref> to_nchw(const std::vector<instruction_ref>& args)
{
std::vector<instruction_ref> result(args.size());
std::transform(
args.begin(), args.end(), result.begin(), [&](auto ins) { return this->to_nchw(ins); });
return result;
}
std::vector<size_t> parse_axes(const attribute_map& attributes, const std::string& s) const
{
auto attrs = attributes.at(s).list().i();
......@@ -119,59 +161,67 @@ struct tf_parser
add_mem_op("AvgPool", &tf_parser::parse_pooling);
add_mem_op("BiasAdd", &tf_parser::parse_biasadd);
add_mem_op("ConcatV2", &tf_parser::parse_concat);
add_mem_op("ConcatV2", &tf_parser::parse_concat, false);
add_mem_op("Const", &tf_parser::parse_constant);
add_mem_op("Conv2D", &tf_parser::parse_conv);
add_mem_op("DepthwiseConv2dNative", &tf_parser::parse_depthwiseconv);
add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm);
add_mem_op("MatMul", &tf_parser::parse_matmul);
add_mem_op("MatMul", &tf_parser::parse_matmul, false);
add_mem_op("MaxPool", &tf_parser::parse_pooling);
add_mem_op("Mean", &tf_parser::parse_mean);
add_mem_op("Pack", &tf_parser::parse_pack);
add_mem_op("Pack", &tf_parser::parse_pack, false);
add_mem_op("Pad", &tf_parser::parse_pad);
add_mem_op("Reshape", &tf_parser::parse_reshape);
add_mem_op("Reshape", &tf_parser::parse_reshape, false);
add_mem_op("Softmax", &tf_parser::parse_softmax);
add_mem_op("Squeeze", &tf_parser::parse_squeeze);
add_mem_op("Squeeze", &tf_parser::parse_squeeze, false);
add_mem_op("StridedSlice", &tf_parser::parse_stridedslice);
}
template <class F>
void add_op(std::string name, F f)
void add_op(std::string name, F f, bool transpose = true)
{
ops.emplace(name, f);
if(transpose)
{
ops.emplace(name,
op_func{[=](const attribute_map& attributes,
const std::vector<instruction_ref>& args) -> instruction_ref {
return to_nhwc(f(attributes, to_nchw(args)));
}});
}
// Multi output op
template <class F>
void add_multi_op(std::string name, F f)
else
{
ops.emplace(name, f);
}
}
template <class F>
void add_mem_op(std::string name, F f)
void add_mem_op(std::string name, F f, bool transpose = true)
{
add_op(name, [=](auto&&... xs) {
add_op(name,
[=](auto&&... xs) {
return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
});
},
transpose);
}
template <class T>
void add_binary_op(std::string name, T x)
{
add_op(name, [this, x](const attribute_map& attributes, std::vector<instruction_ref> args) {
add_op(name,
[this, x](const attribute_map&, std::vector<instruction_ref> args) {
if(args.size() != 2)
MIGRAPHX_THROW("binary operators should have 2 operands");
auto l0 = args[1];
if(contains(attributes, "data_format"))
{
if(is_nhwc)
{
l0 = prog.add_instruction(op::transpose{{0, 3, 1, 2}}, args[1]);
}
}
return add_broadcastable_binary_op(args[0], l0, x);
});
// TODO
// if(contains(attributes, "data_format"))
// {
// if(is_nhwc)
// {
// l0 = prog.add_instruction(op::transpose{{0, 3, 1, 2}}, args[1]);
// }
// }
return add_broadcastable_binary_op(args[0], args[1], x);
},
false);
}
template <class T>
......@@ -210,20 +260,22 @@ struct tf_parser
auto l0 = prog.add_instruction(op::multibroadcast{output_lens}, arg0);
auto l1 = prog.add_instruction(op::multibroadcast{output_lens}, arg1);
return prog.add_instruction(x, l0, l1);
return to_nhwc(prog.add_instruction(x, to_nchw(l0), to_nchw(l1)));
}
else
{
return prog.add_instruction(x, {arg0, arg1});
return to_nhwc(prog.add_instruction(x, {to_nchw(arg0), to_nchw(arg1)}));
}
}
template <class T>
void add_generic_op(std::string name, T x)
void add_generic_op(std::string name, T x, bool transpose = true)
{
add_op(name, [this, x](const attribute_map&, std::vector<instruction_ref> args) {
add_op(name,
[this, x](const attribute_map&, std::vector<instruction_ref> args) {
return prog.add_instruction(x, args);
});
},
transpose);
}
instruction_ref
......@@ -253,7 +305,7 @@ struct tf_parser
{
// get index for axis within args
size_t axis_idx = attributes.at("N").i();
size_t axis = parse_axis(args[axis_idx]->eval().at<int64_t>());
size_t axis = args[axis_idx]->eval().at<int64_t>();
op::concat op{axis};
// return only first N arguments (assuming last index is the axis value)
return prog.add_instruction(
......@@ -265,15 +317,7 @@ struct tf_parser
const std::vector<instruction_ref>&)
{
literal v = parse_tensor(attributes.at("value").tensor());
auto l0 = prog.add_literal(v);
size_t num_axes = l0->get_shape().lens().size();
if(num_axes >= 4)
{
std::vector<int64_t> transpose_axes = get_axes(num_axes);
reorder_data(transpose_axes);
l0 = prog.add_instruction(op::transpose{transpose_axes}, l0);
}
return l0;
return prog.add_literal(v);
}
instruction_ref
......@@ -304,21 +348,8 @@ struct tf_parser
op.dilation[0] = dilation[2];
op.dilation[1] = dilation[3];
}
auto weights = args[1];
// check if weights are from a constant
if(weights->name() != "@param")
{
if(is_nhwc)
{
weights = prog.add_instruction(op::transpose{{1, 3, 0, 2}}, args[1]);
}
else
{
weights = prog.add_instruction(op::transpose{{3, 2, 0, 1}}, args[1]);
}
}
auto weights = to_kcxy(args[1]);
auto l0 = args[0];
if(contains(attributes, "padding"))
{
......@@ -368,8 +399,7 @@ struct tf_parser
op.padding[1] = padding[1];
}
}
return prog.add_instruction(op, {l0, weights});
return prog.add_instruction(op, {l0, to_kcxy(args[1])});
}
instruction_ref parse_depthwiseconv(const std::string&,
......@@ -392,6 +422,8 @@ struct tf_parser
op.stride[0] = stride[2];
op.stride[1] = stride[3];
}
auto weights = to_kcxy(args[1]);
if(contains(attributes, "dilations"))
{
std::vector<size_t> dilation;
......@@ -405,20 +437,6 @@ struct tf_parser
op.dilation[1] = dilation[3];
}
auto weights = args[1];
// check if weights are from a constant
if(weights->name() != "@param")
{
if(is_nhwc)
{
weights = prog.add_instruction(op::transpose{{1, 3, 0, 2}}, args[1]);
}
else
{
weights = prog.add_instruction(op::transpose{{3, 2, 0, 1}}, args[1]);
}
}
auto l0 = args[0];
if(contains(attributes, "padding"))
{
......@@ -466,8 +484,8 @@ struct tf_parser
new_weights_shape[0] = out_channels;
new_weights_shape[1] = 1;
// Make sure weights are contiguous before doing reshape
auto cweights = prog.add_instruction(op::contiguous{}, weights);
auto new_weights = prog.add_instruction(op::reshape{new_weights_shape}, cweights);
auto new_weights =
prog.add_instruction(op::reshape{new_weights_shape}, make_contiguous(weights));
return prog.add_instruction(op, {l0, new_weights});
}
......@@ -535,16 +553,14 @@ struct tf_parser
MIGRAPHX_THROW("TF_PARSER: axis value of " + to_string(axis) +
" must be smaller than input size " + to_string(input_size));
}
// check if input arg needs axis to be converted to NCHW
if(input_size >= 4)
axis = parse_axis(axis);
std::transform(
args.begin(),
args.end(),
std::back_inserter(unsqueezed_args),
[&](instruction_ref arg) { return prog.add_instruction(op::unsqueeze{{axis}}, arg); });
return prog.add_instruction(op::concat{static_cast<size_t>(axis)}, unsqueezed_args);
return to_nhwc(
prog.add_instruction(op::concat{static_cast<size_t>(axis)}, unsqueezed_args));
}
instruction_ref
......@@ -647,7 +663,7 @@ struct tf_parser
MIGRAPHX_THROW("reshape needs 2 arguments (input, new_shape)");
auto s = args[1]->eval();
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
return prog.add_instruction(op, args[0]);
return prog.add_instruction(op, make_contiguous(args[0]));
}
void parse_from(std::istream& is)
......@@ -678,7 +694,7 @@ struct tf_parser
std::vector<instruction_ref> args)
{
op::squeeze op;
auto axes = parse_axes(attributes, "squeeze_dims");
auto axes = attributes.at("squeeze_dims").list().i();
copy(axes, std::back_inserter(op.axes));
auto args0_dims = args[0]->get_shape().lens();
if(op.axes.empty()) // no squeeze_dims provided, remove any dim that equals 1
......@@ -691,7 +707,7 @@ struct tf_parser
}
}
}
return prog.add_instruction(op, args[0]);
return prog.add_instruction(op, make_contiguous(args[0]));
}
instruction_ref parse_stridedslice(const std::string&,
......@@ -702,11 +718,6 @@ struct tf_parser
auto starts = args[1]->eval().get<int32_t>().to_vector();
auto ends = args[2]->eval().get<int32_t>().to_vector();
size_t num_axes = args[0]->get_shape().lens().size();
if(num_axes >= 4)
{
reorder_data(starts);
reorder_data(ends);
}
op.starts = std::vector<int64_t>(starts.begin(), starts.end());
op.ends = std::vector<int64_t>(ends.begin(), ends.end());
......@@ -725,13 +736,9 @@ struct tf_parser
if(((shrink_axis_mask >> i) & bitwise_compare) == 1)
squeeze_axes.push_back(i);
}
if(num_axes >= 4)
{
squeeze_axes = parse_axes(squeeze_axes);
}
auto l0 = prog.add_instruction(op, args[0]);
return prog.add_instruction(op::squeeze{squeeze_axes}, l0);
auto l0 = prog.add_instruction(op, make_contiguous(args[0]));
return to_nhwc(prog.add_instruction(op::squeeze{squeeze_axes}, l0));
}
void parse_graph(const tensorflow::GraphDef& graph)
......@@ -748,7 +755,7 @@ struct tf_parser
reorder_data(dims);
}
shape s = shape{shape_type, dims};
instructions[name] = prog.add_parameter(name, s);
instructions[name] = to_nhwc(prog.add_parameter(name, s));
}
for(auto&& p : nodes)
{
......@@ -1098,6 +1105,7 @@ program parse_tf(const std::string& name, bool is_nhwc)
#else
parser.parse_from(input);
#endif
parser.to_nchw(std::prev(parser.prog.end()));
return std::move(parser.prog);
}
......
......@@ -148,6 +148,56 @@ TEST_CASE(match_arg7)
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_arg8)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::all_of(match::arg(0)(match::name("@literal")),
match::arg(1)(match::name("@literal"))),
match::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_nargs1)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::nargs(2));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_nargs2)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::nargs(2), match::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_nargs3)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::all_of(match::nargs(2)));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_args1)
{
migraphx::program p;
......@@ -307,6 +357,19 @@ TEST_CASE(match_all_of2)
EXPECT(bool{r.result == p.end()});
}
TEST_CASE(match_all_of3)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::all_of(match::all_of(
match::arg(0)(match::name("@literal")), match::arg(1)(match::name("@literal")))));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_any_of1)
{
migraphx::program p;
......@@ -359,6 +422,132 @@ TEST_CASE(match_none_of2)
EXPECT(bool{r.result == p.end()});
}
TEST_CASE(match_output1)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one);
auto sum = p.add_instruction(sum_op{}, minus, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::output(match::name("sum")));
auto r = find_match(p, m);
EXPECT(bool{r.result == minus});
}
TEST_CASE(match_output2)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one);
auto sum = p.add_instruction(sum_op{}, minus, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("@literal")(match::output(match::name("sum")));
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
TEST_CASE(match_skip_output1)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one);
auto sum = p.add_instruction(sum_op{}, minus, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(p, m);
EXPECT(bool{r.result == minus});
}
TEST_CASE(match_skip_output2)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one);
auto minus_pass = p.add_instruction(pass_op{}, minus);
auto sum = p.add_instruction(sum_op{}, minus_pass, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(p, m);
EXPECT(bool{r.result == minus});
}
TEST_CASE(match_skip_output3)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one);
auto minus_pass1 = p.add_instruction(pass_op{}, minus);
auto minus_pass2 = p.add_instruction(pass_op{}, minus_pass1);
auto minus_pass3 = p.add_instruction(pass_op{}, minus_pass2);
auto sum = p.add_instruction(sum_op{}, minus_pass3, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(p, m);
EXPECT(bool{r.result == minus});
}
TEST_CASE(match_skip_output4)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto pass = p.add_instruction(pass_op{}, one);
auto sum = p.add_instruction(sum_op{}, pass, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("@literal")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(p, m);
EXPECT(bool{r.result == two});
}
TEST_CASE(match_skip_output5)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto pass = p.add_instruction(pass_op{}, one);
auto sum1 = p.add_instruction(sum_op{}, pass, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, one);
auto sum3 = p.add_instruction(sum_op{}, sum2, two);
p.add_instruction(pass_op{}, sum3);
auto m = match::name("@literal")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
TEST_CASE(match_skip_output6)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one);
auto sum1 = p.add_instruction(sum_op{}, minus, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, one);
auto sum3 = p.add_instruction(sum_op{}, sum2, two);
p.add_instruction(pass_op{}, sum3);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(p, m);
EXPECT(bool{r.result == minus});
}
TEST_CASE(match_skip_output7)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus1 = p.add_instruction(minus_op{}, two, one);
auto minus2 = p.add_instruction(minus_op{}, two, minus1);
auto sum = p.add_instruction(sum_op{}, one, minus2);
p.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("minus")));
auto r = find_match(p, m);
EXPECT(bool{r.result == minus1});
}
TEST_CASE(match_bind1)
{
migraphx::program p;
......
......@@ -38,7 +38,7 @@ TEST_CASE(test_shape_packed)
EXPECT(not s.broadcasted());
}
TEST_CASE(test_shape_transposed)
TEST_CASE(test_shape_transposed1)
{
migraphx::shape s{migraphx::shape::float_type, {2, 2}, {1, 2}};
EXPECT(not s.standard());
......@@ -47,6 +47,15 @@ TEST_CASE(test_shape_transposed)
EXPECT(not s.broadcasted());
}
TEST_CASE(test_shape_transposed2)
{
migraphx::shape s{migraphx::shape::float_type, {1, 1, 1, 1, 2}, {2, 2, 2, 2, 1}};
EXPECT(s.standard());
EXPECT(s.packed());
EXPECT(not s.transposed());
EXPECT(not s.broadcasted());
}
TEST_CASE(test_shape_broadcasted)
{
migraphx::shape s{migraphx::shape::float_type, {2, 2}, {1, 0}};
......
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
......@@ -165,4 +166,144 @@ TEST_CASE(transpose_double_contiguous)
EXPECT(p.has_instruction(t));
}
TEST_CASE(transpose_partial1)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = p.add_parameter("x", s);
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
p.add_instruction(pass_op{}, t2);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}
TEST_CASE(transpose_partial2)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = p.add_parameter("x", s);
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
auto t3 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2);
p.add_instruction(pass_op{}, t3);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 2);
}
TEST_CASE(transpose_partial3)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = p.add_parameter("x", s);
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
auto t3 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2);
auto t4 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t3);
p.add_instruction(pass_op{}, t4);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 3);
}
TEST_CASE(nop_transpose1)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = p.add_parameter("x", s);
auto t = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, x);
p.add_instruction(pass_op{}, t);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}
TEST_CASE(nop_transpose2)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = p.add_parameter("x", s);
auto t1 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, x);
auto t2 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, t1);
auto t3 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, t2);
auto t4 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, t3);
p.add_instruction(pass_op{}, t4);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 4);
}
TEST_CASE(nop_transpose3)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto concat = p.add_instruction(migraphx::op::concat{3}, x, y);
auto t1 = p.add_instruction(migraphx::op::transpose{{0, 1, 2, 3}}, concat);
auto t2 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, t1);
p.add_instruction(pass_op{}, t2);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}
TEST_CASE(concat_transpose1)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto xt = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x);
auto yt = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, y);
auto concat = p.add_instruction(migraphx::op::concat{2}, xt, yt);
auto t = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, concat);
p.add_instruction(pass_op{}, t);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape().lens() == out_shape.lens());
EXPECT(std::distance(p.begin(), p.end()) == n - 3);
auto new_concat =
std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != p.end()});
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 3);
}
TEST_CASE(concat_transpose2)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto xt = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, x);
auto yt = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, y);
auto concat = p.add_instruction(migraphx::op::concat{3}, xt, yt);
auto t = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat);
p.add_instruction(pass_op{}, t);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape().lens() == out_shape.lens());
EXPECT(std::distance(p.begin(), p.end()) == n - 2);
auto new_concat =
std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != p.end()});
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <iostream>
#include <vector>
#include <migraphx/literal.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/tf.hpp>
#include "test.hpp"
migraphx::program optimize_tf(const std::string& name, bool is_nhwc)
{
auto prog = migraphx::parse_tf(name, is_nhwc);
if(is_nhwc)
migraphx::run_passes(prog,
{migraphx::simplify_reshapes{},
migraphx::dead_code_elimination{},
migraphx::eliminate_identity{}});
return prog;
}
TEST_CASE(add_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
p.add_instruction(migraphx::op::add{}, l0, l1);
auto prog = migraphx::parse_tf("add_test.pb", false);
auto prog = optimize_tf("add_test.pb", false);
EXPECT(p == prog);
}
......@@ -28,7 +43,7 @@ TEST_CASE(add_bcast_test)
auto l2 = p.add_instruction(migraphx::op::multibroadcast{s0.lens()}, l0);
auto l3 = p.add_instruction(migraphx::op::multibroadcast{s0.lens()}, l1);
p.add_instruction(migraphx::op::add{}, l2, l3);
auto prog = migraphx::parse_tf("add_bcast_test.pb", false);
auto prog = optimize_tf("add_bcast_test.pb", false);
EXPECT(p == prog);
}
......@@ -51,7 +66,7 @@ TEST_CASE(batchnorm_test)
auto l4 = p.add_parameter("4", s0);
auto l1 = p.add_literal(migraphx::literal{s0, const_vals});
p.add_instruction(op, l0, l1, l2, l3, l4);
auto prog = migraphx::parse_tf("batchnorm_test.pb", true);
auto prog = optimize_tf("batchnorm_test.pb", true);
EXPECT(p == prog);
}
......@@ -65,7 +80,7 @@ TEST_CASE(biasadd_test)
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {500}});
auto l2 = p.add_instruction(migraphx::op::broadcast{axis, l0->get_shape().lens()}, l1);
p.add_instruction(migraphx::op::add{}, l0, l2);
auto prog = migraphx::parse_tf("biasadd_test.pb", true);
auto prog = optimize_tf("biasadd_test.pb", true);
EXPECT(p == prog);
}
......@@ -83,7 +98,7 @@ TEST_CASE(concat_test)
p.add_literal(migraphx::shape{migraphx::shape::int32_type}, std::vector<int>{axis});
p.add_instruction(migraphx::op::concat{static_cast<std::size_t>(axis)}, l0, l1);
auto prog = migraphx::parse_tf("concat_test.pb", false);
auto prog = optimize_tf("concat_test.pb", false);
EXPECT(p == prog);
}
......@@ -92,7 +107,7 @@ TEST_CASE(const_test)
{
migraphx::program p;
p.add_literal(migraphx::shape{migraphx::shape::float_type}, std::vector<float>{1.0f});
auto prog = migraphx::parse_tf("constant_test.pb", false);
auto prog = optimize_tf("constant_test.pb", false);
EXPECT(p == prog);
}
......@@ -112,10 +127,9 @@ TEST_CASE(conv_test)
op.padding = {1, 1};
op.stride = {1, 1};
op.dilation = {1, 1};
auto l2 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l1);
auto l3 = p.add_instruction(migraphx::op::transpose{{1, 3, 0, 2}}, l2);
p.add_instruction(op, l0, l3);
auto prog = migraphx::parse_tf("conv_test.pb", true);
auto l2 = p.add_instruction(migraphx::op::transpose{{3, 2, 0, 1}}, l1);
p.add_instruction(op, l0, l2);
auto prog = optimize_tf("conv_test.pb", true);
EXPECT(p == prog);
}
......@@ -136,12 +150,11 @@ TEST_CASE(depthwiseconv_test)
op.stride = {1, 1};
op.dilation = {1, 1};
op.group = 3;
auto l2 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l1);
auto l3 = p.add_instruction(migraphx::op::transpose{{1, 3, 0, 2}}, l2);
auto l3 = p.add_instruction(migraphx::op::transpose{{3, 2, 0, 1}}, l1);
auto l4 = p.add_instruction(migraphx::op::contiguous{}, l3);
auto l5 = p.add_instruction(migraphx::op::reshape{{3, 1, 3, 3}}, l4);
p.add_instruction(op, l0, l5);
auto prog = migraphx::parse_tf("depthwise_conv_test.pb", true);
auto prog = optimize_tf("depthwise_conv_test.pb", true);
EXPECT(p == prog);
}
......@@ -151,7 +164,7 @@ TEST_CASE(identity_test)
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::identity{}, l0);
auto prog = migraphx::parse_tf("identity_test.pb", false);
auto prog = optimize_tf("identity_test.pb", false);
EXPECT(p == prog);
}
......@@ -166,7 +179,7 @@ TEST_CASE(matmul_test)
auto trans_l1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
p.add_instruction(migraphx::op::dot{}, trans_l0, trans_l1);
auto prog = migraphx::parse_tf("matmul_test.pb", false);
auto prog = optimize_tf("matmul_test.pb", false);
EXPECT(p == prog);
}
......@@ -183,7 +196,7 @@ TEST_CASE(mean_test)
p.add_instruction(op, l0);
auto l3 = p.add_instruction(op, l0);
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3);
auto prog = migraphx::parse_tf("mean_test.pb", false);
auto prog = optimize_tf("mean_test.pb", false);
EXPECT(p == prog);
}
......@@ -193,14 +206,11 @@ TEST_CASE(mean_test_nhwc)
migraphx::program p;
migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 2}};
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_literal(l);
p.add_literal(l);
migraphx::op::pooling op;
op.lengths = {16, 16};
p.add_instruction(op, l0);
auto l3 = p.add_instruction(op, l0);
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3);
auto prog = migraphx::parse_tf("mean_test_nhwc.pb", true);
auto prog = optimize_tf("mean_test_nhwc.pb", true);
EXPECT(p == prog);
}
......@@ -212,7 +222,7 @@ TEST_CASE(mul_test)
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 16}});
p.add_instruction(migraphx::op::mul{}, l0, l1);
auto prog = migraphx::parse_tf("mul_test.pb", false);
auto prog = optimize_tf("mul_test.pb", false);
EXPECT(p == prog);
}
......@@ -234,7 +244,7 @@ TEST_CASE(pack_test)
return p.add_instruction(migraphx::op::unsqueeze{{axis}}, arg);
});
p.add_instruction(migraphx::op::concat{static_cast<size_t>(axis)}, unsqueezed_args);
auto prog = migraphx::parse_tf("pack_test.pb", false);
auto prog = optimize_tf("pack_test.pb", false);
EXPECT(p == prog);
}
......@@ -243,11 +253,14 @@ TEST_CASE(pack_test_nhwc)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt0 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0);
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt1 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l1);
auto l2 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
std::vector<migraphx::instruction_ref> args{l0, l1, l2};
auto lt2 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l2);
std::vector<migraphx::instruction_ref> args{lt0, lt1, lt2};
std::vector<migraphx::instruction_ref> unsqueezed_args;
int64_t nchw_axis = 1;
int64_t nchw_axis = 3;
std::transform(args.begin(),
args.end(),
......@@ -256,7 +269,7 @@ TEST_CASE(pack_test_nhwc)
return p.add_instruction(migraphx::op::unsqueeze{{nchw_axis}}, arg);
});
p.add_instruction(migraphx::op::concat{static_cast<size_t>(nchw_axis)}, unsqueezed_args);
auto prog = migraphx::parse_tf("pack_test_nhwc.pb", true);
auto prog = optimize_tf("pack_test_nhwc.pb", true);
EXPECT(p == prog);
}
......@@ -273,9 +286,9 @@ TEST_CASE(pooling_test)
max_pool_op.stride = {2, 2};
avg_pool_op.lengths = {2, 2};
max_pool_op.lengths = {2, 2};
p.add_instruction(avg_pool_op, l0);
p.add_instruction(max_pool_op, l0);
auto prog = migraphx::parse_tf("pooling_test.pb", true);
// p.add_instruction(avg_pool_op, l0);
auto prog = optimize_tf("pooling_test.pb", true);
EXPECT(p == prog);
}
......@@ -285,7 +298,7 @@ TEST_CASE(relu_test)
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::relu{}, l0);
auto prog = migraphx::parse_tf("relu_test.pb", false);
auto prog = optimize_tf("relu_test.pb", false);
EXPECT(p == prog);
}
......@@ -295,7 +308,7 @@ TEST_CASE(relu6_test)
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::clip{6.0, 0.0}, l0);
auto prog = migraphx::parse_tf("relu6_test.pb", false);
auto prog = optimize_tf("relu6_test.pb", false);
EXPECT(p == prog);
}
......@@ -308,7 +321,7 @@ TEST_CASE(reshape_test)
// in tf, the second arg is a literal that contains new dimensions
p.add_literal(migraphx::literal{s0, {1, 1, 1, 16}});
p.add_instruction(migraphx::op::reshape{{1, 1, 1, 16}}, l0);
auto prog = migraphx::parse_tf("reshape_test.pb", false);
auto prog = optimize_tf("reshape_test.pb", false);
EXPECT(p == prog);
}
......@@ -321,7 +334,7 @@ TEST_CASE(softmax_test)
auto r = p.add_instruction(migraphx::op::reshape{{long(dims[0]), long(dims[1]), 1, 1}}, l0);
auto s = p.add_instruction(migraphx::op::softmax{}, r);
p.add_instruction(migraphx::op::reshape{{long(dims[0]), long(dims[1])}}, s);
auto prog = migraphx::parse_tf("softmax_test.pb", false);
auto prog = optimize_tf("softmax_test.pb", false);
EXPECT(p == prog);
}
......@@ -331,7 +344,7 @@ TEST_CASE(squeeze_test)
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 1}});
p.add_instruction(migraphx::op::squeeze{{0, 3}}, l0);
auto prog = migraphx::parse_tf("squeeze_test.pb", false);
auto prog = optimize_tf("squeeze_test.pb", false);
EXPECT(p == prog);
}
......@@ -343,18 +356,13 @@ TEST_CASE(stridedslice_test)
std::size_t num_axes = 4;
migraphx::op::slice op;
op.starts = {0, 0, 0, 0};
op.ends = {1, 5, 1, 1};
op.ends = {1, 1, 1, 5};
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
// add literals for starts, ends, and strides in tf (NHWC format)
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{0, 0, 0, 0});
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{1, 1, 1, 5});
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{1, 1, 1, 1});
auto l1 = p.add_instruction(op, l0);
auto shrink_axis = 2;
auto shrink_axis = 1;
p.add_instruction(migraphx::op::squeeze{{shrink_axis}}, l1);
auto prog = migraphx::parse_tf("stridedslice_test.pb", true);
auto prog = optimize_tf("stridedslice_test.pb", true);
EXPECT(p == prog);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment