Unverified Commit 2fdf510d authored by mvermeulen's avatar mvermeulen Committed by GitHub
Browse files

Merge branch 'develop' into onnx_autopad_fix

parents d5939189 4085af9b
...@@ -12,7 +12,7 @@ add_library(migraphx ...@@ -12,7 +12,7 @@ add_library(migraphx
eliminate_concat.cpp eliminate_concat.cpp
eliminate_identity.cpp eliminate_identity.cpp
eliminate_pad.cpp eliminate_pad.cpp
fwd_conv_batchnorm_rewrite.cpp rewrite_batchnorm.cpp
rewrite_rnn.cpp rewrite_rnn.cpp
rewrite_pooling.cpp rewrite_pooling.cpp
env.cpp env.cpp
......
...@@ -74,7 +74,7 @@ auto bind_match(M m, std::string name) ...@@ -74,7 +74,7 @@ auto bind_match(M m, std::string name)
[ =, name = std::move(name) ](matcher_context & ctx, instruction_ref ins) { [ =, name = std::move(name) ](matcher_context & ctx, instruction_ref ins) {
auto result = m.match(ctx, ins); auto result = m.match(ctx, ins);
if(result != ctx.not_found()) if(result != ctx.not_found())
ctx.instructions.emplace(name, ins); ctx.instructions[name] = ins;
return result; return result;
}); });
} }
...@@ -240,6 +240,21 @@ void find_matches(program& p, Ms&&... ms) ...@@ -240,6 +240,21 @@ void find_matches(program& p, Ms&&... ms)
} }
} }
template <class M>
struct find_skip
{
M m;
M matcher() const { return m; }
void apply(program&, const matcher_result&) const {}
};
template <class M>
find_skip<M> make_find_skip(M m)
{
return {m};
}
struct lazy_and struct lazy_and
{ {
template <class F, class G> template <class F, class G>
...@@ -311,6 +326,12 @@ const constexpr auto all_of = match_fold_f<lazy_and, true, true>{}; ...@@ -311,6 +326,12 @@ const constexpr auto all_of = match_fold_f<lazy_and, true, true>{};
const constexpr auto any_of = match_fold_f<lazy_or, false, true>{}; const constexpr auto any_of = match_fold_f<lazy_or, false, true>{};
const constexpr auto none_of = match_fold_f<lazy_or, false, false>{}; const constexpr auto none_of = match_fold_f<lazy_or, false, false>{};
template <class... Ms>
auto skip_matches(Ms... ms)
{
return make_find_skip(any_of(ms...));
}
inline auto inputs() inline auto inputs()
{ {
return [](auto ins, auto f) { return [](auto ins, auto f) {
...@@ -369,6 +390,50 @@ MIGRAPHX_BASIC_MATCHER(used_once, const matcher_context& ctx, instruction_ref in ...@@ -369,6 +390,50 @@ MIGRAPHX_BASIC_MATCHER(used_once, const matcher_context& ctx, instruction_ref in
return ctx.not_found(); return ctx.not_found();
} }
inline auto used_once_recursive(std::size_t depth)
{
return make_basic_fun_matcher([=](const matcher_context& ctx, instruction_ref start) {
// Used once
if(start->outputs().size() == 1)
return start;
// Unused
if(start->outputs().empty())
{
if(std::next(start) == ctx.not_found())
return start;
else
return ctx.not_found();
}
// Check for dead instructions
auto is_dead = fix<bool>([&](auto self, auto ins, auto n) {
if(n == 0)
return false;
if(ins->get_shape().elements() == 0)
return false;
if(ins->outputs().empty() and std::next(ins) != ctx.not_found())
return true;
return std::all_of(ins->outputs().begin(), ins->outputs().end(), [&](auto i) {
return self(i, n - 1);
});
});
auto dead = std::count_if(start->outputs().begin(), start->outputs().end(), [&](auto i) {
return is_dead(i, depth);
});
if(dead + 1 == start->outputs().size())
return start;
return ctx.not_found();
});
}
MIGRAPHX_PRED_MATCHER(is_constant, instruction_ref ins) { return ins->can_eval(); }
MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref ins)
{
if(ins->outputs().empty() and ins != std::prev(ctx.not_found()))
return ins;
return ctx.not_found();
}
template <class... Ms> template <class... Ms>
auto skip_output(Ms... ms) auto skip_output(Ms... ms)
{ {
...@@ -404,6 +469,12 @@ inline auto name(std::unordered_set<std::string> names) ...@@ -404,6 +469,12 @@ inline auto name(std::unordered_set<std::string> names)
}); });
} }
template <class... Ts>
inline auto name(std::string s, Ts... xs) // NOLINT
{
return name(std::unordered_set<std::string>{std::move(s), std::move(xs)...});
}
inline auto nargs(std::size_t n) inline auto nargs(std::size_t n)
{ {
return make_basic_pred_matcher([=](instruction_ref ins) { return ins->inputs().size() == n; }); return make_basic_pred_matcher([=](instruction_ref ins) { return ins->inputs().size() == n; });
......
...@@ -13,9 +13,9 @@ struct program; ...@@ -13,9 +13,9 @@ struct program;
/** /**
* Rewrite batchnorm to a multiply and add. * Rewrite batchnorm to a multiply and add.
*/ */
struct fwd_conv_batchnorm_rewrite struct rewrite_batchnorm
{ {
std::string name() const { return "fwd_conv_batchnorm_rewrite"; } std::string name() const { return "rewrite_batchnorm"; }
void apply(program& p) const; void apply(program& p) const;
}; };
......
...@@ -206,6 +206,16 @@ struct onnx_parser ...@@ -206,6 +206,16 @@ struct onnx_parser
return out_lens; return out_lens;
} }
instruction_ref make_contiguous(instruction_ref ins)
{
if(ins->get_shape().standard())
{
return ins;
}
return prog.add_instruction(op::contiguous{}, ins);
}
template <class T> template <class T>
instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x) instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x)
{ {
...@@ -441,12 +451,7 @@ struct onnx_parser ...@@ -441,12 +451,7 @@ struct onnx_parser
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); }); s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
} }
if(!args[0]->get_shape().standard()) return prog.add_instruction(op, make_contiguous(args[0]));
{
args[0] = prog.add_instruction(op::contiguous{}, args[0]);
}
return prog.add_instruction(op, args[0]);
} }
instruction_ref instruction_ref
...@@ -494,23 +499,41 @@ struct onnx_parser ...@@ -494,23 +499,41 @@ struct onnx_parser
{ {
axis = parse_value(attributes.at("axis")).at<int>(); axis = parse_value(attributes.at("axis")).at<int>();
} }
op::gather op{axis}; op::gather op{axis};
return prog.add_instruction(op, std::move(args)); return prog.add_instruction(op, make_contiguous(args[0]), make_contiguous(args[1]));
} }
instruction_ref instruction_ref
parse_slice(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_slice(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
op::slice op; op::slice op;
std::vector<size_t> dims = args[0]->get_shape().lens();
size_t num_dims = dims.size();
if(contains(attributes, "axes")) if(contains(attributes, "axes"))
{ {
literal s = parse_value(attributes.at("axes")); literal s = parse_value(attributes.at("axes"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); }); s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
} }
else
{
op.axes = std::vector<int64_t>(num_dims);
std::iota(op.axes.begin(), op.axes.end(), 0);
}
if(contains(attributes, "ends"))
{ {
literal s = parse_value(attributes.at("ends")); literal s = parse_value(attributes.at("ends"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.ends)); }); s.visit([&](auto v) { copy(v, std::back_inserter(op.ends)); });
for(size_t i = 0; i < num_dims; i++)
{
if(static_cast<size_t>(op.ends[i]) > dims[i])
{
op.ends[i] = dims[i];
}
}
} }
if(contains(attributes, "starts"))
{ {
literal s = parse_value(attributes.at("starts")); literal s = parse_value(attributes.at("starts"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); }); s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); });
......
...@@ -74,7 +74,8 @@ void quantize(program& prog, const std::vector<std::string>& ins_names) ...@@ -74,7 +74,8 @@ void quantize(program& prog, const std::vector<std::string>& ins_names)
// if the input is a convert operator, uses its input // if the input is a convert operator, uses its input
// as its current input // as its current input
instruction_ref input_fp16{}; instruction_ref input_fp16{};
if(input->name() == "convert") if(input->name() == "convert" and
input->inputs().front()->get_shape().type() == shape::half_type)
{ {
input_fp16 = input->inputs().front(); input_fp16 = input->inputs().front();
} }
......
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp> #include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/op/batch_norm.hpp> #include <migraphx/op/batch_norm.hpp>
#include <migraphx/op/broadcast.hpp> #include <migraphx/op/broadcast.hpp>
#include <migraphx/op/add.hpp> #include <migraphx/op/add.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
...@@ -11,7 +12,7 @@ ...@@ -11,7 +12,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void fwd_conv_batchnorm_rewrite::apply(program& p) const void rewrite_batchnorm::apply(program& p) const
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
...@@ -25,46 +26,30 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const ...@@ -25,46 +26,30 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
if(any_of({gamma, bias, mean, variance}, [](auto arg) { return arg.empty(); })) if(any_of({gamma, bias, mean, variance}, [](auto arg) { return arg.empty(); }))
continue; continue;
auto conv_ins = ins->inputs()[0]; auto s = shape{ins->get_shape().type(), {ins->get_shape().lens()[1]}};
if(conv_ins->name() != "convolution")
continue;
// Get convolution weights
auto weights = conv_ins->inputs()[1]->eval();
if(weights.empty())
continue;
// Get epsilon // Get epsilon
auto bn_op = any_cast<op::batch_norm_inference>(ins->get_operator()); auto bn_op = any_cast<op::batch_norm_inference>(ins->get_operator());
auto epsilon = bn_op.epsilon; auto epsilon = bn_op.epsilon;
// Get convolution op
auto conv_op = conv_ins->get_operator(); argument a{s};
auto weights_lens = weights.get_shape().lens(); argument b{s};
auto conv_lens = conv_ins->get_shape().lens(); visit_all(gamma, bias, mean, variance, a, b)(
argument new_weights{weights.get_shape()}; [&](auto gamma2, auto bias2, auto mean2, auto variance2, auto a2, auto b2) {
argument new_bias{{bias.get_shape().type(), {bias.get_shape().elements()}}}; dfor(a.get_shape().elements())(
visit_all(weights, gamma, bias, mean, variance, new_weights, new_bias)( [&](std::size_t c) { a2[c] = gamma2[c] / std::sqrt(variance2[c] + epsilon); });
[&](auto weights2, dfor(b.get_shape().elements())([&](std::size_t c) {
auto gamma2, b2[c] = bias2[c] - (gamma2[c] * mean2[c] / std::sqrt(variance2[c] + epsilon));
auto bias2,
auto mean2,
auto variance2,
auto new_weights2,
auto new_bias2) {
dfor(weights_lens[0], weights_lens[1], weights_lens[2], weights_lens[3])(
[&](std::size_t k, std::size_t c, std::size_t h, std::size_t w) {
new_weights2(k, c, h, w) =
gamma2[k] / std::sqrt(variance2[k] + epsilon) * weights2(k, c, h, w);
});
dfor(new_bias.get_shape().elements())([&](std::size_t c) {
new_bias2[c] =
bias2[c] - (gamma2[c] * mean2[c] / std::sqrt(variance2[c] + epsilon));
}); });
}); });
// Replace convolution instruction with updated weights
auto l_weights = p.add_literal({weights.get_shape(), new_weights.data()}); auto broadcast = op::broadcast{1, ins->get_shape().lens()};
auto l_bias = p.add_literal({new_bias.get_shape(), new_bias.data()}); auto a_ins = p.add_literal({a.get_shape(), a.data()});
auto c = p.replace_instruction(conv_ins, conv_op, {conv_ins->inputs()[0], l_weights}); auto a_broadcast = p.insert_instruction(ins, broadcast, a_ins);
auto b = p.insert_instruction(ins, op::broadcast{1, c->get_shape().lens()}, l_bias); auto mul = p.insert_instruction(ins, op::mul{}, ins->inputs().front(), a_broadcast);
p.replace_instruction(ins, op::add{}, {c, b}); auto b_ins = p.add_literal({b.get_shape(), b.data()});
auto b_broadcast = p.insert_instruction(ins, broadcast, b_ins);
auto add = p.insert_instruction(ins, op::add{}, mul, b_broadcast);
p.replace_instruction(ins, add);
} }
} }
......
#include <migraphx/simplify_algebra.hpp> #include <migraphx/simplify_algebra.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/op/add.hpp> #include <migraphx/op/add.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct find_add_lit_broadcast auto lit_broadcast() { return match::any_of(match::is_constant(), match::name("broadcast")); }
auto not_lit_broadcast() { return match::none_of(match::is_constant(), match::name("broadcast")); }
auto op_lit_broadcast(std::string op, std::string x, std::string y)
{
return match::name(std::move(op))(match::either_arg(0, 1)(
lit_broadcast().bind(std::move(x)), not_lit_broadcast().bind(std::move(y))));
}
auto conv_const_weights()
{
return match::name("convolution")(match::used_once(),
match::args(match::any(), match::is_constant().bind("w")));
}
struct find_mul_conv
{ {
auto lit_broadcast() const auto matcher() const
{ {
return match::any_of(match::name("@literal"), match::name("broadcast")); return match::name("mul")(match::either_arg(0, 1)(conv_const_weights().bind("conv"),
match::name("broadcast").bind("a")));
} }
auto not_lit_broadcast() const
void apply(program& p, match::matcher_result r) const
{ {
return match::none_of(match::name("@literal"), match::name("broadcast")); auto ins = r.result;
auto conv_ins = r.instructions["conv"];
auto a_ins = r.instructions["a"];
auto w_ins = r.instructions["w"];
auto broadcast_op = any_cast<op::broadcast>(a_ins->get_operator());
if(broadcast_op.axis != 1)
return;
auto new_a = p.insert_instruction(
ins, op::broadcast{0, w_ins->get_shape().lens()}, a_ins->inputs().front());
auto new_mul = p.insert_instruction(ins, op::mul{}, new_a, w_ins);
auto new_conv = p.insert_instruction(
ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_mul);
p.replace_instruction(ins, new_conv);
} }
auto add_lit_broadcast(std::string x, std::string y) const };
// a * (x + b) => a * x + a * b
struct find_mul_add
{
auto matcher() const
{ {
return match::name("add")(match::either_arg(0, 1)(lit_broadcast().bind(std::move(x)), return match::name("mul")(match::either_arg(0, 1)(
not_lit_broadcast().bind(std::move(y)))); match::name("add")(
match::either_arg(0, 1)(
match::any().bind("x"),
match::any_of(conv_const_weights(), match::is_constant()).bind("b")),
match::none_of(match::args(match::is_constant(), match::is_constant())),
match::used_once()),
match::is_constant().bind("a")));
} }
void apply(program& p, match::matcher_result r) const
{
auto ins = r.result;
auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"];
auto x_ins = r.instructions["x"];
assert(x_ins != b_ins);
auto ax_ins = p.insert_instruction(ins, op::mul{}, a_ins, x_ins);
auto ab_ins = p.insert_instruction(ins, op::mul{}, a_ins, b_ins);
p.replace_instruction(ins, op::add{}, ax_ins, ab_ins);
}
};
struct find_add_lit_broadcast
{
auto matcher() const auto matcher() const
{ {
return match::name("add")( return match::name("add")(
match::args(add_lit_broadcast("a", "x"), add_lit_broadcast("b", "y"))); match::either_arg(0, 1)(op_lit_broadcast("add", "a", "x"), lit_broadcast().bind("b")));
}
void apply(program& p, match::matcher_result r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"];
auto sumab = p.insert_instruction(ins, op::add{}, a_ins, b_ins);
p.replace_instruction(ins, op::add{}, x_ins, sumab);
}
};
struct find_double_add_lit_broadcast
{
auto matcher() const
{
return match::name("add")(
match::args(op_lit_broadcast("add", "a", "x"), op_lit_broadcast("add", "b", "y")));
} }
void apply(program& p, match::matcher_result r) const void apply(program& p, match::matcher_result r) const
...@@ -36,11 +117,9 @@ struct find_add_lit_broadcast ...@@ -36,11 +117,9 @@ struct find_add_lit_broadcast
auto a_ins = r.instructions["a"]; auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"]; auto b_ins = r.instructions["b"];
if(a_ins->name() != b_ins->name())
return;
instruction_ref sumab; instruction_ref sumab;
if(a_ins->name() == "broadcast") if(a_ins->name() == "broadcast" and b_ins->name() == "broadcast")
{ {
if(a_ins->inputs().at(0)->get_shape() != b_ins->inputs().at(0)->get_shape()) if(a_ins->inputs().at(0)->get_shape() != b_ins->inputs().at(0)->get_shape())
return; return;
...@@ -59,7 +138,46 @@ struct find_add_lit_broadcast ...@@ -59,7 +138,46 @@ struct find_add_lit_broadcast
} }
}; };
void simplify_algebra::apply(program& p) const { match::find_matches(p, find_add_lit_broadcast{}); } struct find_inner_broadcast
{
auto matcher() const
{
return match::name("mul", "add")(
match::args(match::name("broadcast").bind("x"), match::name("broadcast").bind("y")));
}
void apply(program& p, match::matcher_result r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
auto y_ins = r.instructions["y"];
auto xbroadcast = any_cast<op::broadcast>(x_ins->get_operator());
auto ybroadcast = any_cast<op::broadcast>(y_ins->get_operator());
if(xbroadcast.axis != ybroadcast.axis)
return;
auto op = p.insert_instruction(
ins, ins->get_operator(), x_ins->inputs().front(), y_ins->inputs().front());
p.replace_instruction(ins, xbroadcast, op);
}
};
void simplify_algebra::apply(program& p) const
{
// Run simplifications multiple times
for(int i = 0; i < 4; i++)
{
match::find_matches(p,
find_inner_broadcast{},
find_double_add_lit_broadcast{},
find_add_lit_broadcast{},
find_mul_conv{},
find_mul_add{});
dead_code_elimination{}.apply(p);
}
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -16,6 +16,7 @@ add_library(migraphx_device ...@@ -16,6 +16,7 @@ add_library(migraphx_device
device/argmin.cpp device/argmin.cpp
device/max.cpp device/max.cpp
device/min.cpp device/min.cpp
device/mul_add.cpp
device/exp.cpp device/exp.cpp
device/erf.cpp device/erf.cpp
device/log.cpp device/log.cpp
......
...@@ -6,6 +6,16 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -6,6 +6,16 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void mul_add_relu(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{
nary(stream, result, arg1, arg2, arg3)(
[](auto x, auto a, auto b) { return std::max<decltype(a * x + b)>(0, a * x + b); });
}
void add_relu(hipStream_t stream, void add_relu(hipStream_t stream,
const argument& result, const argument& result,
const argument& arg1, const argument& arg1,
......
...@@ -118,6 +118,111 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg ...@@ -118,6 +118,111 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg
}); });
} }
template <class F, class... Arguments>
void nary_double_broadcast_vec_impl(
hipStream_t stream, F f, argument result, argument barg1, argument barg2, Arguments... args)
{
assert(barg1.get_shape().broadcasted());
assert(barg2.get_shape().broadcasted());
assert(barg1.get_shape() == barg2.get_shape());
const auto& output_shape = result.get_shape();
const auto& b_shape = barg1.get_shape();
auto bdim =
std::distance(b_shape.strides().begin(),
std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) {
return x != 0;
}));
auto bdim_len = output_shape.lens()[bdim];
auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len;
const std::size_t vec_size = 4;
const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal;
const std::size_t bdim_vec_len = bdim_len / vec_size;
hip_vec_visit_all<vec_size>(result, barg1, barg2, args...)(
[&](auto output, auto binput1, auto binput2, auto... inputs) {
using type = typename decltype(output)::value_type;
const std::size_t nelements = output.size() / vec_size;
launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPHX_DEVICE_SHARED type buffer[2048 / vec_size];
// Load bias into LDS
for(size_t i = idx.local; i < bdim_vec_len; i += nlocal)
{
buffer[i] = binput1.data()[i];
}
for(size_t i = idx.local; i < bdim_vec_len; i += nlocal)
{
buffer[i + bdim_vec_len] = binput2.data()[i];
}
__syncthreads();
auto* bp = as_pointer(buffer);
// Process the data
for(size_t i = idx.global; i < nelements; i += nglobal)
{
auto bidx = ((i * vec_size) % bdim_next_stride) / bdim_stride;
auto b1 = bp[bidx];
auto b2 = bp[bidx + bdim_len];
auto out = output.data()[i];
for(std::size_t j = 0; j < vec_size; j++)
{
out[j] = f(inputs.data()[i][j]..., b2, b1);
}
output.data()[i] = out;
}
});
});
}
template <class F, class... Arguments>
void nary_double_broadcast_impl(
hipStream_t stream, F f, argument result, argument barg1, argument barg2, Arguments... args)
{
assert(barg1.get_shape().broadcasted());
assert(barg2.get_shape().broadcasted());
assert(barg1.get_shape() == barg2.get_shape());
const auto& output_shape = result.get_shape();
const auto& b_shape = barg1.get_shape();
auto bdim =
std::distance(b_shape.strides().begin(),
std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) {
return x != 0;
}));
auto bdim_len = output_shape.lens()[bdim];
auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len;
const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal;
std::size_t nelements = result.get_shape().elements();
hip_visit_all(result, barg1, barg2, args...)(
[&](auto output, auto binput1, auto binput2, auto... inputs) {
using type = typename decltype(output)::value_type;
launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPHX_DEVICE_SHARED type buffer[2048];
// Load bias into LDS
for(size_t i = idx.local; i < bdim_len; i += nlocal)
{
buffer[i] = binput1.data()[i];
}
for(size_t i = idx.local; i < bdim_len; i += nlocal)
{
buffer[i + bdim_len] = binput2.data()[i];
}
__syncthreads();
// Process the data
for(size_t i = idx.global; i < nelements; i += nglobal)
{
auto bidx = (i % bdim_next_stride) / bdim_stride;
auto b1 = buffer[bidx];
auto b2 = buffer[bidx + bdim_len];
output.data()[i] = f(inputs.data()[i]..., b2, b1);
}
});
});
}
template <class F, class... Arguments> template <class F, class... Arguments>
void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments... args) void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments... args)
{ {
...@@ -177,49 +282,113 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args) ...@@ -177,49 +282,113 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args)
} }
template <class... Arguments> template <class... Arguments>
auto nary(hipStream_t stream, argument result) bool broadcastable(bool& divisible_by_4,
std::size_t max_size,
const argument& result,
const argument& barg,
const Arguments&... args)
{
divisible_by_4 = false;
auto bshape = barg.get_shape();
const bool standard =
all_of({args.get_shape()...}, [](const shape& s) { return s.standard(); });
const bool same_shapes =
all_of({args.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); });
// TODO: Check result and args shape is the same
if(standard and same_shapes and bshape.broadcasted() and not bshape.scalar())
{
auto not_zero = [](auto x) { return x != 0; };
const auto& strides = bshape.strides();
auto b_it = std::find_if(strides.begin(), strides.end(), not_zero);
auto b_idx = std::distance(strides.begin(), b_it);
auto b_len = result.get_shape().lens()[b_idx];
auto b_stride = result.get_shape().strides()[b_idx];
assert(bshape.lens()[b_idx] == b_len);
if(b_len <= max_size and std::none_of(std::next(b_it), strides.end(), not_zero))
{
divisible_by_4 = (b_len % 4 == 0) and (b_stride % 4 == 0) and
(front_args(args...).get_shape().elements() % 4 == 0);
return true;
}
}
return false;
}
inline bool broadcastable(bool& divisible_by_4, std::size_t, const argument&, const argument&)
{
divisible_by_4 = false;
return false;
}
// Nullary
inline auto nary(hipStream_t stream, argument result)
{ {
return [=](auto f) { nary_standard_impl(stream, f, result); }; return [=](auto f) { nary_standard_impl(stream, f, result); };
} }
// Unary
inline auto nary(hipStream_t stream, argument result, argument arg)
{
return [=](auto f) { nary_impl(stream, f, result, arg); };
}
// Binary
inline auto nary(hipStream_t stream, argument result, argument arg, argument barg)
{
return [=](auto f) {
bool divisible_by_4 = false;
if(broadcastable(divisible_by_4, 2048, result, barg, arg))
{
if(divisible_by_4)
nary_broadcast_vec_impl(stream, f, result, barg, arg);
else
nary_broadcast_impl(stream, f, result, barg, arg);
}
else
{
nary_impl(stream, f, result, arg, barg);
}
};
}
template <class... Arguments> template <class... Arguments>
auto nary(hipStream_t stream, argument result, Arguments... args) auto nary(hipStream_t stream, argument result, Arguments... args)
{ {
static_assert(sizeof...(args) > 2, "Args needs to be greater than 2");
return [=](auto f) { return [=](auto f) {
auto barg = back_args(args...); auto barg1 = back_args(args...);
bool fallback = pop_back_args(args...)([&](auto&&... args2) { bool fallback1 = pop_back_args(args...)([&](auto&&... args2) {
auto bshape = barg.get_shape(); auto barg2 = back_args(args2...);
const bool standard = bool fallback2 =
all_of({args2.get_shape()...}, [](const shape& s) { return s.standard(); }); barg2.get_shape() != barg1.get_shape() or not barg2.get_shape().broadcasted() or
const bool same_shapes = all_of( pop_back_args(args2...)([&](auto&&... args3) {
{args2.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); }); bool divisible_by_4 = false;
// TODO: Check result and args shape is the same if(broadcastable(divisible_by_4, 1024, result, barg2, args3...))
if(standard and same_shapes and bshape.broadcasted() and not bshape.scalar()) {
if(divisible_by_4)
nary_double_broadcast_vec_impl(
stream, f, result, barg1, barg2, args3...);
else
nary_double_broadcast_impl(stream, f, result, barg1, barg2, args3...);
return false;
}
return true;
});
if(not fallback2)
return false;
bool divisible_by_4 = false;
if(broadcastable(divisible_by_4, 2048, result, barg1, args2...))
{ {
auto not_zero = [](auto x) { return x != 0; }; if(divisible_by_4)
const auto& strides = bshape.strides(); nary_broadcast_vec_impl(stream, f, result, barg1, args2...);
auto b_it = std::find_if(strides.begin(), strides.end(), not_zero); else
auto b_idx = std::distance(strides.begin(), b_it); nary_broadcast_impl(stream, f, result, barg1, args2...);
auto b_len = result.get_shape().lens()[b_idx]; return false;
auto b_stride = result.get_shape().strides()[b_idx];
assert(bshape.lens()[b_idx] == b_len);
if(b_len <= 2048 and std::none_of(std::next(b_it), strides.end(), not_zero))
{
const bool divisible_by_4 =
(b_len % 4 == 0) and (b_stride % 4 == 0) and
(front_args(args...).get_shape().elements() % 4 == 0);
if(divisible_by_4)
nary_broadcast_vec_impl(stream, f, result, barg, args2...);
else
nary_broadcast_impl(stream, f, result, barg, args2...);
return false;
}
} }
return true; return true;
}); });
if(fallback) if(fallback1)
nary_impl(stream, f, result, args...); nary_impl(stream, f, result, args...);
}; };
} }
......
#include <migraphx/gpu/device/add_relu.hpp>
#include <migraphx/gpu/device/nary.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void mul_add(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{
nary(stream, result, arg1, arg2, arg3)([](auto x, auto a, auto b) { return a * x + b; });
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/gpu/miopen.hpp> #include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/convolution.hpp> #include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/device/mul_add.hpp>
#include <migraphx/gpu/device/add_relu.hpp> #include <migraphx/gpu/device/add_relu.hpp>
#include <migraphx/gpu/device/add.hpp> #include <migraphx/gpu/device/add.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
...@@ -198,21 +199,62 @@ struct hip_add_relu ...@@ -198,21 +199,62 @@ struct hip_add_relu
} }
}; };
struct hip_mul_add
{
std::string name() const { return "hip::mul_add"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(4);
return inputs.front();
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::mul_add(ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2));
return args.at(3);
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
struct hip_mul_add_relu
{
std::string name() const { return "hip::mul_add_relu"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(4);
return inputs.front();
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::mul_add_relu(
ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2));
return args.at(3);
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
void move_broadcasted_back(std::vector<instruction_ref>& args) void move_broadcasted_back(std::vector<instruction_ref>& args)
{ {
// Ensure the last arguments is the broadcasted one // Ensure the last arguments is the broadcasted one
auto it = std::find_if( auto last = std::prev(args.end());
args.begin(), args.end(), [](auto arg) { return arg->get_shape().broadcasted(); }); auto it =
if(it != args.end()) std::find_if(args.begin(), last, [](auto arg) { return arg->get_shape().broadcasted(); });
std::swap(*it, *std::prev(args.end(), 2)); if(it != last)
std::swap(*it, *std::prev(last));
} }
void move_standard_front(std::vector<instruction_ref>& args) void move_standard_front(std::vector<instruction_ref>& args)
{ {
// Ensure the first arguments is the standard one // Ensure the first arguments is the standard one
auto it = std::find_if( auto last = std::prev(args.end());
args.begin(), args.end(), [](auto arg) { return arg->get_shape().standard(); }); auto it =
if(it != args.end()) std::find_if(args.begin(), last, [](auto arg) { return arg->get_shape().standard(); });
if(it != last)
std::swap(*it, args.front()); std::swap(*it, args.front());
} }
...@@ -220,11 +262,13 @@ struct find_add_relu ...@@ -220,11 +262,13 @@ struct find_add_relu
{ {
auto matcher() const auto matcher() const
{ {
return match::name("gpu::relu")( return match::name("gpu::relu")(match::arg(0)(
match::arg(0)(match::any_of(match::name("gpu::add"), match::used_once(),
match::name("hip::triadd"), match::any_of(match::name("gpu::add"),
match::any_of[match::inputs()](match::standard_shape())) match::name("hip::triadd"),
.bind("add"))); match::any_of(match::name("@literal"),
match::any_of[match::inputs()](match::standard_shape())))
.bind("add")));
} }
void apply(program& p, match::matcher_result r) const void apply(program& p, match::matcher_result r) const
...@@ -249,8 +293,10 @@ struct find_triadd ...@@ -249,8 +293,10 @@ struct find_triadd
auto matcher() const auto matcher() const
{ {
return match::name("gpu::add")(match::either_arg(0, 1)( return match::name("gpu::add")(match::either_arg(0, 1)(
match::name("gpu::add").bind("add"), match::name("gpu::add")(match::used_once()).bind("add"),
match::any(match::any_of[match::inputs()](match::standard_shape())).bind("input"))); match::any(match::any_of(match::name("@literal"),
match::any_of[match::inputs()](match::standard_shape())))
.bind("input")));
} }
void apply(program& p, match::matcher_result r) const void apply(program& p, match::matcher_result r) const
...@@ -273,6 +319,51 @@ struct find_triadd ...@@ -273,6 +319,51 @@ struct find_triadd
} }
}; };
struct find_mul_add
{
auto matcher() const
{
return match::name("gpu::add")(match::either_arg(0, 1)(
match::name("gpu::mul")(match::used_once()).bind("mul"), match::any().bind("b")));
}
void apply(program& p, match::matcher_result r) const
{
auto mul_ins = r.instructions["mul"];
auto b_ins = r.instructions["b"];
auto ins = r.result;
auto args = mul_ins->inputs();
assert(mul_ins != b_ins);
move_standard_front(args);
move_broadcasted_back(args);
args.insert(std::prev(args.end()), b_ins);
args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_mul_add{}, args);
}
};
struct find_mul_add_relu
{
auto matcher() const
{
return match::name("gpu::relu")(
match::arg(0)(match::name("hip::mul_add")(match::used_once()).bind("mul_add")));
}
void apply(program& p, match::matcher_result r) const
{
auto mul_add_ins = r.instructions["mul_add"];
auto ins = r.result;
auto args = mul_add_ins->inputs();
// Use the allocation from the relu operator
args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_mul_add_relu{}, args);
}
};
struct miopen_conv_bias struct miopen_conv_bias
{ {
op::convolution op; op::convolution op;
...@@ -428,6 +519,8 @@ void fuse_ops::apply(program& p) const ...@@ -428,6 +519,8 @@ void fuse_ops::apply(program& p) const
match::find_matches(p, match::find_matches(p,
find_conv_bias_relu{ctx}, find_conv_bias_relu{ctx},
find_conv_bias{ctx}, find_conv_bias{ctx},
find_mul_add{},
find_mul_add_relu{},
find_add_relu{} find_add_relu{}
); );
// clang-format on // clang-format on
......
...@@ -11,6 +11,12 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -11,6 +11,12 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void mul_add_relu(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3);
void add_relu(hipStream_t stream, void add_relu(hipStream_t stream,
const argument& result, const argument& result,
const argument& arg1, const argument& arg1,
......
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_MUL_ADD_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_MUL_ADD_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void mul_add(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -8,51 +8,6 @@ namespace migraphx { ...@@ -8,51 +8,6 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
template <class... Ts>
rocblas_status generic_rocblas_gemm_ex(Ts&&... xs)
{
return rocblas_gemm_ex(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_batched_gemm_ex(Ts&&... xs)
{
return rocblas_gemm_strided_batched_ex(std::forward<Ts>(xs)...);
}
template <class T>
struct compute_rocblas_type
{
using type = T;
};
template <class T>
struct compute_rocblas_type<const T>
{
using type = const typename compute_rocblas_type<T>::type;
};
template <>
struct compute_rocblas_type<half>
{
using type = rocblas_half;
};
template <class T>
using rb_type = typename compute_rocblas_type<T>::type;
template <class T>
rb_type<T> to_rocblas_type(T x)
{
return reinterpret_cast<const rb_type<T>&>(x);
}
template <class T>
rb_type<T>* to_rocblas_type(T* x)
{
return reinterpret_cast<rb_type<T>*>(x);
}
shape rocblas_quant_gemm::compute_shape(const std::vector<shape>& inputs) const shape rocblas_quant_gemm::compute_shape(const std::vector<shape>& inputs) const
{ {
std::vector<shape> in_shapes(inputs); std::vector<shape> in_shapes(inputs);
...@@ -102,13 +57,13 @@ argument rocblas_quant_gemm::compute(context& ctx, ...@@ -102,13 +57,13 @@ argument rocblas_quant_gemm::compute(context& ctx,
auto a_lens = args[0].get_shape().lens(); auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens(); auto b_lens = args[1].get_shape().lens();
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha)); auto alpha_r = as(op.alpha);
auto beta_r = to_rocblas_type(as(beta)); auto beta_r = as(beta);
auto out_lens = output_shape.lens(); auto out_lens = output_shape.lens();
rocblas_int m = out_lens[dim_0]; rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1]; rocblas_int n = out_lens[dim_1];
rocblas_int k = args[0].get_shape().lens()[dim_1]; rocblas_int k = args[0].get_shape().lens()[dim_1];
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); }; auto to_pointer = [&](auto&& arg) { return as.from(arg.data()); };
assert(k % 4 == 0); assert(k % 4 == 0);
auto num_matrices = std::accumulate( auto num_matrices = std::accumulate(
...@@ -119,36 +74,36 @@ argument rocblas_quant_gemm::compute(context& ctx, ...@@ -119,36 +74,36 @@ argument rocblas_quant_gemm::compute(context& ctx,
// column-major format. When doing a C = A * B, we actually do // column-major format. When doing a C = A * B, we actually do
// C^T = (B^T) * (A^T). That is the reason we input args[1] as // C^T = (B^T) * (A^T). That is the reason we input args[1] as
// A and args[0] as B in calling the rocblas_gemm. // A and args[0] as B in calling the rocblas_gemm.
generic_rocblas_gemm_ex(ctx.get_stream().get_rocblas(), rocblas_gemm_ex(ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none, transa ? rocblas_operation_transpose : rocblas_operation_none,
n, n,
m, m,
k, k,
&alpha_r, &alpha_r,
to_pointer(args.at(1)), to_pointer(args.at(1)),
rocblas_datatype_i8_r, rocblas_datatype_i8_r,
ldb, ldb,
to_pointer(args.at(0)), to_pointer(args.at(0)),
rocblas_datatype_i8_r, rocblas_datatype_i8_r,
lda, lda,
&beta_r, &beta_r,
to_pointer(args[2]), to_pointer(args[2]),
rocblas_datatype_i32_r, rocblas_datatype_i32_r,
ldc, ldc,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]), is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
rocblas_datatype_i32_r, rocblas_datatype_i32_r,
ldc, ldc,
rocblas_datatype_i32_r, rocblas_datatype_i32_r,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
0, 0,
0, 0,
nullptr, nullptr,
nullptr); nullptr);
} }
else else
{ {
generic_rocblas_batched_gemm_ex( rocblas_gemm_strided_batched_ex(
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none, transa ? rocblas_operation_transpose : rocblas_operation_none,
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include <migraphx/propagate_constant.hpp> #include <migraphx/propagate_constant.hpp>
#include <migraphx/eliminate_contiguous.hpp> #include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/common_subexpression_elimination.hpp> #include <migraphx/common_subexpression_elimination.hpp>
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp> #include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/rewrite_rnn.hpp> #include <migraphx/rewrite_rnn.hpp>
#include <migraphx/rewrite_pooling.hpp> #include <migraphx/rewrite_pooling.hpp>
#include <migraphx/eliminate_concat.hpp> #include <migraphx/eliminate_concat.hpp>
...@@ -44,13 +44,13 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const ...@@ -44,13 +44,13 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
eliminate_identity{}, eliminate_identity{},
eliminate_pad{}, eliminate_pad{},
dead_code_elimination{}, dead_code_elimination{},
fwd_conv_batchnorm_rewrite{}, rewrite_batchnorm{},
dead_code_elimination{}, dead_code_elimination{},
rewrite_rnn{}, rewrite_rnn{},
rewrite_pooling{}, rewrite_pooling{},
dead_code_elimination{}, dead_code_elimination{},
//common_subexpression_elimination{}, // common_subexpression_elimination{},
//dead_code_elimination{}, // dead_code_elimination{},
simplify_algebra{}, simplify_algebra{},
dead_code_elimination{}, dead_code_elimination{},
auto_contiguous{}, auto_contiguous{},
......
...@@ -26,7 +26,6 @@ struct tf_parser ...@@ -26,7 +26,6 @@ struct tf_parser
{ {
using attribute_map = std::unordered_map<std::string, tensorflow::AttrValue>; using attribute_map = std::unordered_map<std::string, tensorflow::AttrValue>;
using node_map = std::map<std::string, tensorflow::NodeDef>; using node_map = std::map<std::string, tensorflow::NodeDef>;
// using input_node_map = std::unordered_map<std::string, std::unordered_set<std::string>>;
using op_func = std::function<instruction_ref(attribute_map, std::vector<instruction_ref>)>; using op_func = std::function<instruction_ref(attribute_map, std::vector<instruction_ref>)>;
node_map nodes; node_map nodes;
...@@ -149,9 +148,26 @@ struct tf_parser ...@@ -149,9 +148,26 @@ struct tf_parser
return axes; return axes;
} }
std::vector<int64_t> get_axes_from_mask(const size_t num_axes, const uint32_t mask)
{
uint32_t bitwise_compare = 1;
std::vector<int64_t> axes;
for(size_t i = 0; i < num_axes; i++)
{
// the LSB corresponds to axis 0 when determining which axes to begin
if(((mask >> i) & bitwise_compare) == 1)
axes.push_back(1);
else
axes.push_back(0);
}
return axes;
}
tf_parser() tf_parser()
{ {
add_generic_op("All", op::identity{});
add_generic_op("Identity", op::identity{}); add_generic_op("Identity", op::identity{});
add_generic_op("LessEqual", op::identity{});
add_generic_op("Relu", op::relu{}); add_generic_op("Relu", op::relu{});
add_generic_op("Relu6", op::clip{6.0, 0.0}); add_generic_op("Relu6", op::clip{6.0, 0.0});
add_generic_op("Rsqrt", op::rsqrt{}); add_generic_op("Rsqrt", op::rsqrt{});
...@@ -166,6 +182,7 @@ struct tf_parser ...@@ -166,6 +182,7 @@ struct tf_parser
add_mem_op("AvgPool", &tf_parser::parse_pooling); add_mem_op("AvgPool", &tf_parser::parse_pooling);
add_mem_op("BatchMatMul", &tf_parser::parse_matmul, false); add_mem_op("BatchMatMul", &tf_parser::parse_matmul, false);
add_mem_op("BatchMatMulV2", &tf_parser::parse_matmul, false);
add_mem_op("BiasAdd", &tf_parser::parse_biasadd); add_mem_op("BiasAdd", &tf_parser::parse_biasadd);
add_mem_op("Cast", &tf_parser::parse_cast, false); add_mem_op("Cast", &tf_parser::parse_cast, false);
add_mem_op("ConcatV2", &tf_parser::parse_concat, false); add_mem_op("ConcatV2", &tf_parser::parse_concat, false);
...@@ -177,14 +194,15 @@ struct tf_parser ...@@ -177,14 +194,15 @@ struct tf_parser
add_mem_op("GatherV2", &tf_parser::parse_gather, false); add_mem_op("GatherV2", &tf_parser::parse_gather, false);
add_mem_op("MatMul", &tf_parser::parse_matmul, false); add_mem_op("MatMul", &tf_parser::parse_matmul, false);
add_mem_op("MaxPool", &tf_parser::parse_pooling); add_mem_op("MaxPool", &tf_parser::parse_pooling);
add_mem_op("Mean", &tf_parser::parse_mean); add_mem_op("Mean", &tf_parser::parse_mean, false);
add_mem_op("OneHot", &tf_parser::parse_onehot, false);
add_mem_op("Pack", &tf_parser::parse_pack, false); add_mem_op("Pack", &tf_parser::parse_pack, false);
add_mem_op("Pad", &tf_parser::parse_pad); add_mem_op("Pad", &tf_parser::parse_pad);
add_mem_op("Reshape", &tf_parser::parse_reshape, false); add_mem_op("Reshape", &tf_parser::parse_reshape, false);
add_mem_op("Slice", &tf_parser::parse_slice, false); add_mem_op("Slice", &tf_parser::parse_slice, false);
add_mem_op("Softmax", &tf_parser::parse_softmax<op::softmax>); add_mem_op("Softmax", &tf_parser::parse_softmax<op::softmax>, false);
add_mem_op("Squeeze", &tf_parser::parse_squeeze, false); add_mem_op("Squeeze", &tf_parser::parse_squeeze, false);
add_mem_op("StridedSlice", &tf_parser::parse_stridedslice); add_mem_op("StridedSlice", &tf_parser::parse_stridedslice, false);
add_mem_op("Transpose", &tf_parser::parse_transpose, false); add_mem_op("Transpose", &tf_parser::parse_transpose, false);
} }
...@@ -547,7 +565,7 @@ struct tf_parser ...@@ -547,7 +565,7 @@ struct tf_parser
} }
if(contains(attributes, "transpose_b")) if(contains(attributes, "transpose_b"))
{ {
transb = attributes.at("transpose_a").b(); transb = attributes.at("transpose_b").b();
} }
if(contains(attributes, "adj_x")) if(contains(attributes, "adj_x"))
...@@ -574,8 +592,7 @@ struct tf_parser ...@@ -574,8 +592,7 @@ struct tf_parser
parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
bool keep_dims = attributes.at("keep_dims").b(); bool keep_dims = attributes.at("keep_dims").b();
auto lens = args[0]->get_shape().lens(); auto axes = args[1]->eval().get<int32_t>().to_vector<int64_t>();
auto axes = parse_axes(args[1]->eval().get<int32_t>().to_vector<int64_t>(), lens.size());
if(keep_dims) if(keep_dims)
{ {
...@@ -588,6 +605,32 @@ struct tf_parser ...@@ -588,6 +605,32 @@ struct tf_parser
} }
} }
instruction_ref
parse_onehot(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
size_t depth = static_cast<size_t>(args[1]->eval().at<int32_t>());
int64_t axis = -1;
float on_value = args[2]->eval().at<float>();
float off_value = args[3]->eval().at<float>();
std::vector<float> depth_input(depth * depth, off_value);
for(int i = 0; i < depth; i++)
{
depth_input[depth * i + i] = on_value;
}
if(contains(attributes, "axis"))
axis = attributes.at("axis").i();
if(axis == -1)
{
shape s{shape::float_type, {depth, depth}};
auto l0 = prog.add_literal({s, depth_input});
return prog.add_instruction(op::gather{0}, {l0, args[0]});
}
MIGRAPHX_THROW("MIGraphX does not support axis != -1");
}
instruction_ref parse_pack(const std::string&, instruction_ref parse_pack(const std::string&,
const attribute_map& attributes, const attribute_map& attributes,
std::vector<instruction_ref> args) std::vector<instruction_ref> args)
...@@ -799,21 +842,50 @@ struct tf_parser ...@@ -799,21 +842,50 @@ struct tf_parser
std::vector<instruction_ref> args) std::vector<instruction_ref> args)
{ {
op::slice op; op::slice op;
auto starts = args[1]->eval().get<int32_t>().to_vector(); auto starts = args[1]->eval().get<int32_t>().to_vector();
auto ends = args[2]->eval().get<int32_t>().to_vector(); auto ends = args[2]->eval().get<int32_t>().to_vector();
size_t num_axes = args[0]->get_shape().lens().size(); auto l0 = args[0];
size_t num_axes = l0->get_shape().lens().size();
std::vector<size_t> axes = l0->get_shape().lens();
op.starts = std::vector<int64_t>(starts.begin(), starts.end()); op.starts = std::vector<int64_t>(starts.begin(), starts.end());
op.ends = std::vector<int64_t>(ends.begin(), ends.end()); op.ends = std::vector<int64_t>(ends.begin(), ends.end());
op.axes = std::vector<int64_t>(num_axes); op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0); std::iota(op.axes.begin(), op.axes.end(), 0);
uint32_t begin_mask = 0;
uint32_t end_mask = 0;
uint32_t shrink_axis_mask = 0; uint32_t shrink_axis_mask = 0;
uint32_t bitwise_compare = 1; uint32_t bitwise_compare = 1;
std::vector<int64_t> squeeze_axes; std::vector<int64_t> squeeze_axes;
if(contains(attributes, "begin_mask"))
begin_mask = static_cast<uint32_t>(attributes.at("begin_mask").i());
if(contains(attributes, "end_mask"))
end_mask = static_cast<uint32_t>(attributes.at("end_mask").i());
if(contains(attributes, "shrink_axis_mask")) if(contains(attributes, "shrink_axis_mask"))
shrink_axis_mask = static_cast<uint32_t>(attributes.at("shrink_axis_mask").i()); shrink_axis_mask = static_cast<uint32_t>(attributes.at("shrink_axis_mask").i());
std::vector<int64_t> begin_axes = get_axes_from_mask(num_axes, begin_mask);
std::vector<int64_t> end_axes = get_axes_from_mask(num_axes, end_mask);
for(size_t i = 0; i < num_axes; i++)
{
if(begin_axes.at(i) == 1)
{
op.starts.at(i) = 0;
}
if(end_axes.at(i) == 1)
{
op.ends.at(i) = axes.at(i);
}
}
auto l1 = prog.add_instruction(op, l0);
if(shrink_axis_mask == 0)
return l1;
for(size_t i = 0; i < num_axes; i++) for(size_t i = 0; i < num_axes; i++)
{ {
// the LSB corresponds to axis 0 when determining which axes to squeeze // the LSB corresponds to axis 0 when determining which axes to squeeze
...@@ -821,8 +893,7 @@ struct tf_parser ...@@ -821,8 +893,7 @@ struct tf_parser
squeeze_axes.push_back(i); squeeze_axes.push_back(i);
} }
auto l0 = prog.add_instruction(op, make_contiguous(args[0])); return prog.add_instruction(op::squeeze{squeeze_axes}, l1);
return to_nhwc(prog.add_instruction(op::squeeze{squeeze_axes}, l0));
} }
instruction_ref instruction_ref
...@@ -862,10 +933,16 @@ struct tf_parser ...@@ -862,10 +933,16 @@ struct tf_parser
if(instructions.count(name) == 0) if(instructions.count(name) == 0)
{ {
auto&& node = nodes.at(name); auto&& node = nodes.at(name);
// assert ops ignored
if(node.op() == "Assert" or contains(name, "Assert"))
return;
std::vector<instruction_ref> args; std::vector<instruction_ref> args;
for(auto&& input : node.input()) for(auto&& input : node.input())
{ {
// control dependencies (signified by ^ before the name) are ignored
if(contains(input, "^"))
continue;
if(nodes.count(input) > 0) if(nodes.count(input) > 0)
{ {
auto&& iname = get_name(nodes.at(input)); auto&& iname = get_name(nodes.at(input));
......
...@@ -502,6 +502,24 @@ struct test_triadd2 : verify_program<test_triadd2> ...@@ -502,6 +502,24 @@ struct test_triadd2 : verify_program<test_triadd2>
} }
}; };
struct test_mul_add : verify_program<test_mul_add>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::shape bs{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", s);
auto a = p.add_parameter("a", bs);
auto b = p.add_parameter("b", bs);
auto ab = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, a);
auto bb = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, b);
auto mul = p.add_instruction(migraphx::op::mul{}, x, ab);
p.add_instruction(migraphx::op::add{}, mul, bb);
return p;
}
};
struct test_add_broadcast : verify_program<test_add_broadcast> struct test_add_broadcast : verify_program<test_add_broadcast>
{ {
migraphx::program create_program() const migraphx::program create_program() const
......
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
namespace match = migraphx::match; namespace match = migraphx::match;
MIGRAPHX_PRED_MATCHER(throws, migraphx::instruction_ref) { MIGRAPHX_THROW("Matcher throws"); }
template <class M> template <class M>
migraphx::match::matcher_result find_match(migraphx::program& p, M&& m) migraphx::match::matcher_result find_match(migraphx::program& p, M&& m)
{ {
...@@ -331,6 +333,81 @@ TEST_CASE(match_either_args3) ...@@ -331,6 +333,81 @@ TEST_CASE(match_either_args3)
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == p.end()});
} }
TEST_CASE(match_either_args_any1)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, two);
p.add_instruction(pass_op{}, sum2);
auto m =
match::name("sum")(match::either_arg(0, 1)(match::any().bind("x"), match::any().bind("y")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}
TEST_CASE(match_either_args_any2)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, two);
p.add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(
match::either_arg(0, 1)(match::any().bind("x"), match::name("@literal").bind("y")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}
TEST_CASE(match_either_args_any3)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, two);
p.add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(
match::either_arg(0, 1)(match::name("@literal").bind("x"), match::any().bind("y")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}
TEST_CASE(match_either_args_any4)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, two);
p.add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(
match::either_arg(0, 1)(match::name("sum").bind("x"), match::any().bind("y")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum2});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}
TEST_CASE(match_either_args_any5)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, two);
p.add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(
match::either_arg(0, 1)(match::any().bind("x"), match::name("sum").bind("y")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum2});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}
TEST_CASE(match_all_of1) TEST_CASE(match_all_of1)
{ {
migraphx::program p; migraphx::program p;
...@@ -370,6 +447,36 @@ TEST_CASE(match_all_of3) ...@@ -370,6 +447,36 @@ TEST_CASE(match_all_of3)
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
TEST_CASE(match_lazy_any_of)
{
migraphx::program p;
auto one = p.add_literal(1);
p.add_instruction(pass_op{}, one);
auto m = match::any_of(match::any(), throws());
auto r = find_match(p, m);
EXPECT(bool{r.result == one});
}
TEST_CASE(match_lazy_all_of)
{
migraphx::program p;
auto one = p.add_literal(1);
p.add_instruction(pass_op{}, one);
auto m = match::all_of(match::none(), throws());
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
TEST_CASE(match_lazy_none_of)
{
migraphx::program p;
auto one = p.add_literal(1);
p.add_instruction(pass_op{}, one);
auto m = match::none_of(match::any(), throws());
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
TEST_CASE(match_any_of1) TEST_CASE(match_any_of1)
{ {
migraphx::program p; migraphx::program p;
...@@ -396,6 +503,97 @@ TEST_CASE(match_any_of2) ...@@ -396,6 +503,97 @@ TEST_CASE(match_any_of2)
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == p.end()});
} }
TEST_CASE(match_any_of_lazy1)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(
match::any_of(match::args(match::any(), match::any()).bind("x"),
match::args(match::name("sum"), match::name("sum")).bind("y")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
EXPECT(migraphx::contains(r.instructions, "x"));
EXPECT(bool{r.instructions["x"] == sum});
EXPECT(not migraphx::contains(r.instructions, "y"));
}
TEST_CASE(match_any_of_lazy2)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(
match::any_of(match::args(match::name("@literal"), match::name("@literal")).bind("x"),
match::args(match::any(), match::any()).bind("y")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
EXPECT(migraphx::contains(r.instructions, "x"));
EXPECT(bool{r.instructions["x"] == sum});
EXPECT(not migraphx::contains(r.instructions, "y"));
}
TEST_CASE(match_any_of_lazy3)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(
match::any_of(match::args(match::any(), match::any()).bind("x"),
match::args(match::name("@literal"), match::name("@literal")).bind("y")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
EXPECT(migraphx::contains(r.instructions, "x"));
EXPECT(bool{r.instructions["x"] == sum});
EXPECT(not migraphx::contains(r.instructions, "y"));
}
TEST_CASE(match_any_of_lazy4)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::any_of(
match::args(match::name("@literal").bind("x1"), match::name("@literal").bind("y1")),
match::args(match::any().bind("x2"), match::any().bind("y2"))));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
EXPECT(migraphx::contains(r.instructions, "x1"));
EXPECT(migraphx::contains(r.instructions, "y1"));
EXPECT(bool{r.instructions["x1"] == one});
EXPECT(bool{r.instructions["y1"] == two});
EXPECT(not migraphx::contains(r.instructions, "x2"));
EXPECT(not migraphx::contains(r.instructions, "y2"));
}
TEST_CASE(match_any_of_lazy5)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::any_of(
match::args(match::any().bind("x1"), match::any().bind("y1")),
match::args(match::name("@literal").bind("x2"), match::name("@literal").bind("y2"))));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
EXPECT(migraphx::contains(r.instructions, "x1"));
EXPECT(migraphx::contains(r.instructions, "y1"));
EXPECT(bool{r.instructions["x1"] == one});
EXPECT(bool{r.instructions["y1"] == two});
EXPECT(not migraphx::contains(r.instructions, "x2"));
EXPECT(not migraphx::contains(r.instructions, "y2"));
}
TEST_CASE(match_none_of1) TEST_CASE(match_none_of1)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include "test.hpp" #include "test.hpp"
...@@ -514,6 +515,32 @@ TEST_CASE(shape_gather_test) ...@@ -514,6 +515,32 @@ TEST_CASE(shape_gather_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(transpose_gather_test)
{
migraphx::program p;
auto make_contiguous = [&p](migraphx::instruction_ref ins) {
if(ins->get_shape().standard())
{
return ins;
}
return p.add_instruction(migraphx::op::contiguous{}, ins);
};
auto data = p.add_parameter("data", migraphx::shape{migraphx::shape::float_type, {3, 5, 4, 6}});
auto ind =
p.add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 4, 3, 5}});
auto tr_data = p.add_instruction(migraphx::op::transpose{{0, 2, 1, 3}}, data);
auto tr_ind = p.add_instruction(migraphx::op::transpose{{0, 2, 1, 3}}, ind);
int axis = 1;
p.add_instruction(
migraphx::op::gather{axis}, make_contiguous(tr_data), make_contiguous(tr_ind));
auto prog = migraphx::parse_onnx("transpose_gather.onnx");
EXPECT(p == prog);
}
TEST_CASE(flatten_test) TEST_CASE(flatten_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment