Commit 6b36c82e authored by Khalique's avatar Khalique
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into test_gen_scripts

parents 34150e61 01ff5321
...@@ -12,7 +12,7 @@ add_library(migraphx ...@@ -12,7 +12,7 @@ add_library(migraphx
eliminate_concat.cpp eliminate_concat.cpp
eliminate_identity.cpp eliminate_identity.cpp
eliminate_pad.cpp eliminate_pad.cpp
fwd_conv_batchnorm_rewrite.cpp rewrite_batchnorm.cpp
rewrite_rnn.cpp rewrite_rnn.cpp
rewrite_pooling.cpp rewrite_pooling.cpp
env.cpp env.cpp
......
...@@ -74,7 +74,7 @@ auto bind_match(M m, std::string name) ...@@ -74,7 +74,7 @@ auto bind_match(M m, std::string name)
[ =, name = std::move(name) ](matcher_context & ctx, instruction_ref ins) { [ =, name = std::move(name) ](matcher_context & ctx, instruction_ref ins) {
auto result = m.match(ctx, ins); auto result = m.match(ctx, ins);
if(result != ctx.not_found()) if(result != ctx.not_found())
ctx.instructions.emplace(name, ins); ctx.instructions[name] = ins;
return result; return result;
}); });
} }
...@@ -240,6 +240,21 @@ void find_matches(program& p, Ms&&... ms) ...@@ -240,6 +240,21 @@ void find_matches(program& p, Ms&&... ms)
} }
} }
template <class M>
struct find_skip
{
M m;
M matcher() const { return m; }
void apply(program&, const matcher_result&) const {}
};
template <class M>
find_skip<M> make_find_skip(M m)
{
return {m};
}
struct lazy_and struct lazy_and
{ {
template <class F, class G> template <class F, class G>
...@@ -311,6 +326,12 @@ const constexpr auto all_of = match_fold_f<lazy_and, true, true>{}; ...@@ -311,6 +326,12 @@ const constexpr auto all_of = match_fold_f<lazy_and, true, true>{};
const constexpr auto any_of = match_fold_f<lazy_or, false, true>{}; const constexpr auto any_of = match_fold_f<lazy_or, false, true>{};
const constexpr auto none_of = match_fold_f<lazy_or, false, false>{}; const constexpr auto none_of = match_fold_f<lazy_or, false, false>{};
template <class... Ms>
auto skip_matches(Ms... ms)
{
return make_find_skip(any_of(ms...));
}
inline auto inputs() inline auto inputs()
{ {
return [](auto ins, auto f) { return [](auto ins, auto f) {
...@@ -369,6 +390,50 @@ MIGRAPHX_BASIC_MATCHER(used_once, const matcher_context& ctx, instruction_ref in ...@@ -369,6 +390,50 @@ MIGRAPHX_BASIC_MATCHER(used_once, const matcher_context& ctx, instruction_ref in
return ctx.not_found(); return ctx.not_found();
} }
inline auto used_once_recursive(std::size_t depth)
{
return make_basic_fun_matcher([=](const matcher_context& ctx, instruction_ref start) {
// Used once
if(start->outputs().size() == 1)
return start;
// Unused
if(start->outputs().empty())
{
if(std::next(start) == ctx.not_found())
return start;
else
return ctx.not_found();
}
// Check for dead instructions
auto is_dead = fix<bool>([&](auto self, auto ins, auto n) {
if(n == 0)
return false;
if(ins->get_shape().elements() == 0)
return false;
if(ins->outputs().empty() and std::next(ins) != ctx.not_found())
return true;
return std::all_of(ins->outputs().begin(), ins->outputs().end(), [&](auto i) {
return self(i, n - 1);
});
});
auto dead = std::count_if(start->outputs().begin(), start->outputs().end(), [&](auto i) {
return is_dead(i, depth);
});
if(dead + 1 == start->outputs().size())
return start;
return ctx.not_found();
});
}
MIGRAPHX_PRED_MATCHER(is_constant, instruction_ref ins) { return ins->can_eval(); }
MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref ins)
{
if(ins->outputs().empty() and ins != std::prev(ctx.not_found()))
return ins;
return ctx.not_found();
}
template <class... Ms> template <class... Ms>
auto skip_output(Ms... ms) auto skip_output(Ms... ms)
{ {
...@@ -404,6 +469,12 @@ inline auto name(std::unordered_set<std::string> names) ...@@ -404,6 +469,12 @@ inline auto name(std::unordered_set<std::string> names)
}); });
} }
template <class... Ts>
inline auto name(std::string s, Ts... xs) // NOLINT
{
return name(std::unordered_set<std::string>{std::move(s), std::move(xs)...});
}
inline auto nargs(std::size_t n) inline auto nargs(std::size_t n)
{ {
return make_basic_pred_matcher([=](instruction_ref ins) { return ins->inputs().size() == n; }); return make_basic_pred_matcher([=](instruction_ref ins) { return ins->inputs().size() == n; });
......
...@@ -13,9 +13,9 @@ struct program; ...@@ -13,9 +13,9 @@ struct program;
/** /**
* Rewrite batchnorm to a multiply and add. * Rewrite batchnorm to a multiply and add.
*/ */
struct fwd_conv_batchnorm_rewrite struct rewrite_batchnorm
{ {
std::string name() const { return "fwd_conv_batchnorm_rewrite"; } std::string name() const { return "rewrite_batchnorm"; }
void apply(program& p) const; void apply(program& p) const;
}; };
......
...@@ -375,7 +375,7 @@ struct onnx_parser ...@@ -375,7 +375,7 @@ struct onnx_parser
if(args.size() == 3) if(args.size() == 3)
{ {
uint64_t axis = 1; uint64_t axis = 1;
auto l1 = prog.add_instruction(op, args[0], args[1]); auto l1 = prog.add_instruction(op, l0, args[1]);
auto l2 = prog.add_instruction(op::broadcast{axis, l1->get_shape().lens()}, args[2]); auto l2 = prog.add_instruction(op::broadcast{axis, l1->get_shape().lens()}, args[2]);
return prog.add_instruction(op::add{}, l1, l2); return prog.add_instruction(op::add{}, l1, l2);
} }
......
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp> #include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/op/batch_norm.hpp> #include <migraphx/op/batch_norm.hpp>
#include <migraphx/op/broadcast.hpp> #include <migraphx/op/broadcast.hpp>
#include <migraphx/op/add.hpp> #include <migraphx/op/add.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
...@@ -11,7 +12,7 @@ ...@@ -11,7 +12,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void fwd_conv_batchnorm_rewrite::apply(program& p) const void rewrite_batchnorm::apply(program& p) const
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
...@@ -25,46 +26,30 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const ...@@ -25,46 +26,30 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
if(any_of({gamma, bias, mean, variance}, [](auto arg) { return arg.empty(); })) if(any_of({gamma, bias, mean, variance}, [](auto arg) { return arg.empty(); }))
continue; continue;
auto conv_ins = ins->inputs()[0]; auto s = shape{ins->get_shape().type(), {ins->get_shape().lens()[1]}};
if(conv_ins->name() != "convolution")
continue;
// Get convolution weights
auto weights = conv_ins->inputs()[1]->eval();
if(weights.empty())
continue;
// Get epsilon // Get epsilon
auto bn_op = any_cast<op::batch_norm_inference>(ins->get_operator()); auto bn_op = any_cast<op::batch_norm_inference>(ins->get_operator());
auto epsilon = bn_op.epsilon; auto epsilon = bn_op.epsilon;
// Get convolution op
auto conv_op = conv_ins->get_operator(); argument a{s};
auto weights_lens = weights.get_shape().lens(); argument b{s};
auto conv_lens = conv_ins->get_shape().lens(); visit_all(gamma, bias, mean, variance, a, b)(
argument new_weights{weights.get_shape()}; [&](auto gamma2, auto bias2, auto mean2, auto variance2, auto a2, auto b2) {
argument new_bias{{bias.get_shape().type(), {bias.get_shape().elements()}}}; dfor(a.get_shape().elements())(
visit_all(weights, gamma, bias, mean, variance, new_weights, new_bias)( [&](std::size_t c) { a2[c] = gamma2[c] / std::sqrt(variance2[c] + epsilon); });
[&](auto weights2, dfor(b.get_shape().elements())([&](std::size_t c) {
auto gamma2, b2[c] = bias2[c] - (gamma2[c] * mean2[c] / std::sqrt(variance2[c] + epsilon));
auto bias2,
auto mean2,
auto variance2,
auto new_weights2,
auto new_bias2) {
dfor(weights_lens[0], weights_lens[1], weights_lens[2], weights_lens[3])(
[&](std::size_t k, std::size_t c, std::size_t h, std::size_t w) {
new_weights2(k, c, h, w) =
gamma2[k] / std::sqrt(variance2[k] + epsilon) * weights2(k, c, h, w);
});
dfor(new_bias.get_shape().elements())([&](std::size_t c) {
new_bias2[c] =
bias2[c] - (gamma2[c] * mean2[c] / std::sqrt(variance2[c] + epsilon));
}); });
}); });
// Replace convolution instruction with updated weights
auto l_weights = p.add_literal({weights.get_shape(), new_weights.data()}); auto broadcast = op::broadcast{1, ins->get_shape().lens()};
auto l_bias = p.add_literal({new_bias.get_shape(), new_bias.data()}); auto a_ins = p.add_literal({a.get_shape(), a.data()});
auto c = p.replace_instruction(conv_ins, conv_op, {conv_ins->inputs()[0], l_weights}); auto a_broadcast = p.insert_instruction(ins, broadcast, a_ins);
auto b = p.insert_instruction(ins, op::broadcast{1, c->get_shape().lens()}, l_bias); auto mul = p.insert_instruction(ins, op::mul{}, ins->inputs().front(), a_broadcast);
p.replace_instruction(ins, op::add{}, {c, b}); auto b_ins = p.add_literal({b.get_shape(), b.data()});
auto b_broadcast = p.insert_instruction(ins, broadcast, b_ins);
auto add = p.insert_instruction(ins, op::add{}, mul, b_broadcast);
p.replace_instruction(ins, add);
} }
} }
......
#include <migraphx/simplify_algebra.hpp> #include <migraphx/simplify_algebra.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/op/add.hpp> #include <migraphx/op/add.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct find_add_lit_broadcast auto lit_broadcast() { return match::any_of(match::is_constant(), match::name("broadcast")); }
auto not_lit_broadcast() { return match::none_of(match::is_constant(), match::name("broadcast")); }
auto op_lit_broadcast(std::string op, std::string x, std::string y)
{
return match::name(std::move(op))(match::either_arg(0, 1)(
lit_broadcast().bind(std::move(x)), not_lit_broadcast().bind(std::move(y))));
}
auto conv_const_weights()
{
return match::name("convolution")(match::used_once(),
match::args(match::any(), match::is_constant().bind("w")));
}
struct find_mul_conv
{ {
auto lit_broadcast() const auto matcher() const
{ {
return match::any_of(match::name("@literal"), match::name("broadcast")); return match::name("mul")(match::either_arg(0, 1)(conv_const_weights().bind("conv"),
match::name("broadcast").bind("a")));
} }
auto not_lit_broadcast() const
void apply(program& p, match::matcher_result r) const
{ {
return match::none_of(match::name("@literal"), match::name("broadcast")); auto ins = r.result;
auto conv_ins = r.instructions["conv"];
auto a_ins = r.instructions["a"];
auto w_ins = r.instructions["w"];
auto broadcast_op = any_cast<op::broadcast>(a_ins->get_operator());
if(broadcast_op.axis != 1)
return;
auto new_a = p.insert_instruction(
ins, op::broadcast{0, w_ins->get_shape().lens()}, a_ins->inputs().front());
auto new_mul = p.insert_instruction(ins, op::mul{}, new_a, w_ins);
auto new_conv = p.insert_instruction(
ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_mul);
p.replace_instruction(ins, new_conv);
} }
auto add_lit_broadcast(std::string x, std::string y) const };
// a * (x + b) => a * x + a * b
struct find_mul_add
{
auto matcher() const
{ {
return match::name("add")(match::either_arg(0, 1)(lit_broadcast().bind(std::move(x)), return match::name("mul")(match::either_arg(0, 1)(
not_lit_broadcast().bind(std::move(y)))); match::name("add")(
match::either_arg(0, 1)(
match::any().bind("x"),
match::any_of(conv_const_weights(), match::is_constant()).bind("b")),
match::none_of(match::args(match::is_constant(), match::is_constant())),
match::used_once()),
match::is_constant().bind("a")));
} }
void apply(program& p, match::matcher_result r) const
{
auto ins = r.result;
auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"];
auto x_ins = r.instructions["x"];
assert(x_ins != b_ins);
auto ax_ins = p.insert_instruction(ins, op::mul{}, a_ins, x_ins);
auto ab_ins = p.insert_instruction(ins, op::mul{}, a_ins, b_ins);
p.replace_instruction(ins, op::add{}, ax_ins, ab_ins);
}
};
struct find_add_lit_broadcast
{
auto matcher() const auto matcher() const
{ {
return match::name("add")( return match::name("add")(
match::args(add_lit_broadcast("a", "x"), add_lit_broadcast("b", "y"))); match::either_arg(0, 1)(op_lit_broadcast("add", "a", "x"), lit_broadcast().bind("b")));
}
void apply(program& p, match::matcher_result r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"];
auto sumab = p.insert_instruction(ins, op::add{}, a_ins, b_ins);
p.replace_instruction(ins, op::add{}, x_ins, sumab);
}
};
struct find_double_add_lit_broadcast
{
auto matcher() const
{
return match::name("add")(
match::args(op_lit_broadcast("add", "a", "x"), op_lit_broadcast("add", "b", "y")));
} }
void apply(program& p, match::matcher_result r) const void apply(program& p, match::matcher_result r) const
...@@ -36,11 +117,9 @@ struct find_add_lit_broadcast ...@@ -36,11 +117,9 @@ struct find_add_lit_broadcast
auto a_ins = r.instructions["a"]; auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"]; auto b_ins = r.instructions["b"];
if(a_ins->name() != b_ins->name())
return;
instruction_ref sumab; instruction_ref sumab;
if(a_ins->name() == "broadcast") if(a_ins->name() == "broadcast" and b_ins->name() == "broadcast")
{ {
if(a_ins->inputs().at(0)->get_shape() != b_ins->inputs().at(0)->get_shape()) if(a_ins->inputs().at(0)->get_shape() != b_ins->inputs().at(0)->get_shape())
return; return;
...@@ -59,7 +138,46 @@ struct find_add_lit_broadcast ...@@ -59,7 +138,46 @@ struct find_add_lit_broadcast
} }
}; };
void simplify_algebra::apply(program& p) const { match::find_matches(p, find_add_lit_broadcast{}); } struct find_inner_broadcast
{
auto matcher() const
{
return match::name("mul", "add")(
match::args(match::name("broadcast").bind("x"), match::name("broadcast").bind("y")));
}
void apply(program& p, match::matcher_result r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
auto y_ins = r.instructions["y"];
auto xbroadcast = any_cast<op::broadcast>(x_ins->get_operator());
auto ybroadcast = any_cast<op::broadcast>(y_ins->get_operator());
if(xbroadcast.axis != ybroadcast.axis)
return;
auto op = p.insert_instruction(
ins, ins->get_operator(), x_ins->inputs().front(), y_ins->inputs().front());
p.replace_instruction(ins, xbroadcast, op);
}
};
void simplify_algebra::apply(program& p) const
{
// Run simplifications multiple times
for(int i = 0; i < 4; i++)
{
match::find_matches(p,
find_inner_broadcast{},
find_double_add_lit_broadcast{},
find_add_lit_broadcast{},
find_mul_conv{},
find_mul_add{});
dead_code_elimination{}.apply(p);
}
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -16,6 +16,7 @@ add_library(migraphx_device ...@@ -16,6 +16,7 @@ add_library(migraphx_device
device/argmin.cpp device/argmin.cpp
device/max.cpp device/max.cpp
device/min.cpp device/min.cpp
device/mul_add.cpp
device/exp.cpp device/exp.cpp
device/erf.cpp device/erf.cpp
device/log.cpp device/log.cpp
......
...@@ -6,6 +6,16 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -6,6 +6,16 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void mul_add_relu(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{
nary(stream, result, arg1, arg2, arg3)(
[](auto x, auto a, auto b) { return std::max<decltype(a * x + b)>(0, a * x + b); });
}
void add_relu(hipStream_t stream, void add_relu(hipStream_t stream,
const argument& result, const argument& result,
const argument& arg1, const argument& arg1,
......
...@@ -118,6 +118,111 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg ...@@ -118,6 +118,111 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg
}); });
} }
template <class F, class... Arguments>
void nary_double_broadcast_vec_impl(
hipStream_t stream, F f, argument result, argument barg1, argument barg2, Arguments... args)
{
assert(barg1.get_shape().broadcasted());
assert(barg2.get_shape().broadcasted());
assert(barg1.get_shape() == barg2.get_shape());
const auto& output_shape = result.get_shape();
const auto& b_shape = barg1.get_shape();
auto bdim =
std::distance(b_shape.strides().begin(),
std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) {
return x != 0;
}));
auto bdim_len = output_shape.lens()[bdim];
auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len;
const std::size_t vec_size = 4;
const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal;
const std::size_t bdim_vec_len = bdim_len / vec_size;
hip_vec_visit_all<vec_size>(result, barg1, barg2, args...)(
[&](auto output, auto binput1, auto binput2, auto... inputs) {
using type = typename decltype(output)::value_type;
const std::size_t nelements = output.size() / vec_size;
launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPHX_DEVICE_SHARED type buffer[2048 / vec_size];
// Load bias into LDS
for(size_t i = idx.local; i < bdim_vec_len; i += nlocal)
{
buffer[i] = binput1.data()[i];
}
for(size_t i = idx.local; i < bdim_vec_len; i += nlocal)
{
buffer[i + bdim_vec_len] = binput2.data()[i];
}
__syncthreads();
auto* bp = as_pointer(buffer);
// Process the data
for(size_t i = idx.global; i < nelements; i += nglobal)
{
auto bidx = ((i * vec_size) % bdim_next_stride) / bdim_stride;
auto b1 = bp[bidx];
auto b2 = bp[bidx + bdim_len];
auto out = output.data()[i];
for(std::size_t j = 0; j < vec_size; j++)
{
out[j] = f(inputs.data()[i][j]..., b2, b1);
}
output.data()[i] = out;
}
});
});
}
template <class F, class... Arguments>
void nary_double_broadcast_impl(
hipStream_t stream, F f, argument result, argument barg1, argument barg2, Arguments... args)
{
assert(barg1.get_shape().broadcasted());
assert(barg2.get_shape().broadcasted());
assert(barg1.get_shape() == barg2.get_shape());
const auto& output_shape = result.get_shape();
const auto& b_shape = barg1.get_shape();
auto bdim =
std::distance(b_shape.strides().begin(),
std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) {
return x != 0;
}));
auto bdim_len = output_shape.lens()[bdim];
auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len;
const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal;
std::size_t nelements = result.get_shape().elements();
hip_visit_all(result, barg1, barg2, args...)(
[&](auto output, auto binput1, auto binput2, auto... inputs) {
using type = typename decltype(output)::value_type;
launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPHX_DEVICE_SHARED type buffer[2048];
// Load bias into LDS
for(size_t i = idx.local; i < bdim_len; i += nlocal)
{
buffer[i] = binput1.data()[i];
}
for(size_t i = idx.local; i < bdim_len; i += nlocal)
{
buffer[i + bdim_len] = binput2.data()[i];
}
__syncthreads();
// Process the data
for(size_t i = idx.global; i < nelements; i += nglobal)
{
auto bidx = (i % bdim_next_stride) / bdim_stride;
auto b1 = buffer[bidx];
auto b2 = buffer[bidx + bdim_len];
output.data()[i] = f(inputs.data()[i]..., b2, b1);
}
});
});
}
template <class F, class... Arguments> template <class F, class... Arguments>
void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments... args) void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments... args)
{ {
...@@ -177,49 +282,113 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args) ...@@ -177,49 +282,113 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args)
} }
template <class... Arguments> template <class... Arguments>
auto nary(hipStream_t stream, argument result) bool broadcastable(bool& divisible_by_4,
std::size_t max_size,
const argument& result,
const argument& barg,
const Arguments&... args)
{
divisible_by_4 = false;
auto bshape = barg.get_shape();
const bool standard =
all_of({args.get_shape()...}, [](const shape& s) { return s.standard(); });
const bool same_shapes =
all_of({args.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); });
// TODO: Check result and args shape is the same
if(standard and same_shapes and bshape.broadcasted() and not bshape.scalar())
{
auto not_zero = [](auto x) { return x != 0; };
const auto& strides = bshape.strides();
auto b_it = std::find_if(strides.begin(), strides.end(), not_zero);
auto b_idx = std::distance(strides.begin(), b_it);
auto b_len = result.get_shape().lens()[b_idx];
auto b_stride = result.get_shape().strides()[b_idx];
assert(bshape.lens()[b_idx] == b_len);
if(b_len <= max_size and std::none_of(std::next(b_it), strides.end(), not_zero))
{
divisible_by_4 = (b_len % 4 == 0) and (b_stride % 4 == 0) and
(front_args(args...).get_shape().elements() % 4 == 0);
return true;
}
}
return false;
}
inline bool broadcastable(bool& divisible_by_4, std::size_t, const argument&, const argument&)
{
divisible_by_4 = false;
return false;
}
// Nullary
inline auto nary(hipStream_t stream, argument result)
{ {
return [=](auto f) { nary_standard_impl(stream, f, result); }; return [=](auto f) { nary_standard_impl(stream, f, result); };
} }
// Unary
inline auto nary(hipStream_t stream, argument result, argument arg)
{
return [=](auto f) { nary_impl(stream, f, result, arg); };
}
// Binary
inline auto nary(hipStream_t stream, argument result, argument arg, argument barg)
{
return [=](auto f) {
bool divisible_by_4 = false;
if(broadcastable(divisible_by_4, 2048, result, barg, arg))
{
if(divisible_by_4)
nary_broadcast_vec_impl(stream, f, result, barg, arg);
else
nary_broadcast_impl(stream, f, result, barg, arg);
}
else
{
nary_impl(stream, f, result, arg, barg);
}
};
}
template <class... Arguments> template <class... Arguments>
auto nary(hipStream_t stream, argument result, Arguments... args) auto nary(hipStream_t stream, argument result, Arguments... args)
{ {
static_assert(sizeof...(args) > 2, "Args needs to be greater than 2");
return [=](auto f) { return [=](auto f) {
auto barg = back_args(args...); auto barg1 = back_args(args...);
bool fallback = pop_back_args(args...)([&](auto&&... args2) { bool fallback1 = pop_back_args(args...)([&](auto&&... args2) {
auto bshape = barg.get_shape(); auto barg2 = back_args(args2...);
const bool standard = bool fallback2 =
all_of({args2.get_shape()...}, [](const shape& s) { return s.standard(); }); barg2.get_shape() != barg1.get_shape() or not barg2.get_shape().broadcasted() or
const bool same_shapes = all_of( pop_back_args(args2...)([&](auto&&... args3) {
{args2.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); }); bool divisible_by_4 = false;
// TODO: Check result and args shape is the same if(broadcastable(divisible_by_4, 1024, result, barg2, args3...))
if(standard and same_shapes and bshape.broadcasted() and not bshape.scalar()) {
if(divisible_by_4)
nary_double_broadcast_vec_impl(
stream, f, result, barg1, barg2, args3...);
else
nary_double_broadcast_impl(stream, f, result, barg1, barg2, args3...);
return false;
}
return true;
});
if(not fallback2)
return false;
bool divisible_by_4 = false;
if(broadcastable(divisible_by_4, 2048, result, barg1, args2...))
{ {
auto not_zero = [](auto x) { return x != 0; }; if(divisible_by_4)
const auto& strides = bshape.strides(); nary_broadcast_vec_impl(stream, f, result, barg1, args2...);
auto b_it = std::find_if(strides.begin(), strides.end(), not_zero); else
auto b_idx = std::distance(strides.begin(), b_it); nary_broadcast_impl(stream, f, result, barg1, args2...);
auto b_len = result.get_shape().lens()[b_idx]; return false;
auto b_stride = result.get_shape().strides()[b_idx];
assert(bshape.lens()[b_idx] == b_len);
if(b_len <= 2048 and std::none_of(std::next(b_it), strides.end(), not_zero))
{
const bool divisible_by_4 =
(b_len % 4 == 0) and (b_stride % 4 == 0) and
(front_args(args...).get_shape().elements() % 4 == 0);
if(divisible_by_4)
nary_broadcast_vec_impl(stream, f, result, barg, args2...);
else
nary_broadcast_impl(stream, f, result, barg, args2...);
return false;
}
} }
return true; return true;
}); });
if(fallback) if(fallback1)
nary_impl(stream, f, result, args...); nary_impl(stream, f, result, args...);
}; };
} }
......
#include <migraphx/gpu/device/add_relu.hpp>
#include <migraphx/gpu/device/nary.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void mul_add(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{
nary(stream, result, arg1, arg2, arg3)([](auto x, auto a, auto b) { return a * x + b; });
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/gpu/miopen.hpp> #include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/convolution.hpp> #include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/device/mul_add.hpp>
#include <migraphx/gpu/device/add_relu.hpp> #include <migraphx/gpu/device/add_relu.hpp>
#include <migraphx/gpu/device/add.hpp> #include <migraphx/gpu/device/add.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
...@@ -198,21 +199,62 @@ struct hip_add_relu ...@@ -198,21 +199,62 @@ struct hip_add_relu
} }
}; };
struct hip_mul_add
{
std::string name() const { return "hip::mul_add"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(4);
return inputs.front();
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::mul_add(ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2));
return args.at(3);
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
struct hip_mul_add_relu
{
std::string name() const { return "hip::mul_add_relu"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(4);
return inputs.front();
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::mul_add_relu(
ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2));
return args.at(3);
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
void move_broadcasted_back(std::vector<instruction_ref>& args) void move_broadcasted_back(std::vector<instruction_ref>& args)
{ {
// Ensure the last arguments is the broadcasted one // Ensure the last arguments is the broadcasted one
auto it = std::find_if( auto last = std::prev(args.end());
args.begin(), args.end(), [](auto arg) { return arg->get_shape().broadcasted(); }); auto it =
if(it != args.end()) std::find_if(args.begin(), last, [](auto arg) { return arg->get_shape().broadcasted(); });
std::swap(*it, *std::prev(args.end(), 2)); if(it != last)
std::swap(*it, *std::prev(last));
} }
void move_standard_front(std::vector<instruction_ref>& args) void move_standard_front(std::vector<instruction_ref>& args)
{ {
// Ensure the first arguments is the standard one // Ensure the first arguments is the standard one
auto it = std::find_if( auto last = std::prev(args.end());
args.begin(), args.end(), [](auto arg) { return arg->get_shape().standard(); }); auto it =
if(it != args.end()) std::find_if(args.begin(), last, [](auto arg) { return arg->get_shape().standard(); });
if(it != last)
std::swap(*it, args.front()); std::swap(*it, args.front());
} }
...@@ -220,11 +262,13 @@ struct find_add_relu ...@@ -220,11 +262,13 @@ struct find_add_relu
{ {
auto matcher() const auto matcher() const
{ {
return match::name("gpu::relu")( return match::name("gpu::relu")(match::arg(0)(
match::arg(0)(match::any_of(match::name("gpu::add"), match::used_once(),
match::name("hip::triadd"), match::any_of(match::name("gpu::add"),
match::any_of[match::inputs()](match::standard_shape())) match::name("hip::triadd"),
.bind("add"))); match::any_of(match::name("@literal"),
match::any_of[match::inputs()](match::standard_shape())))
.bind("add")));
} }
void apply(program& p, match::matcher_result r) const void apply(program& p, match::matcher_result r) const
...@@ -249,8 +293,10 @@ struct find_triadd ...@@ -249,8 +293,10 @@ struct find_triadd
auto matcher() const auto matcher() const
{ {
return match::name("gpu::add")(match::either_arg(0, 1)( return match::name("gpu::add")(match::either_arg(0, 1)(
match::name("gpu::add").bind("add"), match::name("gpu::add")(match::used_once()).bind("add"),
match::any(match::any_of[match::inputs()](match::standard_shape())).bind("input"))); match::any(match::any_of(match::name("@literal"),
match::any_of[match::inputs()](match::standard_shape())))
.bind("input")));
} }
void apply(program& p, match::matcher_result r) const void apply(program& p, match::matcher_result r) const
...@@ -273,6 +319,51 @@ struct find_triadd ...@@ -273,6 +319,51 @@ struct find_triadd
} }
}; };
struct find_mul_add
{
auto matcher() const
{
return match::name("gpu::add")(match::either_arg(0, 1)(
match::name("gpu::mul")(match::used_once()).bind("mul"), match::any().bind("b")));
}
void apply(program& p, match::matcher_result r) const
{
auto mul_ins = r.instructions["mul"];
auto b_ins = r.instructions["b"];
auto ins = r.result;
auto args = mul_ins->inputs();
assert(mul_ins != b_ins);
move_standard_front(args);
move_broadcasted_back(args);
args.insert(std::prev(args.end()), b_ins);
args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_mul_add{}, args);
}
};
struct find_mul_add_relu
{
auto matcher() const
{
return match::name("gpu::relu")(
match::arg(0)(match::name("hip::mul_add")(match::used_once()).bind("mul_add")));
}
void apply(program& p, match::matcher_result r) const
{
auto mul_add_ins = r.instructions["mul_add"];
auto ins = r.result;
auto args = mul_add_ins->inputs();
// Use the allocation from the relu operator
args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_mul_add_relu{}, args);
}
};
struct miopen_conv_bias struct miopen_conv_bias
{ {
op::convolution op; op::convolution op;
...@@ -428,6 +519,8 @@ void fuse_ops::apply(program& p) const ...@@ -428,6 +519,8 @@ void fuse_ops::apply(program& p) const
match::find_matches(p, match::find_matches(p,
find_conv_bias_relu{ctx}, find_conv_bias_relu{ctx},
find_conv_bias{ctx}, find_conv_bias{ctx},
find_mul_add{},
find_mul_add_relu{},
find_add_relu{} find_add_relu{}
); );
// clang-format on // clang-format on
......
...@@ -11,6 +11,12 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -11,6 +11,12 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void mul_add_relu(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3);
void add_relu(hipStream_t stream, void add_relu(hipStream_t stream,
const argument& result, const argument& result,
const argument& arg1, const argument& arg1,
......
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_MUL_ADD_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_MUL_ADD_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void mul_add(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include <migraphx/propagate_constant.hpp> #include <migraphx/propagate_constant.hpp>
#include <migraphx/eliminate_contiguous.hpp> #include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/common_subexpression_elimination.hpp> #include <migraphx/common_subexpression_elimination.hpp>
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp> #include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/rewrite_rnn.hpp> #include <migraphx/rewrite_rnn.hpp>
#include <migraphx/rewrite_pooling.hpp> #include <migraphx/rewrite_pooling.hpp>
#include <migraphx/eliminate_concat.hpp> #include <migraphx/eliminate_concat.hpp>
...@@ -44,13 +44,13 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const ...@@ -44,13 +44,13 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
eliminate_identity{}, eliminate_identity{},
eliminate_pad{}, eliminate_pad{},
dead_code_elimination{}, dead_code_elimination{},
fwd_conv_batchnorm_rewrite{}, rewrite_batchnorm{},
dead_code_elimination{}, dead_code_elimination{},
rewrite_rnn{}, rewrite_rnn{},
rewrite_pooling{}, rewrite_pooling{},
dead_code_elimination{}, dead_code_elimination{},
//common_subexpression_elimination{}, // common_subexpression_elimination{},
//dead_code_elimination{}, // dead_code_elimination{},
simplify_algebra{}, simplify_algebra{},
dead_code_elimination{}, dead_code_elimination{},
auto_contiguous{}, auto_contiguous{},
......
...@@ -502,6 +502,24 @@ struct test_triadd2 : verify_program<test_triadd2> ...@@ -502,6 +502,24 @@ struct test_triadd2 : verify_program<test_triadd2>
} }
}; };
struct test_mul_add : verify_program<test_mul_add>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::shape bs{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", s);
auto a = p.add_parameter("a", bs);
auto b = p.add_parameter("b", bs);
auto ab = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, a);
auto bb = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, b);
auto mul = p.add_instruction(migraphx::op::mul{}, x, ab);
p.add_instruction(migraphx::op::add{}, mul, bb);
return p;
}
};
struct test_add_broadcast : verify_program<test_add_broadcast> struct test_add_broadcast : verify_program<test_add_broadcast>
{ {
migraphx::program create_program() const migraphx::program create_program() const
......
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
namespace match = migraphx::match; namespace match = migraphx::match;
MIGRAPHX_PRED_MATCHER(throws, migraphx::instruction_ref) { MIGRAPHX_THROW("Matcher throws"); }
template <class M> template <class M>
migraphx::match::matcher_result find_match(migraphx::program& p, M&& m) migraphx::match::matcher_result find_match(migraphx::program& p, M&& m)
{ {
...@@ -331,6 +333,81 @@ TEST_CASE(match_either_args3) ...@@ -331,6 +333,81 @@ TEST_CASE(match_either_args3)
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == p.end()});
} }
TEST_CASE(match_either_args_any1)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, two);
p.add_instruction(pass_op{}, sum2);
auto m =
match::name("sum")(match::either_arg(0, 1)(match::any().bind("x"), match::any().bind("y")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}
TEST_CASE(match_either_args_any2)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, two);
p.add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(
match::either_arg(0, 1)(match::any().bind("x"), match::name("@literal").bind("y")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}
TEST_CASE(match_either_args_any3)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, two);
p.add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(
match::either_arg(0, 1)(match::name("@literal").bind("x"), match::any().bind("y")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}
TEST_CASE(match_either_args_any4)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, two);
p.add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(
match::either_arg(0, 1)(match::name("sum").bind("x"), match::any().bind("y")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum2});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}
TEST_CASE(match_either_args_any5)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, two);
p.add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(
match::either_arg(0, 1)(match::any().bind("x"), match::name("sum").bind("y")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum2});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}
TEST_CASE(match_all_of1) TEST_CASE(match_all_of1)
{ {
migraphx::program p; migraphx::program p;
...@@ -370,6 +447,36 @@ TEST_CASE(match_all_of3) ...@@ -370,6 +447,36 @@ TEST_CASE(match_all_of3)
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
TEST_CASE(match_lazy_any_of)
{
migraphx::program p;
auto one = p.add_literal(1);
p.add_instruction(pass_op{}, one);
auto m = match::any_of(match::any(), throws());
auto r = find_match(p, m);
EXPECT(bool{r.result == one});
}
TEST_CASE(match_lazy_all_of)
{
migraphx::program p;
auto one = p.add_literal(1);
p.add_instruction(pass_op{}, one);
auto m = match::all_of(match::none(), throws());
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
TEST_CASE(match_lazy_none_of)
{
migraphx::program p;
auto one = p.add_literal(1);
p.add_instruction(pass_op{}, one);
auto m = match::none_of(match::any(), throws());
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
TEST_CASE(match_any_of1) TEST_CASE(match_any_of1)
{ {
migraphx::program p; migraphx::program p;
...@@ -396,6 +503,97 @@ TEST_CASE(match_any_of2) ...@@ -396,6 +503,97 @@ TEST_CASE(match_any_of2)
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == p.end()});
} }
TEST_CASE(match_any_of_lazy1)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(
match::any_of(match::args(match::any(), match::any()).bind("x"),
match::args(match::name("sum"), match::name("sum")).bind("y")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
EXPECT(migraphx::contains(r.instructions, "x"));
EXPECT(bool{r.instructions["x"] == sum});
EXPECT(not migraphx::contains(r.instructions, "y"));
}
TEST_CASE(match_any_of_lazy2)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(
match::any_of(match::args(match::name("@literal"), match::name("@literal")).bind("x"),
match::args(match::any(), match::any()).bind("y")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
EXPECT(migraphx::contains(r.instructions, "x"));
EXPECT(bool{r.instructions["x"] == sum});
EXPECT(not migraphx::contains(r.instructions, "y"));
}
TEST_CASE(match_any_of_lazy3)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(
match::any_of(match::args(match::any(), match::any()).bind("x"),
match::args(match::name("@literal"), match::name("@literal")).bind("y")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
EXPECT(migraphx::contains(r.instructions, "x"));
EXPECT(bool{r.instructions["x"] == sum});
EXPECT(not migraphx::contains(r.instructions, "y"));
}
TEST_CASE(match_any_of_lazy4)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::any_of(
match::args(match::name("@literal").bind("x1"), match::name("@literal").bind("y1")),
match::args(match::any().bind("x2"), match::any().bind("y2"))));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
EXPECT(migraphx::contains(r.instructions, "x1"));
EXPECT(migraphx::contains(r.instructions, "y1"));
EXPECT(bool{r.instructions["x1"] == one});
EXPECT(bool{r.instructions["y1"] == two});
EXPECT(not migraphx::contains(r.instructions, "x2"));
EXPECT(not migraphx::contains(r.instructions, "y2"));
}
TEST_CASE(match_any_of_lazy5)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::any_of(
match::args(match::any().bind("x1"), match::any().bind("y1")),
match::args(match::name("@literal").bind("x2"), match::name("@literal").bind("y2"))));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
EXPECT(migraphx::contains(r.instructions, "x1"));
EXPECT(migraphx::contains(r.instructions, "y1"));
EXPECT(bool{r.instructions["x1"] == one});
EXPECT(bool{r.instructions["y1"] == two});
EXPECT(not migraphx::contains(r.instructions, "x2"));
EXPECT(not migraphx::contains(r.instructions, "y2"));
}
TEST_CASE(match_none_of1) TEST_CASE(match_none_of1)
{ {
migraphx::program p; migraphx::program p;
......
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp> #include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/cpu/target.hpp> #include <migraphx/cpu/target.hpp>
#include <migraphx/op/convolution.hpp> #include <migraphx/op/convolution.hpp>
...@@ -56,7 +56,7 @@ TEST_CASE(fwd_conv_batchnorm_rewrite_test) ...@@ -56,7 +56,7 @@ TEST_CASE(fwd_conv_batchnorm_rewrite_test)
migraphx::program p1 = create_program(); migraphx::program p1 = create_program();
migraphx::program p2 = create_program(); migraphx::program p2 = create_program();
migraphx::fwd_conv_batchnorm_rewrite opt; migraphx::rewrite_batchnorm opt;
opt.apply(p2); opt.apply(p2);
p1.compile(migraphx::cpu::target{}); p1.compile(migraphx::cpu::target{});
p2.compile(migraphx::cpu::target{}); p2.compile(migraphx::cpu::target{});
...@@ -93,10 +93,10 @@ TEST_CASE(non_literal) ...@@ -93,10 +93,10 @@ TEST_CASE(non_literal)
migraphx::program p1 = create_program(); migraphx::program p1 = create_program();
migraphx::program p2 = create_program(); migraphx::program p2 = create_program();
migraphx::fwd_conv_batchnorm_rewrite opt; migraphx::rewrite_batchnorm opt;
opt.apply(p2); opt.apply(p2);
EXPECT(any_of(p1, &is_batch_norm)); EXPECT(any_of(p1, &is_batch_norm));
EXPECT(any_of(p2, &is_batch_norm)); EXPECT(none_of(p2, &is_batch_norm));
} }
TEST_CASE(as_literal) TEST_CASE(as_literal)
...@@ -121,7 +121,7 @@ TEST_CASE(as_literal) ...@@ -121,7 +121,7 @@ TEST_CASE(as_literal)
migraphx::program p1 = create_program(); migraphx::program p1 = create_program();
migraphx::program p2 = create_program(); migraphx::program p2 = create_program();
migraphx::fwd_conv_batchnorm_rewrite opt; migraphx::rewrite_batchnorm opt;
opt.apply(p2); opt.apply(p2);
EXPECT(any_of(p1, &is_batch_norm)); EXPECT(any_of(p1, &is_batch_norm));
EXPECT(none_of(p2, &is_batch_norm)); EXPECT(none_of(p2, &is_batch_norm));
...@@ -159,7 +159,7 @@ TEST_CASE(literal_reshape) ...@@ -159,7 +159,7 @@ TEST_CASE(literal_reshape)
migraphx::program p1 = create_program(); migraphx::program p1 = create_program();
migraphx::program p2 = create_program(); migraphx::program p2 = create_program();
migraphx::fwd_conv_batchnorm_rewrite opt; migraphx::rewrite_batchnorm opt;
opt.apply(p2); opt.apply(p2);
EXPECT(any_of(p1, &is_batch_norm)); EXPECT(any_of(p1, &is_batch_norm));
EXPECT(none_of(p2, &is_batch_norm)); EXPECT(none_of(p2, &is_batch_norm));
......
#include <migraphx/simplify_algebra.hpp> #include <migraphx/simplify_algebra.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <test.hpp> #include <test.hpp>
...@@ -91,9 +94,9 @@ TEST_CASE(simplify_add3) ...@@ -91,9 +94,9 @@ TEST_CASE(simplify_add3)
auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1}}); auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto one = p2.add_literal(1); auto one = p2.add_literal(1);
auto two = p2.add_literal(2); auto two = p2.add_literal(2);
auto sum1 = p2.add_instruction(migraphx::op::add{}, one, x); auto sum1 = p2.add_instruction(migraphx::op::add{}, one, two);
auto sum2 = p2.add_instruction(migraphx::op::add{}, one, two); auto sum2 = p2.add_instruction(migraphx::op::add{}, one, sum1);
auto sum3 = p2.add_instruction(migraphx::op::add{}, sum1, sum2); auto sum3 = p2.add_instruction(migraphx::op::add{}, x, sum2);
p2.add_instruction(pass_op{}, sum3); p2.add_instruction(pass_op{}, sum3);
} }
EXPECT(p1 == p2); EXPECT(p1 == p2);
...@@ -129,4 +132,73 @@ void simplify_add4() ...@@ -129,4 +132,73 @@ void simplify_add4()
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
TEST_CASE(simplify_mul_conv1)
{
migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::int32_type, {1, 128, 28, 28}});
auto w =
p.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256, 128, 3, 3}}));
auto conv = p.add_instruction(migraphx::op::convolution{{1, 1}, {2, 2}, {1, 1}}, x, w);
auto a = p.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256}}));
auto b = p.add_instruction(migraphx::op::broadcast{1, {1, 256, 14, 14}}, a);
auto mul = p.add_instruction(migraphx::op::mul{}, conv, b);
p.add_instruction(pass_op{}, mul);
EXPECT(conv->outputs().front()->name() == "mul");
p.compile(simplify_algebra_target{});
auto new_conv =
std::find_if(p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; });
EXPECT(new_conv->outputs().front()->name() != "mul");
}
TEST_CASE(simplify_mul_add)
{
migraphx::program p1;
{
auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto one = p1.add_literal(1);
auto two = p1.add_literal(2);
auto sum = p1.add_instruction(migraphx::op::add{}, one, x);
auto mul = p1.add_instruction(migraphx::op::mul{}, sum, two);
p1.add_instruction(pass_op{}, mul);
}
p1.compile(simplify_algebra_target{});
migraphx::program p2;
{
auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto one = p2.add_literal(1);
auto two = p2.add_literal(2);
auto mul1 = p2.add_instruction(migraphx::op::mul{}, two, x);
auto mul2 = p2.add_instruction(migraphx::op::mul{}, two, one);
auto sum = p2.add_instruction(migraphx::op::add{}, mul1, mul2);
p2.add_instruction(pass_op{}, sum);
}
EXPECT(p1 == p2);
}
TEST_CASE(simplify_inner_broadcast)
{
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
migraphx::program p1;
{
auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto y = p1.add_parameter("y", {migraphx::shape::int32_type, {1}});
auto xb = p1.add_instruction(b, x);
auto yb = p1.add_instruction(b, y);
auto sum = p1.add_instruction(migraphx::op::add{}, xb, yb);
p1.add_instruction(pass_op{}, sum);
}
p1.compile(simplify_algebra_target{});
migraphx::program p2;
{
auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto y = p2.add_parameter("y", {migraphx::shape::int32_type, {1}});
auto sum = p2.add_instruction(migraphx::op::add{}, x, y);
auto sumb = p2.add_instruction(b, sum);
p2.add_instruction(pass_op{}, sumb);
}
EXPECT(p1 == p2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment