"...deploy/git@developer.sourcefind.cn:Wenxuan/LightX2V.git" did not exist on "835ca96b192b0dc50d31b83b78fdce462ecb1f47"
Unverified Commit 2fdf510d authored by mvermeulen's avatar mvermeulen Committed by GitHub
Browse files

Merge branch 'develop' into onnx_autopad_fix

parents d5939189 4085af9b
......@@ -12,7 +12,7 @@ add_library(migraphx
eliminate_concat.cpp
eliminate_identity.cpp
eliminate_pad.cpp
fwd_conv_batchnorm_rewrite.cpp
rewrite_batchnorm.cpp
rewrite_rnn.cpp
rewrite_pooling.cpp
env.cpp
......
......@@ -74,7 +74,7 @@ auto bind_match(M m, std::string name)
[ =, name = std::move(name) ](matcher_context & ctx, instruction_ref ins) {
auto result = m.match(ctx, ins);
if(result != ctx.not_found())
ctx.instructions.emplace(name, ins);
ctx.instructions[name] = ins;
return result;
});
}
......@@ -240,6 +240,21 @@ void find_matches(program& p, Ms&&... ms)
}
}
template <class M>
struct find_skip
{
M m;
M matcher() const { return m; }
void apply(program&, const matcher_result&) const {}
};
template <class M>
find_skip<M> make_find_skip(M m)
{
return {m};
}
struct lazy_and
{
template <class F, class G>
......@@ -311,6 +326,12 @@ const constexpr auto all_of = match_fold_f<lazy_and, true, true>{};
const constexpr auto any_of = match_fold_f<lazy_or, false, true>{};
const constexpr auto none_of = match_fold_f<lazy_or, false, false>{};
template <class... Ms>
auto skip_matches(Ms... ms)
{
return make_find_skip(any_of(ms...));
}
inline auto inputs()
{
return [](auto ins, auto f) {
......@@ -369,6 +390,50 @@ MIGRAPHX_BASIC_MATCHER(used_once, const matcher_context& ctx, instruction_ref in
return ctx.not_found();
}
inline auto used_once_recursive(std::size_t depth)
{
return make_basic_fun_matcher([=](const matcher_context& ctx, instruction_ref start) {
// Used once
if(start->outputs().size() == 1)
return start;
// Unused
if(start->outputs().empty())
{
if(std::next(start) == ctx.not_found())
return start;
else
return ctx.not_found();
}
// Check for dead instructions
auto is_dead = fix<bool>([&](auto self, auto ins, auto n) {
if(n == 0)
return false;
if(ins->get_shape().elements() == 0)
return false;
if(ins->outputs().empty() and std::next(ins) != ctx.not_found())
return true;
return std::all_of(ins->outputs().begin(), ins->outputs().end(), [&](auto i) {
return self(i, n - 1);
});
});
auto dead = std::count_if(start->outputs().begin(), start->outputs().end(), [&](auto i) {
return is_dead(i, depth);
});
if(dead + 1 == start->outputs().size())
return start;
return ctx.not_found();
});
}
MIGRAPHX_PRED_MATCHER(is_constant, instruction_ref ins) { return ins->can_eval(); }
MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref ins)
{
if(ins->outputs().empty() and ins != std::prev(ctx.not_found()))
return ins;
return ctx.not_found();
}
template <class... Ms>
auto skip_output(Ms... ms)
{
......@@ -404,6 +469,12 @@ inline auto name(std::unordered_set<std::string> names)
});
}
template <class... Ts>
inline auto name(std::string s, Ts... xs) // NOLINT
{
return name(std::unordered_set<std::string>{std::move(s), std::move(xs)...});
}
inline auto nargs(std::size_t n)
{
return make_basic_pred_matcher([=](instruction_ref ins) { return ins->inputs().size() == n; });
......
......@@ -13,9 +13,9 @@ struct program;
/**
* Rewrite batchnorm to a multiply and add.
*/
struct fwd_conv_batchnorm_rewrite
struct rewrite_batchnorm
{
std::string name() const { return "fwd_conv_batchnorm_rewrite"; }
std::string name() const { return "rewrite_batchnorm"; }
void apply(program& p) const;
};
......
......@@ -206,6 +206,16 @@ struct onnx_parser
return out_lens;
}
instruction_ref make_contiguous(instruction_ref ins)
{
if(ins->get_shape().standard())
{
return ins;
}
return prog.add_instruction(op::contiguous{}, ins);
}
template <class T>
instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x)
{
......@@ -441,12 +451,7 @@ struct onnx_parser
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
}
if(!args[0]->get_shape().standard())
{
args[0] = prog.add_instruction(op::contiguous{}, args[0]);
}
return prog.add_instruction(op, args[0]);
return prog.add_instruction(op, make_contiguous(args[0]));
}
instruction_ref
......@@ -494,23 +499,41 @@ struct onnx_parser
{
axis = parse_value(attributes.at("axis")).at<int>();
}
op::gather op{axis};
return prog.add_instruction(op, std::move(args));
return prog.add_instruction(op, make_contiguous(args[0]), make_contiguous(args[1]));
}
instruction_ref
parse_slice(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
op::slice op;
std::vector<size_t> dims = args[0]->get_shape().lens();
size_t num_dims = dims.size();
if(contains(attributes, "axes"))
{
literal s = parse_value(attributes.at("axes"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
}
else
{
op.axes = std::vector<int64_t>(num_dims);
std::iota(op.axes.begin(), op.axes.end(), 0);
}
if(contains(attributes, "ends"))
{
literal s = parse_value(attributes.at("ends"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.ends)); });
for(size_t i = 0; i < num_dims; i++)
{
if(static_cast<size_t>(op.ends[i]) > dims[i])
{
op.ends[i] = dims[i];
}
}
}
if(contains(attributes, "starts"))
{
literal s = parse_value(attributes.at("starts"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); });
......
......@@ -74,7 +74,8 @@ void quantize(program& prog, const std::vector<std::string>& ins_names)
// if the input is a convert operator, uses its input
// as its current input
instruction_ref input_fp16{};
if(input->name() == "convert")
if(input->name() == "convert" and
input->inputs().front()->get_shape().type() == shape::half_type)
{
input_fp16 = input->inputs().front();
}
......
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp>
#include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/batch_norm.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/dfor.hpp>
......@@ -11,7 +12,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void fwd_conv_batchnorm_rewrite::apply(program& p) const
void rewrite_batchnorm::apply(program& p) const
{
for(auto ins : iterator_for(p))
{
......@@ -25,46 +26,30 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
if(any_of({gamma, bias, mean, variance}, [](auto arg) { return arg.empty(); }))
continue;
auto conv_ins = ins->inputs()[0];
if(conv_ins->name() != "convolution")
continue;
// Get convolution weights
auto weights = conv_ins->inputs()[1]->eval();
if(weights.empty())
continue;
auto s = shape{ins->get_shape().type(), {ins->get_shape().lens()[1]}};
// Get epsilon
auto bn_op = any_cast<op::batch_norm_inference>(ins->get_operator());
auto epsilon = bn_op.epsilon;
// Get convolution op
auto conv_op = conv_ins->get_operator();
auto weights_lens = weights.get_shape().lens();
auto conv_lens = conv_ins->get_shape().lens();
argument new_weights{weights.get_shape()};
argument new_bias{{bias.get_shape().type(), {bias.get_shape().elements()}}};
visit_all(weights, gamma, bias, mean, variance, new_weights, new_bias)(
[&](auto weights2,
auto gamma2,
auto bias2,
auto mean2,
auto variance2,
auto new_weights2,
auto new_bias2) {
dfor(weights_lens[0], weights_lens[1], weights_lens[2], weights_lens[3])(
[&](std::size_t k, std::size_t c, std::size_t h, std::size_t w) {
new_weights2(k, c, h, w) =
gamma2[k] / std::sqrt(variance2[k] + epsilon) * weights2(k, c, h, w);
});
dfor(new_bias.get_shape().elements())([&](std::size_t c) {
new_bias2[c] =
bias2[c] - (gamma2[c] * mean2[c] / std::sqrt(variance2[c] + epsilon));
argument a{s};
argument b{s};
visit_all(gamma, bias, mean, variance, a, b)(
[&](auto gamma2, auto bias2, auto mean2, auto variance2, auto a2, auto b2) {
dfor(a.get_shape().elements())(
[&](std::size_t c) { a2[c] = gamma2[c] / std::sqrt(variance2[c] + epsilon); });
dfor(b.get_shape().elements())([&](std::size_t c) {
b2[c] = bias2[c] - (gamma2[c] * mean2[c] / std::sqrt(variance2[c] + epsilon));
});
});
// Replace convolution instruction with updated weights
auto l_weights = p.add_literal({weights.get_shape(), new_weights.data()});
auto l_bias = p.add_literal({new_bias.get_shape(), new_bias.data()});
auto c = p.replace_instruction(conv_ins, conv_op, {conv_ins->inputs()[0], l_weights});
auto b = p.insert_instruction(ins, op::broadcast{1, c->get_shape().lens()}, l_bias);
p.replace_instruction(ins, op::add{}, {c, b});
auto broadcast = op::broadcast{1, ins->get_shape().lens()};
auto a_ins = p.add_literal({a.get_shape(), a.data()});
auto a_broadcast = p.insert_instruction(ins, broadcast, a_ins);
auto mul = p.insert_instruction(ins, op::mul{}, ins->inputs().front(), a_broadcast);
auto b_ins = p.add_literal({b.get_shape(), b.data()});
auto b_broadcast = p.insert_instruction(ins, broadcast, b_ins);
auto add = p.insert_instruction(ins, op::add{}, mul, b_broadcast);
p.replace_instruction(ins, add);
}
}
......
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct find_add_lit_broadcast
auto lit_broadcast() { return match::any_of(match::is_constant(), match::name("broadcast")); }
auto not_lit_broadcast() { return match::none_of(match::is_constant(), match::name("broadcast")); }
auto op_lit_broadcast(std::string op, std::string x, std::string y)
{
return match::name(std::move(op))(match::either_arg(0, 1)(
lit_broadcast().bind(std::move(x)), not_lit_broadcast().bind(std::move(y))));
}
auto conv_const_weights()
{
return match::name("convolution")(match::used_once(),
match::args(match::any(), match::is_constant().bind("w")));
}
struct find_mul_conv
{
auto lit_broadcast() const
auto matcher() const
{
return match::any_of(match::name("@literal"), match::name("broadcast"));
return match::name("mul")(match::either_arg(0, 1)(conv_const_weights().bind("conv"),
match::name("broadcast").bind("a")));
}
auto not_lit_broadcast() const
void apply(program& p, match::matcher_result r) const
{
return match::none_of(match::name("@literal"), match::name("broadcast"));
auto ins = r.result;
auto conv_ins = r.instructions["conv"];
auto a_ins = r.instructions["a"];
auto w_ins = r.instructions["w"];
auto broadcast_op = any_cast<op::broadcast>(a_ins->get_operator());
if(broadcast_op.axis != 1)
return;
auto new_a = p.insert_instruction(
ins, op::broadcast{0, w_ins->get_shape().lens()}, a_ins->inputs().front());
auto new_mul = p.insert_instruction(ins, op::mul{}, new_a, w_ins);
auto new_conv = p.insert_instruction(
ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_mul);
p.replace_instruction(ins, new_conv);
}
auto add_lit_broadcast(std::string x, std::string y) const
};
// a * (x + b) => a * x + a * b
struct find_mul_add
{
auto matcher() const
{
return match::name("add")(match::either_arg(0, 1)(lit_broadcast().bind(std::move(x)),
not_lit_broadcast().bind(std::move(y))));
return match::name("mul")(match::either_arg(0, 1)(
match::name("add")(
match::either_arg(0, 1)(
match::any().bind("x"),
match::any_of(conv_const_weights(), match::is_constant()).bind("b")),
match::none_of(match::args(match::is_constant(), match::is_constant())),
match::used_once()),
match::is_constant().bind("a")));
}
void apply(program& p, match::matcher_result r) const
{
auto ins = r.result;
auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"];
auto x_ins = r.instructions["x"];
assert(x_ins != b_ins);
auto ax_ins = p.insert_instruction(ins, op::mul{}, a_ins, x_ins);
auto ab_ins = p.insert_instruction(ins, op::mul{}, a_ins, b_ins);
p.replace_instruction(ins, op::add{}, ax_ins, ab_ins);
}
};
struct find_add_lit_broadcast
{
auto matcher() const
{
return match::name("add")(
match::args(add_lit_broadcast("a", "x"), add_lit_broadcast("b", "y")));
match::either_arg(0, 1)(op_lit_broadcast("add", "a", "x"), lit_broadcast().bind("b")));
}
void apply(program& p, match::matcher_result r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"];
auto sumab = p.insert_instruction(ins, op::add{}, a_ins, b_ins);
p.replace_instruction(ins, op::add{}, x_ins, sumab);
}
};
struct find_double_add_lit_broadcast
{
auto matcher() const
{
return match::name("add")(
match::args(op_lit_broadcast("add", "a", "x"), op_lit_broadcast("add", "b", "y")));
}
void apply(program& p, match::matcher_result r) const
......@@ -36,11 +117,9 @@ struct find_add_lit_broadcast
auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"];
if(a_ins->name() != b_ins->name())
return;
instruction_ref sumab;
if(a_ins->name() == "broadcast")
if(a_ins->name() == "broadcast" and b_ins->name() == "broadcast")
{
if(a_ins->inputs().at(0)->get_shape() != b_ins->inputs().at(0)->get_shape())
return;
......@@ -59,7 +138,46 @@ struct find_add_lit_broadcast
}
};
void simplify_algebra::apply(program& p) const { match::find_matches(p, find_add_lit_broadcast{}); }
struct find_inner_broadcast
{
auto matcher() const
{
return match::name("mul", "add")(
match::args(match::name("broadcast").bind("x"), match::name("broadcast").bind("y")));
}
void apply(program& p, match::matcher_result r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
auto y_ins = r.instructions["y"];
auto xbroadcast = any_cast<op::broadcast>(x_ins->get_operator());
auto ybroadcast = any_cast<op::broadcast>(y_ins->get_operator());
if(xbroadcast.axis != ybroadcast.axis)
return;
auto op = p.insert_instruction(
ins, ins->get_operator(), x_ins->inputs().front(), y_ins->inputs().front());
p.replace_instruction(ins, xbroadcast, op);
}
};
void simplify_algebra::apply(program& p) const
{
// Run simplifications multiple times
for(int i = 0; i < 4; i++)
{
match::find_matches(p,
find_inner_broadcast{},
find_double_add_lit_broadcast{},
find_add_lit_broadcast{},
find_mul_conv{},
find_mul_add{});
dead_code_elimination{}.apply(p);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -16,6 +16,7 @@ add_library(migraphx_device
device/argmin.cpp
device/max.cpp
device/min.cpp
device/mul_add.cpp
device/exp.cpp
device/erf.cpp
device/log.cpp
......
......@@ -6,6 +6,16 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void mul_add_relu(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{
nary(stream, result, arg1, arg2, arg3)(
[](auto x, auto a, auto b) { return std::max<decltype(a * x + b)>(0, a * x + b); });
}
void add_relu(hipStream_t stream,
const argument& result,
const argument& arg1,
......
......@@ -118,6 +118,111 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg
});
}
template <class F, class... Arguments>
void nary_double_broadcast_vec_impl(
hipStream_t stream, F f, argument result, argument barg1, argument barg2, Arguments... args)
{
assert(barg1.get_shape().broadcasted());
assert(barg2.get_shape().broadcasted());
assert(barg1.get_shape() == barg2.get_shape());
const auto& output_shape = result.get_shape();
const auto& b_shape = barg1.get_shape();
auto bdim =
std::distance(b_shape.strides().begin(),
std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) {
return x != 0;
}));
auto bdim_len = output_shape.lens()[bdim];
auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len;
const std::size_t vec_size = 4;
const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal;
const std::size_t bdim_vec_len = bdim_len / vec_size;
hip_vec_visit_all<vec_size>(result, barg1, barg2, args...)(
[&](auto output, auto binput1, auto binput2, auto... inputs) {
using type = typename decltype(output)::value_type;
const std::size_t nelements = output.size() / vec_size;
launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPHX_DEVICE_SHARED type buffer[2048 / vec_size];
// Load bias into LDS
for(size_t i = idx.local; i < bdim_vec_len; i += nlocal)
{
buffer[i] = binput1.data()[i];
}
for(size_t i = idx.local; i < bdim_vec_len; i += nlocal)
{
buffer[i + bdim_vec_len] = binput2.data()[i];
}
__syncthreads();
auto* bp = as_pointer(buffer);
// Process the data
for(size_t i = idx.global; i < nelements; i += nglobal)
{
auto bidx = ((i * vec_size) % bdim_next_stride) / bdim_stride;
auto b1 = bp[bidx];
auto b2 = bp[bidx + bdim_len];
auto out = output.data()[i];
for(std::size_t j = 0; j < vec_size; j++)
{
out[j] = f(inputs.data()[i][j]..., b2, b1);
}
output.data()[i] = out;
}
});
});
}
template <class F, class... Arguments>
void nary_double_broadcast_impl(
hipStream_t stream, F f, argument result, argument barg1, argument barg2, Arguments... args)
{
assert(barg1.get_shape().broadcasted());
assert(barg2.get_shape().broadcasted());
assert(barg1.get_shape() == barg2.get_shape());
const auto& output_shape = result.get_shape();
const auto& b_shape = barg1.get_shape();
auto bdim =
std::distance(b_shape.strides().begin(),
std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) {
return x != 0;
}));
auto bdim_len = output_shape.lens()[bdim];
auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len;
const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal;
std::size_t nelements = result.get_shape().elements();
hip_visit_all(result, barg1, barg2, args...)(
[&](auto output, auto binput1, auto binput2, auto... inputs) {
using type = typename decltype(output)::value_type;
launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPHX_DEVICE_SHARED type buffer[2048];
// Load bias into LDS
for(size_t i = idx.local; i < bdim_len; i += nlocal)
{
buffer[i] = binput1.data()[i];
}
for(size_t i = idx.local; i < bdim_len; i += nlocal)
{
buffer[i + bdim_len] = binput2.data()[i];
}
__syncthreads();
// Process the data
for(size_t i = idx.global; i < nelements; i += nglobal)
{
auto bidx = (i % bdim_next_stride) / bdim_stride;
auto b1 = buffer[bidx];
auto b2 = buffer[bidx + bdim_len];
output.data()[i] = f(inputs.data()[i]..., b2, b1);
}
});
});
}
template <class F, class... Arguments>
void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments... args)
{
......@@ -177,49 +282,113 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args)
}
template <class... Arguments>
auto nary(hipStream_t stream, argument result)
bool broadcastable(bool& divisible_by_4,
std::size_t max_size,
const argument& result,
const argument& barg,
const Arguments&... args)
{
divisible_by_4 = false;
auto bshape = barg.get_shape();
const bool standard =
all_of({args.get_shape()...}, [](const shape& s) { return s.standard(); });
const bool same_shapes =
all_of({args.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); });
// TODO: Check result and args shape is the same
if(standard and same_shapes and bshape.broadcasted() and not bshape.scalar())
{
auto not_zero = [](auto x) { return x != 0; };
const auto& strides = bshape.strides();
auto b_it = std::find_if(strides.begin(), strides.end(), not_zero);
auto b_idx = std::distance(strides.begin(), b_it);
auto b_len = result.get_shape().lens()[b_idx];
auto b_stride = result.get_shape().strides()[b_idx];
assert(bshape.lens()[b_idx] == b_len);
if(b_len <= max_size and std::none_of(std::next(b_it), strides.end(), not_zero))
{
divisible_by_4 = (b_len % 4 == 0) and (b_stride % 4 == 0) and
(front_args(args...).get_shape().elements() % 4 == 0);
return true;
}
}
return false;
}
inline bool broadcastable(bool& divisible_by_4, std::size_t, const argument&, const argument&)
{
divisible_by_4 = false;
return false;
}
// Nullary
inline auto nary(hipStream_t stream, argument result)
{
return [=](auto f) { nary_standard_impl(stream, f, result); };
}
// Unary
inline auto nary(hipStream_t stream, argument result, argument arg)
{
return [=](auto f) { nary_impl(stream, f, result, arg); };
}
// Binary
inline auto nary(hipStream_t stream, argument result, argument arg, argument barg)
{
return [=](auto f) {
bool divisible_by_4 = false;
if(broadcastable(divisible_by_4, 2048, result, barg, arg))
{
if(divisible_by_4)
nary_broadcast_vec_impl(stream, f, result, barg, arg);
else
nary_broadcast_impl(stream, f, result, barg, arg);
}
else
{
nary_impl(stream, f, result, arg, barg);
}
};
}
template <class... Arguments>
auto nary(hipStream_t stream, argument result, Arguments... args)
{
static_assert(sizeof...(args) > 2, "Args needs to be greater than 2");
return [=](auto f) {
auto barg = back_args(args...);
bool fallback = pop_back_args(args...)([&](auto&&... args2) {
auto bshape = barg.get_shape();
const bool standard =
all_of({args2.get_shape()...}, [](const shape& s) { return s.standard(); });
const bool same_shapes = all_of(
{args2.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); });
// TODO: Check result and args shape is the same
if(standard and same_shapes and bshape.broadcasted() and not bshape.scalar())
auto barg1 = back_args(args...);
bool fallback1 = pop_back_args(args...)([&](auto&&... args2) {
auto barg2 = back_args(args2...);
bool fallback2 =
barg2.get_shape() != barg1.get_shape() or not barg2.get_shape().broadcasted() or
pop_back_args(args2...)([&](auto&&... args3) {
bool divisible_by_4 = false;
if(broadcastable(divisible_by_4, 1024, result, barg2, args3...))
{
if(divisible_by_4)
nary_double_broadcast_vec_impl(
stream, f, result, barg1, barg2, args3...);
else
nary_double_broadcast_impl(stream, f, result, barg1, barg2, args3...);
return false;
}
return true;
});
if(not fallback2)
return false;
bool divisible_by_4 = false;
if(broadcastable(divisible_by_4, 2048, result, barg1, args2...))
{
auto not_zero = [](auto x) { return x != 0; };
const auto& strides = bshape.strides();
auto b_it = std::find_if(strides.begin(), strides.end(), not_zero);
auto b_idx = std::distance(strides.begin(), b_it);
auto b_len = result.get_shape().lens()[b_idx];
auto b_stride = result.get_shape().strides()[b_idx];
assert(bshape.lens()[b_idx] == b_len);
if(b_len <= 2048 and std::none_of(std::next(b_it), strides.end(), not_zero))
{
const bool divisible_by_4 =
(b_len % 4 == 0) and (b_stride % 4 == 0) and
(front_args(args...).get_shape().elements() % 4 == 0);
if(divisible_by_4)
nary_broadcast_vec_impl(stream, f, result, barg, args2...);
else
nary_broadcast_impl(stream, f, result, barg, args2...);
return false;
}
if(divisible_by_4)
nary_broadcast_vec_impl(stream, f, result, barg1, args2...);
else
nary_broadcast_impl(stream, f, result, barg1, args2...);
return false;
}
return true;
});
if(fallback)
if(fallback1)
nary_impl(stream, f, result, args...);
};
}
......
#include <migraphx/gpu/device/add_relu.hpp>
#include <migraphx/gpu/device/nary.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void mul_add(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{
nary(stream, result, arg1, arg2, arg3)([](auto x, auto a, auto b) { return a * x + b; });
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -2,6 +2,7 @@
#include <migraphx/matcher.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/device/mul_add.hpp>
#include <migraphx/gpu/device/add_relu.hpp>
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/instruction.hpp>
......@@ -198,21 +199,62 @@ struct hip_add_relu
}
};
struct hip_mul_add
{
std::string name() const { return "hip::mul_add"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(4);
return inputs.front();
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::mul_add(ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2));
return args.at(3);
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
struct hip_mul_add_relu
{
std::string name() const { return "hip::mul_add_relu"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(4);
return inputs.front();
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::mul_add_relu(
ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2));
return args.at(3);
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
void move_broadcasted_back(std::vector<instruction_ref>& args)
{
// Ensure the last arguments is the broadcasted one
auto it = std::find_if(
args.begin(), args.end(), [](auto arg) { return arg->get_shape().broadcasted(); });
if(it != args.end())
std::swap(*it, *std::prev(args.end(), 2));
auto last = std::prev(args.end());
auto it =
std::find_if(args.begin(), last, [](auto arg) { return arg->get_shape().broadcasted(); });
if(it != last)
std::swap(*it, *std::prev(last));
}
void move_standard_front(std::vector<instruction_ref>& args)
{
// Ensure the first arguments is the standard one
auto it = std::find_if(
args.begin(), args.end(), [](auto arg) { return arg->get_shape().standard(); });
if(it != args.end())
auto last = std::prev(args.end());
auto it =
std::find_if(args.begin(), last, [](auto arg) { return arg->get_shape().standard(); });
if(it != last)
std::swap(*it, args.front());
}
......@@ -220,11 +262,13 @@ struct find_add_relu
{
auto matcher() const
{
return match::name("gpu::relu")(
match::arg(0)(match::any_of(match::name("gpu::add"),
match::name("hip::triadd"),
match::any_of[match::inputs()](match::standard_shape()))
.bind("add")));
return match::name("gpu::relu")(match::arg(0)(
match::used_once(),
match::any_of(match::name("gpu::add"),
match::name("hip::triadd"),
match::any_of(match::name("@literal"),
match::any_of[match::inputs()](match::standard_shape())))
.bind("add")));
}
void apply(program& p, match::matcher_result r) const
......@@ -249,8 +293,10 @@ struct find_triadd
auto matcher() const
{
return match::name("gpu::add")(match::either_arg(0, 1)(
match::name("gpu::add").bind("add"),
match::any(match::any_of[match::inputs()](match::standard_shape())).bind("input")));
match::name("gpu::add")(match::used_once()).bind("add"),
match::any(match::any_of(match::name("@literal"),
match::any_of[match::inputs()](match::standard_shape())))
.bind("input")));
}
void apply(program& p, match::matcher_result r) const
......@@ -273,6 +319,51 @@ struct find_triadd
}
};
struct find_mul_add
{
auto matcher() const
{
return match::name("gpu::add")(match::either_arg(0, 1)(
match::name("gpu::mul")(match::used_once()).bind("mul"), match::any().bind("b")));
}
void apply(program& p, match::matcher_result r) const
{
auto mul_ins = r.instructions["mul"];
auto b_ins = r.instructions["b"];
auto ins = r.result;
auto args = mul_ins->inputs();
assert(mul_ins != b_ins);
move_standard_front(args);
move_broadcasted_back(args);
args.insert(std::prev(args.end()), b_ins);
args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_mul_add{}, args);
}
};
struct find_mul_add_relu
{
auto matcher() const
{
return match::name("gpu::relu")(
match::arg(0)(match::name("hip::mul_add")(match::used_once()).bind("mul_add")));
}
void apply(program& p, match::matcher_result r) const
{
auto mul_add_ins = r.instructions["mul_add"];
auto ins = r.result;
auto args = mul_add_ins->inputs();
// Use the allocation from the relu operator
args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_mul_add_relu{}, args);
}
};
struct miopen_conv_bias
{
op::convolution op;
......@@ -428,6 +519,8 @@ void fuse_ops::apply(program& p) const
match::find_matches(p,
find_conv_bias_relu{ctx},
find_conv_bias{ctx},
find_mul_add{},
find_mul_add_relu{},
find_add_relu{}
);
// clang-format on
......
......@@ -11,6 +11,12 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void mul_add_relu(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3);
void add_relu(hipStream_t stream,
const argument& result,
const argument& arg1,
......
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_MUL_ADD_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_MUL_ADD_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void mul_add(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -8,51 +8,6 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
template <class... Ts>
rocblas_status generic_rocblas_gemm_ex(Ts&&... xs)
{
return rocblas_gemm_ex(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_batched_gemm_ex(Ts&&... xs)
{
return rocblas_gemm_strided_batched_ex(std::forward<Ts>(xs)...);
}
template <class T>
struct compute_rocblas_type
{
using type = T;
};
template <class T>
struct compute_rocblas_type<const T>
{
using type = const typename compute_rocblas_type<T>::type;
};
template <>
struct compute_rocblas_type<half>
{
using type = rocblas_half;
};
template <class T>
using rb_type = typename compute_rocblas_type<T>::type;
template <class T>
rb_type<T> to_rocblas_type(T x)
{
return reinterpret_cast<const rb_type<T>&>(x);
}
template <class T>
rb_type<T>* to_rocblas_type(T* x)
{
return reinterpret_cast<rb_type<T>*>(x);
}
shape rocblas_quant_gemm::compute_shape(const std::vector<shape>& inputs) const
{
std::vector<shape> in_shapes(inputs);
......@@ -102,13 +57,13 @@ argument rocblas_quant_gemm::compute(context& ctx,
auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens();
output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(beta));
auto alpha_r = as(op.alpha);
auto beta_r = as(beta);
auto out_lens = output_shape.lens();
rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1];
rocblas_int k = args[0].get_shape().lens()[dim_1];
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
auto to_pointer = [&](auto&& arg) { return as.from(arg.data()); };
assert(k % 4 == 0);
auto num_matrices = std::accumulate(
......@@ -119,36 +74,36 @@ argument rocblas_quant_gemm::compute(context& ctx,
// column-major format. When doing a C = A * B, we actually do
// C^T = (B^T) * (A^T). That is the reason we input args[1] as
// A and args[0] as B in calling the rocblas_gemm.
generic_rocblas_gemm_ex(ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
to_pointer(args.at(1)),
rocblas_datatype_i8_r,
ldb,
to_pointer(args.at(0)),
rocblas_datatype_i8_r,
lda,
&beta_r,
to_pointer(args[2]),
rocblas_datatype_i32_r,
ldc,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
rocblas_datatype_i32_r,
ldc,
rocblas_datatype_i32_r,
rocblas_gemm_algo_standard,
0,
0,
nullptr,
nullptr);
rocblas_gemm_ex(ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
to_pointer(args.at(1)),
rocblas_datatype_i8_r,
ldb,
to_pointer(args.at(0)),
rocblas_datatype_i8_r,
lda,
&beta_r,
to_pointer(args[2]),
rocblas_datatype_i32_r,
ldc,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
rocblas_datatype_i32_r,
ldc,
rocblas_datatype_i32_r,
rocblas_gemm_algo_standard,
0,
0,
nullptr,
nullptr);
}
else
{
generic_rocblas_batched_gemm_ex(
rocblas_gemm_strided_batched_ex(
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
......
......@@ -14,7 +14,7 @@
#include <migraphx/propagate_constant.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/common_subexpression_elimination.hpp>
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp>
#include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/eliminate_concat.hpp>
......@@ -44,13 +44,13 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
eliminate_identity{},
eliminate_pad{},
dead_code_elimination{},
fwd_conv_batchnorm_rewrite{},
rewrite_batchnorm{},
dead_code_elimination{},
rewrite_rnn{},
rewrite_pooling{},
dead_code_elimination{},
//common_subexpression_elimination{},
//dead_code_elimination{},
// common_subexpression_elimination{},
// dead_code_elimination{},
simplify_algebra{},
dead_code_elimination{},
auto_contiguous{},
......
......@@ -26,7 +26,6 @@ struct tf_parser
{
using attribute_map = std::unordered_map<std::string, tensorflow::AttrValue>;
using node_map = std::map<std::string, tensorflow::NodeDef>;
// using input_node_map = std::unordered_map<std::string, std::unordered_set<std::string>>;
using op_func = std::function<instruction_ref(attribute_map, std::vector<instruction_ref>)>;
node_map nodes;
......@@ -149,9 +148,26 @@ struct tf_parser
return axes;
}
std::vector<int64_t> get_axes_from_mask(const size_t num_axes, const uint32_t mask)
{
uint32_t bitwise_compare = 1;
std::vector<int64_t> axes;
for(size_t i = 0; i < num_axes; i++)
{
// the LSB corresponds to axis 0 when determining which axes to begin
if(((mask >> i) & bitwise_compare) == 1)
axes.push_back(1);
else
axes.push_back(0);
}
return axes;
}
tf_parser()
{
add_generic_op("All", op::identity{});
add_generic_op("Identity", op::identity{});
add_generic_op("LessEqual", op::identity{});
add_generic_op("Relu", op::relu{});
add_generic_op("Relu6", op::clip{6.0, 0.0});
add_generic_op("Rsqrt", op::rsqrt{});
......@@ -166,6 +182,7 @@ struct tf_parser
add_mem_op("AvgPool", &tf_parser::parse_pooling);
add_mem_op("BatchMatMul", &tf_parser::parse_matmul, false);
add_mem_op("BatchMatMulV2", &tf_parser::parse_matmul, false);
add_mem_op("BiasAdd", &tf_parser::parse_biasadd);
add_mem_op("Cast", &tf_parser::parse_cast, false);
add_mem_op("ConcatV2", &tf_parser::parse_concat, false);
......@@ -177,14 +194,15 @@ struct tf_parser
add_mem_op("GatherV2", &tf_parser::parse_gather, false);
add_mem_op("MatMul", &tf_parser::parse_matmul, false);
add_mem_op("MaxPool", &tf_parser::parse_pooling);
add_mem_op("Mean", &tf_parser::parse_mean);
add_mem_op("Mean", &tf_parser::parse_mean, false);
add_mem_op("OneHot", &tf_parser::parse_onehot, false);
add_mem_op("Pack", &tf_parser::parse_pack, false);
add_mem_op("Pad", &tf_parser::parse_pad);
add_mem_op("Reshape", &tf_parser::parse_reshape, false);
add_mem_op("Slice", &tf_parser::parse_slice, false);
add_mem_op("Softmax", &tf_parser::parse_softmax<op::softmax>);
add_mem_op("Softmax", &tf_parser::parse_softmax<op::softmax>, false);
add_mem_op("Squeeze", &tf_parser::parse_squeeze, false);
add_mem_op("StridedSlice", &tf_parser::parse_stridedslice);
add_mem_op("StridedSlice", &tf_parser::parse_stridedslice, false);
add_mem_op("Transpose", &tf_parser::parse_transpose, false);
}
......@@ -547,7 +565,7 @@ struct tf_parser
}
if(contains(attributes, "transpose_b"))
{
transb = attributes.at("transpose_a").b();
transb = attributes.at("transpose_b").b();
}
if(contains(attributes, "adj_x"))
......@@ -574,8 +592,7 @@ struct tf_parser
parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
bool keep_dims = attributes.at("keep_dims").b();
auto lens = args[0]->get_shape().lens();
auto axes = parse_axes(args[1]->eval().get<int32_t>().to_vector<int64_t>(), lens.size());
auto axes = args[1]->eval().get<int32_t>().to_vector<int64_t>();
if(keep_dims)
{
......@@ -588,6 +605,32 @@ struct tf_parser
}
}
instruction_ref
parse_onehot(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
size_t depth = static_cast<size_t>(args[1]->eval().at<int32_t>());
int64_t axis = -1;
float on_value = args[2]->eval().at<float>();
float off_value = args[3]->eval().at<float>();
std::vector<float> depth_input(depth * depth, off_value);
for(int i = 0; i < depth; i++)
{
depth_input[depth * i + i] = on_value;
}
if(contains(attributes, "axis"))
axis = attributes.at("axis").i();
if(axis == -1)
{
shape s{shape::float_type, {depth, depth}};
auto l0 = prog.add_literal({s, depth_input});
return prog.add_instruction(op::gather{0}, {l0, args[0]});
}
MIGRAPHX_THROW("MIGraphX does not support axis != -1");
}
instruction_ref parse_pack(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
......@@ -799,21 +842,50 @@ struct tf_parser
std::vector<instruction_ref> args)
{
op::slice op;
auto starts = args[1]->eval().get<int32_t>().to_vector();
auto ends = args[2]->eval().get<int32_t>().to_vector();
size_t num_axes = args[0]->get_shape().lens().size();
auto starts = args[1]->eval().get<int32_t>().to_vector();
auto ends = args[2]->eval().get<int32_t>().to_vector();
auto l0 = args[0];
size_t num_axes = l0->get_shape().lens().size();
std::vector<size_t> axes = l0->get_shape().lens();
op.starts = std::vector<int64_t>(starts.begin(), starts.end());
op.ends = std::vector<int64_t>(ends.begin(), ends.end());
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
uint32_t begin_mask = 0;
uint32_t end_mask = 0;
uint32_t shrink_axis_mask = 0;
uint32_t bitwise_compare = 1;
std::vector<int64_t> squeeze_axes;
if(contains(attributes, "begin_mask"))
begin_mask = static_cast<uint32_t>(attributes.at("begin_mask").i());
if(contains(attributes, "end_mask"))
end_mask = static_cast<uint32_t>(attributes.at("end_mask").i());
if(contains(attributes, "shrink_axis_mask"))
shrink_axis_mask = static_cast<uint32_t>(attributes.at("shrink_axis_mask").i());
std::vector<int64_t> begin_axes = get_axes_from_mask(num_axes, begin_mask);
std::vector<int64_t> end_axes = get_axes_from_mask(num_axes, end_mask);
for(size_t i = 0; i < num_axes; i++)
{
if(begin_axes.at(i) == 1)
{
op.starts.at(i) = 0;
}
if(end_axes.at(i) == 1)
{
op.ends.at(i) = axes.at(i);
}
}
auto l1 = prog.add_instruction(op, l0);
if(shrink_axis_mask == 0)
return l1;
for(size_t i = 0; i < num_axes; i++)
{
// the LSB corresponds to axis 0 when determining which axes to squeeze
......@@ -821,8 +893,7 @@ struct tf_parser
squeeze_axes.push_back(i);
}
auto l0 = prog.add_instruction(op, make_contiguous(args[0]));
return to_nhwc(prog.add_instruction(op::squeeze{squeeze_axes}, l0));
return prog.add_instruction(op::squeeze{squeeze_axes}, l1);
}
instruction_ref
......@@ -862,10 +933,16 @@ struct tf_parser
if(instructions.count(name) == 0)
{
auto&& node = nodes.at(name);
// assert ops ignored
if(node.op() == "Assert" or contains(name, "Assert"))
return;
std::vector<instruction_ref> args;
for(auto&& input : node.input())
{
// control dependencies (signified by ^ before the name) are ignored
if(contains(input, "^"))
continue;
if(nodes.count(input) > 0)
{
auto&& iname = get_name(nodes.at(input));
......
......@@ -502,6 +502,24 @@ struct test_triadd2 : verify_program<test_triadd2>
}
};
struct test_mul_add : verify_program<test_mul_add>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::shape bs{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", s);
auto a = p.add_parameter("a", bs);
auto b = p.add_parameter("b", bs);
auto ab = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, a);
auto bb = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, b);
auto mul = p.add_instruction(migraphx::op::mul{}, x, ab);
p.add_instruction(migraphx::op::add{}, mul, bb);
return p;
}
};
struct test_add_broadcast : verify_program<test_add_broadcast>
{
migraphx::program create_program() const
......
......@@ -5,6 +5,8 @@
namespace match = migraphx::match;
MIGRAPHX_PRED_MATCHER(throws, migraphx::instruction_ref) { MIGRAPHX_THROW("Matcher throws"); }
template <class M>
migraphx::match::matcher_result find_match(migraphx::program& p, M&& m)
{
......@@ -331,6 +333,81 @@ TEST_CASE(match_either_args3)
EXPECT(bool{r.result == p.end()});
}
TEST_CASE(match_either_args_any1)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, two);
p.add_instruction(pass_op{}, sum2);
auto m =
match::name("sum")(match::either_arg(0, 1)(match::any().bind("x"), match::any().bind("y")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}
TEST_CASE(match_either_args_any2)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, two);
p.add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(
match::either_arg(0, 1)(match::any().bind("x"), match::name("@literal").bind("y")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}
TEST_CASE(match_either_args_any3)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, two);
p.add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(
match::either_arg(0, 1)(match::name("@literal").bind("x"), match::any().bind("y")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}
TEST_CASE(match_either_args_any4)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, two);
p.add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(
match::either_arg(0, 1)(match::name("sum").bind("x"), match::any().bind("y")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum2});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}
TEST_CASE(match_either_args_any5)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, two);
p.add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(
match::either_arg(0, 1)(match::any().bind("x"), match::name("sum").bind("y")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum2});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}
TEST_CASE(match_all_of1)
{
migraphx::program p;
......@@ -370,6 +447,36 @@ TEST_CASE(match_all_of3)
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_lazy_any_of)
{
migraphx::program p;
auto one = p.add_literal(1);
p.add_instruction(pass_op{}, one);
auto m = match::any_of(match::any(), throws());
auto r = find_match(p, m);
EXPECT(bool{r.result == one});
}
TEST_CASE(match_lazy_all_of)
{
migraphx::program p;
auto one = p.add_literal(1);
p.add_instruction(pass_op{}, one);
auto m = match::all_of(match::none(), throws());
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
TEST_CASE(match_lazy_none_of)
{
migraphx::program p;
auto one = p.add_literal(1);
p.add_instruction(pass_op{}, one);
auto m = match::none_of(match::any(), throws());
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
TEST_CASE(match_any_of1)
{
migraphx::program p;
......@@ -396,6 +503,97 @@ TEST_CASE(match_any_of2)
EXPECT(bool{r.result == p.end()});
}
TEST_CASE(match_any_of_lazy1)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(
match::any_of(match::args(match::any(), match::any()).bind("x"),
match::args(match::name("sum"), match::name("sum")).bind("y")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
EXPECT(migraphx::contains(r.instructions, "x"));
EXPECT(bool{r.instructions["x"] == sum});
EXPECT(not migraphx::contains(r.instructions, "y"));
}
TEST_CASE(match_any_of_lazy2)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(
match::any_of(match::args(match::name("@literal"), match::name("@literal")).bind("x"),
match::args(match::any(), match::any()).bind("y")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
EXPECT(migraphx::contains(r.instructions, "x"));
EXPECT(bool{r.instructions["x"] == sum});
EXPECT(not migraphx::contains(r.instructions, "y"));
}
TEST_CASE(match_any_of_lazy3)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(
match::any_of(match::args(match::any(), match::any()).bind("x"),
match::args(match::name("@literal"), match::name("@literal")).bind("y")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
EXPECT(migraphx::contains(r.instructions, "x"));
EXPECT(bool{r.instructions["x"] == sum});
EXPECT(not migraphx::contains(r.instructions, "y"));
}
TEST_CASE(match_any_of_lazy4)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::any_of(
match::args(match::name("@literal").bind("x1"), match::name("@literal").bind("y1")),
match::args(match::any().bind("x2"), match::any().bind("y2"))));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
EXPECT(migraphx::contains(r.instructions, "x1"));
EXPECT(migraphx::contains(r.instructions, "y1"));
EXPECT(bool{r.instructions["x1"] == one});
EXPECT(bool{r.instructions["y1"] == two});
EXPECT(not migraphx::contains(r.instructions, "x2"));
EXPECT(not migraphx::contains(r.instructions, "y2"));
}
TEST_CASE(match_any_of_lazy5)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::any_of(
match::args(match::any().bind("x1"), match::any().bind("y1")),
match::args(match::name("@literal").bind("x2"), match::name("@literal").bind("y2"))));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
EXPECT(migraphx::contains(r.instructions, "x1"));
EXPECT(migraphx::contains(r.instructions, "y1"));
EXPECT(bool{r.instructions["x1"] == one});
EXPECT(bool{r.instructions["y1"] == two});
EXPECT(not migraphx::contains(r.instructions, "x2"));
EXPECT(not migraphx::contains(r.instructions, "y2"));
}
TEST_CASE(match_none_of1)
{
migraphx::program p;
......
......@@ -4,6 +4,7 @@
#include <migraphx/operators.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/onnx.hpp>
#include "test.hpp"
......@@ -514,6 +515,32 @@ TEST_CASE(shape_gather_test)
EXPECT(p == prog);
}
TEST_CASE(transpose_gather_test)
{
migraphx::program p;
auto make_contiguous = [&p](migraphx::instruction_ref ins) {
if(ins->get_shape().standard())
{
return ins;
}
return p.add_instruction(migraphx::op::contiguous{}, ins);
};
auto data = p.add_parameter("data", migraphx::shape{migraphx::shape::float_type, {3, 5, 4, 6}});
auto ind =
p.add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 4, 3, 5}});
auto tr_data = p.add_instruction(migraphx::op::transpose{{0, 2, 1, 3}}, data);
auto tr_ind = p.add_instruction(migraphx::op::transpose{{0, 2, 1, 3}}, ind);
int axis = 1;
p.add_instruction(
migraphx::op::gather{axis}, make_contiguous(tr_data), make_contiguous(tr_ind));
auto prog = migraphx::parse_onnx("transpose_gather.onnx");
EXPECT(p == prog);
}
TEST_CASE(flatten_test)
{
migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment