"...resnet50_tensorflow.git" did not exist on "4577d2c91f34dae1b6133188b51dba424b92ddb5"
Unverified Commit 8752875a authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Improvements to handling and add constant passed to dot operator (#1280)

This will rewrite dot operators like X(Y + b) to XY + Xb when b is constant as we can fold the add away.
This improves handling pointwise with broadcasted operators, this helps improves const propagation.
Improve gemm fusion with a mul_add
Improve support for broadcast shapes in gemm
parent af7f22d8
...@@ -564,6 +564,11 @@ MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref in ...@@ -564,6 +564,11 @@ MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref in
return nullopt; return nullopt;
} }
MIGRAPHX_PRED_MATCHER(broadcast, instruction_ref ins)
{
return contains({"broadcast", "multibroadcast"}, ins->name());
}
template <class... Ms> template <class... Ms>
auto skip(Ms... ms) auto skip(Ms... ms)
{ {
...@@ -813,8 +818,7 @@ inline auto has_attribute(const std::string& name) ...@@ -813,8 +818,7 @@ inline auto has_attribute(const std::string& name)
template <class... Ms> template <class... Ms>
auto pointwise(Ms... ms) auto pointwise(Ms... ms)
{ {
return match::has_attribute("pointwise")(match::any_of(match::nargs(1), match::nargs(2)), return match::has_attribute("pointwise")(ms...);
ms...);
} }
} // namespace match } // namespace match
......
...@@ -208,6 +208,42 @@ struct find_mul_add ...@@ -208,6 +208,42 @@ struct find_mul_add
} }
}; };
struct find_dot_add
{
auto matcher() const
{
return match::name("dot")(match::either_arg(0, 1)(
match::name("add")(
match::either_arg(0, 1)(match::any().bind("x"),
match::any_of(match::is_constant()).bind("b")),
match::none_of(match::args(match::is_constant(), match::is_constant())),
match::used_once()),
match::is_constant().bind("a")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"];
auto x_ins = r.instructions["x"];
assert(x_ins != b_ins);
const bool flipped = a_ins == ins->inputs().back();
auto insert_dot = [&](auto x, auto y) {
if(flipped)
return m.insert_instruction(ins, make_op("dot"), y, x);
else
return m.insert_instruction(ins, make_op("dot"), x, y);
};
auto ax_ins = insert_dot(a_ins, x_ins);
auto ab_ins = insert_dot(a_ins, b_ins);
m.replace_instruction(ins, make_op("add"), ax_ins, ab_ins);
}
};
struct find_add_lit_broadcast struct find_add_lit_broadcast
{ {
auto matcher() const auto matcher() const
...@@ -267,28 +303,26 @@ struct find_double_add_lit_broadcast ...@@ -267,28 +303,26 @@ struct find_double_add_lit_broadcast
struct find_inner_broadcast struct find_inner_broadcast
{ {
auto matcher() const auto matcher() const { return pointwise(match::all_of[match::inputs()](match::broadcast())); }
{
return pointwise(
match::nargs(2),
match::args(match::name("broadcast").bind("x"), match::name("broadcast").bind("y")));
}
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto broadcasts = ins->inputs();
auto y_ins = r.instructions["y"]; if(broadcasts.empty())
return;
auto xbroadcast = any_cast<op::broadcast>(x_ins->get_operator()); std::vector<instruction_ref> inputs;
auto ybroadcast = any_cast<op::broadcast>(y_ins->get_operator()); std::transform(broadcasts.begin(),
broadcasts.end(),
if(xbroadcast.axis != ybroadcast.axis) std::back_inserter(inputs),
[](auto i) { return i->inputs().front(); });
if(std::any_of(inputs.begin(), inputs.end(), [&](auto i) {
return i->get_shape() != inputs.front()->get_shape();
}))
return; return;
auto op = m.insert_instruction( auto op = m.insert_instruction(ins, ins->get_operator(), inputs);
ins, ins->get_operator(), x_ins->inputs().front(), y_ins->inputs().front()); m.replace_instruction(ins, broadcasts.front()->get_operator(), op);
m.replace_instruction(ins, xbroadcast, op);
} }
}; };
...@@ -416,8 +450,9 @@ struct find_splits ...@@ -416,8 +450,9 @@ struct find_splits
{ {
auto matcher() const auto matcher() const
{ {
return match::any(match::any_of[match::outputs()](match::name("slice")( return match::any(
match::any_of[match::outputs()](match::pointwise(), reduction())))); match::any_of[match::outputs()](match::name("slice")(match::any_of[match::outputs()](
match::pointwise(match::any_of(match::nargs(1), match::nargs(2))), reduction()))));
} }
static bool is_dependent(const module& m, instruction_ref ins1, instruction_ref ins2) static bool is_dependent(const module& m, instruction_ref ins1, instruction_ref ins2)
...@@ -1048,6 +1083,7 @@ void simplify_algebra::apply(module& m) const ...@@ -1048,6 +1083,7 @@ void simplify_algebra::apply(module& m) const
find_mul_conv{}, find_mul_conv{},
find_mul_slice_conv{}, find_mul_slice_conv{},
find_mul_add{}, find_mul_add{},
find_dot_add{},
find_div_const{}, find_div_const{},
find_sub_const{}, find_sub_const{},
find_rsqrt{}, find_rsqrt{},
......
...@@ -48,8 +48,10 @@ ...@@ -48,8 +48,10 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
#include <migraphx/array.hpp> #include <migraphx/array.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/clip.hpp> #include <migraphx/op/clip.hpp>
#include <migraphx/op/contiguous.hpp>
#include <cmath> #include <cmath>
#include <set> #include <set>
...@@ -279,6 +281,11 @@ MIGRAPHX_REGISTER_OP(hip_layernorm) ...@@ -279,6 +281,11 @@ MIGRAPHX_REGISTER_OP(hip_layernorm)
struct hip_triadd_layernorm : ternary_device<hip_triadd_layernorm, &device::triadd_layernorm> struct hip_triadd_layernorm : ternary_device<hip_triadd_layernorm, &device::triadd_layernorm>
{ {
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(4).standard();
return inputs[0];
}
// Empty finalize to skip dimension reduction // Empty finalize to skip dimension reduction
void finalize(context&, const shape&, const std::vector<shape>&) {} void finalize(context&, const shape&, const std::vector<shape>&) {}
}; };
...@@ -943,28 +950,70 @@ struct find_gemm_add ...@@ -943,28 +950,70 @@ struct find_gemm_add
} }
}; };
auto pointwise_name(const std::string& s)
{
return precompile_name("pointwise")(match::make_basic_pred_matcher([=](auto ins) {
module_ref pm = ins->module_inputs().front();
auto n = std::count_if(pm->begin(), pm->end(), [&](auto& i) { return i.name() == s; });
if(n != 1)
return false;
return std::all_of(pm->begin(), pm->end(), [&](auto& i) {
return starts_with(i.name(), "@") or i.name() == s;
});
}));
}
struct find_gemm_pointwise struct find_gemm_pointwise
{ {
auto matcher() const auto matcher() const
{ {
return pointwise_name("add")( return precompile_name("pointwise")(
match::nargs(3), match::nargs(3),
match::all_of[match::inputs()](match::standard_shape()), match::either_arg(0, 1)(
match::either_arg(0, 1)(match::used_once().bind("c"), match::any_of(match::standard_shape(), match::is_constant()).bind("c"),
match::name("gpu::gemm")(match::nargs(3)).bind("gemm"))); match::name("gpu::gemm")(match::nargs(3), match::used_once()).bind("gemm")));
}
// TODO: Move to matcher.hpp
static auto match_param(const std::string& name)
{
return match::make_basic_pred_matcher([=](auto ins) {
if(ins->name() != "@param")
return false;
auto p = any_cast<builtin::param>(ins->get_operator());
return p.parameter == name;
});
}
template <class M>
static auto match_mul_const(M m, const std::string& var)
{
return match::name("mul")(match::either_arg(0, 1)(match::name("@literal").bind(var), m))
.bind(var + "_mul");
}
static auto match_add(const std::string& input, const std::string& output)
{
auto param = match::name("@param");
auto add = match::name("add")(match::args(param, param));
auto inner_mul = match::any_of(match_mul_const(match_param(input), "alpha"),
match_mul_const(match_param(output), "beta"));
auto mul_add = match::name("add")(match::either_arg(0, 1)(inner_mul, param));
auto add_mul = match_mul_const(add, "gamma");
return match::name("@return")(match::args(match::any_of(add, mul_add, add_mul)));
}
static float get_float(instruction_ref ins) { return ins->get_literal().at<float>(); }
template <class Gemm>
static bool update_gemm(Gemm& gemm, module_ref pm, unsigned input)
{
auto names = pm->get_parameter_names();
if(names.size() != 2)
return false;
std::sort(names.begin(), names.end());
unsigned output = input == 0 ? 1 : 0;
auto mr = match::match_instruction(
*pm, std::prev(pm->end()), match_add(names[input], names[output]));
if(mr.result == pm->end())
return false;
if(contains(mr.instructions, "alpha_mul"))
gemm.alpha *= get_float(mr.instructions["alpha"]);
else if(contains(mr.instructions, "beta_mul"))
gemm.beta *= get_float(mr.instructions["beta"]);
else if(contains(mr.instructions, "gamma_mul"))
{
gemm.alpha *= get_float(mr.instructions["gamma"]);
gemm.beta *= get_float(mr.instructions["gamma"]);
}
return true;
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
...@@ -978,6 +1027,19 @@ struct find_gemm_pointwise ...@@ -978,6 +1027,19 @@ struct find_gemm_pointwise
// Already fused gemm // Already fused gemm
if(not float_equal(gemm.beta, 0)) if(not float_equal(gemm.beta, 0))
return; return;
gemm.beta = 1;
if(not update_gemm(
gemm, ins->module_inputs().front(), ins->inputs().front() == gemm_ins ? 0 : 1))
return;
// const-fold input if not standard shape since rocblas can't handle it
if(not c_ins->get_shape().standard())
{
auto c = op::contiguous{};
auto l = c.compute(c.compute_shape({c_ins->get_shape()}), {c_ins->eval()});
c_ins = m.add_literal(l.get_shape(), l.data());
}
auto inputs = gemm_ins->inputs(); auto inputs = gemm_ins->inputs();
inputs.pop_back(); inputs.pop_back();
...@@ -985,11 +1047,68 @@ struct find_gemm_pointwise ...@@ -985,11 +1047,68 @@ struct find_gemm_pointwise
inputs.push_back(c_ins); inputs.push_back(c_ins);
inputs.push_back(ins->inputs().back()); inputs.push_back(ins->inputs().back());
gemm.beta = 1;
m.replace_instruction(ins, gemm, inputs); m.replace_instruction(ins, gemm, inputs);
} }
}; };
struct find_contiguous_tranpose_gemm
{
auto matcher() const
{
return match::name("gpu::contiguous")(match::arg(0)(
match::name("transpose")(
match::arg(0)(match::name("gpu::gemm")(match::used_once()).bind("gemm")))
.bind("transpose")));
}
template <class Vector>
static bool is_swapped(const Vector& perm, std::size_t i, std::size_t j)
{
if(i >= perm.size() or j >= perm.size())
return false;
auto perm2 = perm;
std::iota(perm2.begin(), perm2.end(), 0);
std::swap(perm2[i], perm2[j]);
return perm2 == perm;
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto gemm = r.instructions["gemm"];
auto alloc = gemm->inputs().back();
auto transpose = r.instructions["transpose"];
auto perm = transpose->get_operator().to_value()["permutation"].to_vector<int64_t>();
auto iperm = invert_permutation(perm);
if(perm.size() < 3)
return;
if(not is_swapped(perm, perm.size() - 3, perm.size() - 2))
return;
auto lens = gemm->get_shape().lens();
if(lens.size() > 3 and
not std::all_of(lens.begin(), lens.end() - 3, [](auto i) { return i == 1; }))
return;
auto gemmv = gemm->get_operator().to_value();
gemmv["trans_batch"] = 1;
auto s = shape{alloc->get_shape().type(), reorder_dims(alloc->get_shape().lens(), iperm)};
auto new_alloc = m.insert_instruction(gemm, make_op("allocate", {{"shape", to_value(s)}}));
auto alloc_transpose =
m.insert_instruction(gemm, make_op("transpose", {{"permutation", perm}}), new_alloc);
auto inputs = gemm->inputs();
inputs.back() = alloc_transpose;
auto new_gemm = m.insert_instruction(gemm, make_op("gpu::gemm", gemmv), inputs);
auto gemm_transpoe = m.insert_instruction(gemm, transpose->get_operator(), new_gemm);
m.replace_instruction(ins, gemm_transpoe);
}
};
struct find_commutative_broadcast struct find_commutative_broadcast
{ {
auto matcher() const auto matcher() const
...@@ -1091,6 +1210,7 @@ void fuse_ops::apply(module& m) const ...@@ -1091,6 +1210,7 @@ void fuse_ops::apply(module& m) const
find_gemm_add{}, find_gemm_add{},
find_layernorm_pointwise{}, find_layernorm_pointwise{},
find_gemm_pointwise{}, find_gemm_pointwise{},
find_contiguous_tranpose_gemm{},
find_commutative_broadcast{}); find_commutative_broadcast{});
match::find_matches(m, find_contiguous{}); match::find_matches(m, find_contiguous{});
} }
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include <rocblas.h> #include <rocblas.h>
#include <migraphx/gpu/gemm_impl.hpp> #include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/reduce_dims.hpp> #include <migraphx/reduce_dims.hpp>
#include <migraphx/permutation.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -67,6 +68,19 @@ void blas_shape(const shape& s) ...@@ -67,6 +68,19 @@ void blas_shape(const shape& s)
MIGRAPHX_THROW("GPU_GEMM: Batch dimension is not collapsible"); MIGRAPHX_THROW("GPU_GEMM: Batch dimension is not collapsible");
} }
shape transpose_batch(const shape& s, unsigned trans_batch)
{
if(trans_batch == 0)
return s;
if(s.lens().size() < 3)
return s;
auto batch = s.lens().size() - 3;
std::vector<int64_t> perm(s.lens().size());
std::iota(perm.begin(), perm.end(), 0);
std::swap(perm[batch], perm[batch + trans_batch]);
return shape::from_permutation(s.type(), s.lens(), perm);
}
template <class R, class... Ts, class... Us> template <class R, class... Ts, class... Us>
R rocblas_invoke(R (*f)(Ts...), Us... xs) R rocblas_invoke(R (*f)(Ts...), Us... xs)
{ {
...@@ -97,6 +111,12 @@ void gemm_impl(context& ctx, ...@@ -97,6 +111,12 @@ void gemm_impl(context& ctx,
bool int8_x4_format, bool int8_x4_format,
bool compute_fp32) bool compute_fp32)
{ {
const bool is_3inputs = (args.size() == 4);
if(!is_3inputs)
{
beta = 0;
}
bool transa = is_transposed(args[0].get_shape()); bool transa = is_transposed(args[0].get_shape());
bool transb = is_transposed(args[1].get_shape()); bool transb = is_transposed(args[1].get_shape());
auto n_dim = output_shape.lens().size(); auto n_dim = output_shape.lens().size();
...@@ -105,12 +125,8 @@ void gemm_impl(context& ctx, ...@@ -105,12 +125,8 @@ void gemm_impl(context& ctx,
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0]; rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0]; rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = args[2].get_shape().strides()[dim_0]; rocblas_int ldc = args[2].get_shape().strides()[dim_0];
rocblas_int ldd = is_3inputs ? args[3].get_shape().strides()[dim_0] : ldc;
bool is_3inputs = (args.size() == 4);
if(!is_3inputs)
{
beta = 0;
}
rocblas_datatype arg_type = get_type(args[0].get_shape().type()); rocblas_datatype arg_type = get_type(args[0].get_shape().type());
auto output_type = arg_type; auto output_type = arg_type;
if(output_type == rocblas_datatype_i8_r) if(output_type == rocblas_datatype_i8_r)
...@@ -186,7 +202,7 @@ void gemm_impl(context& ctx, ...@@ -186,7 +202,7 @@ void gemm_impl(context& ctx,
ldc, ldc,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]), is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type, output_type,
ldc, ldd,
compute_type, compute_type,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
0, 0,
...@@ -197,6 +213,7 @@ void gemm_impl(context& ctx, ...@@ -197,6 +213,7 @@ void gemm_impl(context& ctx,
auto a_stride = get_batch_stride(args[0]); auto a_stride = get_batch_stride(args[0]);
auto b_stride = get_batch_stride(args[1]); auto b_stride = get_batch_stride(args[1]);
auto c_stride = get_batch_stride(args[2]); auto c_stride = get_batch_stride(args[2]);
auto d_stride = is_3inputs ? get_batch_stride(args[3]) : c_stride;
rocblas_invoke(&rocblas_gemm_strided_batched_ex, rocblas_invoke(&rocblas_gemm_strided_batched_ex,
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
...@@ -220,8 +237,8 @@ void gemm_impl(context& ctx, ...@@ -220,8 +237,8 @@ void gemm_impl(context& ctx,
c_stride, c_stride,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]), is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type, output_type,
ldc, ldd,
c_stride, d_stride,
num_matrices, num_matrices,
compute_type, compute_type,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
......
...@@ -42,15 +42,17 @@ namespace gpu { ...@@ -42,15 +42,17 @@ namespace gpu {
struct context; struct context;
void blas_shape(const shape& s); void blas_shape(const shape& s);
shape transpose_batch(const shape& s, unsigned trans_batch);
template <class Op> template <class Op>
struct rocblas_gemm struct rocblas_gemm
{ {
Op op; Op op;
float alpha = 1; float alpha = 1;
float beta = 0; float beta = 0;
bool int8_x4_format = true; bool int8_x4_format = true;
bool compute_fp32 = false; bool compute_fp32 = false;
unsigned trans_batch = 0;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -58,7 +60,9 @@ struct rocblas_gemm ...@@ -58,7 +60,9 @@ struct rocblas_gemm
return pack_join(migraphx::reflect(self.op, f), return pack_join(migraphx::reflect(self.op, f),
pack(f(self.alpha, "alpha"), pack(f(self.alpha, "alpha"),
f(self.beta, "beta"), f(self.beta, "beta"),
f(self.int8_x4_format, "int8_x4_format"))); f(self.int8_x4_format, "int8_x4_format"),
f(self.compute_fp32, "compute_fp32"),
f(self.trans_batch, "trans_batch")));
} }
std::string name() const std::string name() const
...@@ -74,13 +78,14 @@ struct rocblas_gemm ...@@ -74,13 +78,14 @@ struct rocblas_gemm
{ {
std::vector<shape> in_shapes(inputs); std::vector<shape> in_shapes(inputs);
in_shapes.pop_back(); in_shapes.pop_back();
check_shapes{in_shapes, *this}.not_broadcasted(); check_shapes{in_shapes, *this}.has(2, 3);
blas_shape(inputs[0]); blas_shape(inputs[0]);
blas_shape(inputs[1]); blas_shape(inputs[1]);
// if gemm and add are fused // if gemm and add are fused
if(in_shapes.size() > 2) if(in_shapes.size() > 2)
{ {
auto cmat_shape = in_shapes.back(); auto cmat_shape = in_shapes.back();
check_shapes{{cmat_shape}, *this}.not_transposed().not_broadcasted();
in_shapes.pop_back(); in_shapes.pop_back();
blas_shape(cmat_shape); blas_shape(cmat_shape);
auto op_out_shape = op.compute_shape(in_shapes); auto op_out_shape = op.compute_shape(in_shapes);
...@@ -97,10 +102,10 @@ struct rocblas_gemm ...@@ -97,10 +102,10 @@ struct rocblas_gemm
to_string(cmat_shape.type()) + to_string(cmat_shape.type()) +
", it must be: " + to_string(op_out_shape.type())); ", it must be: " + to_string(op_out_shape.type()));
} }
return op_out_shape; return transpose_batch(op_out_shape, trans_batch);
} }
return op.compute_shape(in_shapes); return transpose_batch(op.compute_shape(in_shapes), trans_batch);
} }
argument argument
......
...@@ -134,8 +134,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -134,8 +134,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
lowering{&ctx, options.offload_copy}, lowering{&ctx, options.offload_copy},
eliminate_contiguous{"gpu::contiguous"}, eliminate_contiguous{"gpu::contiguous"},
dead_code_elimination{}, dead_code_elimination{},
replace_allocate{gpu_allocation_model{}, options.offload_copy},
dead_code_elimination{},
eliminate_concat{concat_gpu_optimization{}}, eliminate_concat{concat_gpu_optimization{}},
dead_code_elimination{}, dead_code_elimination{},
pack_int8_args{}, pack_int8_args{},
...@@ -144,6 +142,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -144,6 +142,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{}, dead_code_elimination{},
fuse_ops{&ctx, options.fast_math}, fuse_ops{&ctx, options.fast_math},
dead_code_elimination{}, dead_code_elimination{},
replace_allocate{gpu_allocation_model{}, options.offload_copy},
dead_code_elimination{},
compile_ops{&ctx}, compile_ops{&ctx},
dead_code_elimination{}, dead_code_elimination{},
write_literals{&ctx}, write_literals{&ctx},
......
...@@ -186,9 +186,10 @@ struct nop ...@@ -186,9 +186,10 @@ struct nop
migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const { return {}; } migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const { return {}; }
}; };
inline migraphx::literal get_2x2() inline migraphx::literal get_2x2(int base = 0)
{ {
return migraphx::literal{{migraphx::shape::float_type, {2, 2}}, {1, 2, 3, 4}}; return migraphx::literal{{migraphx::shape::float_type, {2, 2}},
{base + 1, base + 2, base + 3, base + 4}};
} }
inline migraphx::literal get_2x2_transposed() inline migraphx::literal get_2x2_transposed()
......
...@@ -358,7 +358,33 @@ TEST_CASE(simplify_mul_add) ...@@ -358,7 +358,33 @@ TEST_CASE(simplify_mul_add)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(simplify_inner_broadcast) TEST_CASE(simplify_dot_add)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto one = m1.add_literal(get_2x2());
auto two = m1.add_literal(get_2x2(1));
auto sum = m1.add_instruction(migraphx::make_op("add"), one, x);
auto dot = m1.add_instruction(migraphx::make_op("dot"), sum, two);
m1.add_instruction(pass_op{}, dot);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto one = m2.add_literal(get_2x2());
auto two = m2.add_literal(get_2x2(1));
auto dot1 = m2.add_instruction(migraphx::make_op("dot"), x, two);
auto dot2 = m2.add_instruction(migraphx::make_op("dot"), one, two);
auto sum = m2.add_instruction(migraphx::make_op("add"), dot1, dot2);
m2.add_instruction(pass_op{}, sum);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_inner_broadcast1)
{ {
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}}; auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
migraphx::module m1; migraphx::module m1;
...@@ -383,6 +409,31 @@ TEST_CASE(simplify_inner_broadcast) ...@@ -383,6 +409,31 @@ TEST_CASE(simplify_inner_broadcast)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(simplify_inner_broadcast2)
{
auto b = migraphx::op::multibroadcast{{2, 1, 4, 5}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1, 1, 1, 1}});
auto y = m1.add_parameter("y", {migraphx::shape::int32_type, {1, 1, 1, 1}});
auto xb = m1.add_instruction(b, x);
auto yb = m1.add_instruction(b, y);
auto sum = m1.add_instruction(migraphx::make_op("add"), xb, yb);
m1.add_instruction(pass_op{}, sum);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1, 1, 1, 1}});
auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {1, 1, 1, 1}});
auto sum = m2.add_instruction(migraphx::make_op("add"), x, y);
auto sumb = m2.add_instruction(b, sum);
m2.add_instruction(pass_op{}, sumb);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_add_conv1) TEST_CASE(simplify_add_conv1)
{ {
migraphx::module m; migraphx::module m;
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/apply_alpha_beta.hpp>
struct gemm_add_broadcast1 : verify_program<gemm_add_broadcast1>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}};
migraphx::shape m3_shape{migraphx::shape::float_type, {1, 1, 4}};
auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape);
auto l3 = mm->add_parameter("3", m3_shape);
auto l3_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 2, 4}}}), l3);
auto dot = mm->add_instruction(migraphx::make_op("dot"), l1, l2);
mm->add_instruction(migraphx::make_op("add"), dot, l3_b);
return p;
}
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/apply_alpha_beta.hpp>
struct gemm_add_broadcast2 : verify_program<gemm_add_broadcast2>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}};
migraphx::shape m3_shape{migraphx::shape::float_type, {1, 2, 1}};
auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape);
auto l3 = mm->add_parameter("3", m3_shape);
auto l3_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 2, 4}}}), l3);
auto dot = mm->add_instruction(migraphx::make_op("dot"), l1, l2);
mm->add_instruction(migraphx::make_op("add"), dot, l3_b);
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment