Unverified Commit 457703a8 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Rewrite multiplies with dot operator (#1685)

When multiplying either the input or output across the K dimensions then the multiple can be applied to the constant which can then be folded with propagate_const.
parent 7f105952
......@@ -204,6 +204,131 @@ struct find_mul_slice_conv
}
};
struct find_mul_dot
{
auto matcher() const
{
auto is_dot_const_inputs =
match::name("dot")(match::any_of[match::inputs()](match::is_constant()));
return match::name("mul")(match::either_arg(0, 1)(
is_dot_const_inputs.bind("dot"), match::name("broadcast", "multibroadcast").bind("c")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto dot_ins = r.instructions["dot"];
auto a_ins = dot_ins->inputs()[0];
auto b_ins = dot_ins->inputs()[1];
auto c_ins = r.instructions["c"];
const auto& c_strides = c_ins->get_shape().strides();
// There should only be one stride that is not zero
if(std::count_if(c_strides.begin(), c_strides.end(), [](auto s) { return s != 0; }) > 1)
return;
auto add_mul_const = [&](instruction_ref x_ins) {
if(not x_ins->can_eval())
return m.end();
auto broadcast_v = c_ins->get_operator().to_value();
broadcast_v["out_lens"] = x_ins->get_shape().lens();
auto cb_ins =
m.insert_instruction(ins, make_op(c_ins->name(), broadcast_v), c_ins->inputs());
return m.insert_instruction(ins, make_op("mul"), x_ins, cb_ins);
};
if(c_strides.back() == 1)
{
b_ins = add_mul_const(b_ins);
}
else if(c_strides[c_strides.size() - 2] == 1)
{
a_ins = add_mul_const(a_ins);
}
else if(c_ins->get_shape().scalar())
{
if(a_ins->can_eval())
a_ins = add_mul_const(a_ins);
else
b_ins = add_mul_const(b_ins);
}
else
{
return;
}
if(contains({a_ins, b_ins}, m.end()))
return;
m.replace_instruction(ins, make_op("dot"), a_ins, b_ins);
}
};
struct find_dot_mul
{
auto matcher() const
{
auto const_broadcast = match::name("broadcast", "multibroadcast")(match::is_constant());
auto mul = match::name("mul")(
match::used_once(),
match::either_arg(0, 1)(const_broadcast.bind("d"),
match::none_of(match::is_constant()).bind("z")));
return match::name("dot")(match::either_arg(0, 1)(mul, match::is_constant().bind("c")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a_ins = ins->inputs()[0];
auto b_ins = ins->inputs()[1];
auto d_ins = r.instructions["d"];
auto c_ins = r.instructions["c"];
auto z_ins = r.instructions["z"];
const auto& d_strides = d_ins->get_shape().strides();
// There should only be one stride that is not zero
if(std::count_if(d_strides.begin(), d_strides.end(), [](auto s) { return s != 0; }) > 1)
return;
if(not d_ins->get_shape().scalar())
{
if(d_strides.back() == 1 and not b_ins->can_eval())
return;
if(d_strides[d_strides.size() - 2] == 1 and not a_ins->can_eval())
return;
}
auto broadcast_v = d_ins->get_operator().to_value();
auto c_lens = c_ins->get_shape().lens();
std::vector<int64_t> permutation(c_lens.size());
std::iota(permutation.begin(), permutation.end(), 0);
std::swap(permutation.back(), permutation[permutation.size() - 2]);
c_lens = reorder_dims(c_lens, permutation);
broadcast_v["out_lens"] = c_lens;
auto db_ins =
m.insert_instruction(ins, make_op(d_ins->name(), broadcast_v), d_ins->inputs());
auto db_transpose_ins =
m.insert_instruction(ins, make_op("transpose", {{"permutation", permutation}}), db_ins);
auto cd_ins = m.insert_instruction(ins, make_op("mul"), c_ins, db_transpose_ins);
if(c_ins == b_ins)
{
a_ins = z_ins;
b_ins = cd_ins;
}
else
{
a_ins = cd_ins;
b_ins = z_ins;
}
m.replace_instruction(ins, make_op("dot"), a_ins, b_ins);
}
};
// ******************************
// a * (x + b) => a * x + a * b
// ******************************
......@@ -1367,6 +1492,8 @@ void simplify_algebra::apply(module& m) const
find_conv_dot_horiz_fusion{},
find_mul_conv{},
find_mul_slice_conv{},
find_mul_dot{},
find_dot_mul{},
find_mul_add{},
find_unit_ops{},
find_neg_unit_ops{},
......
......@@ -3138,4 +3138,257 @@ TEST_CASE(dot_fusion_reshape)
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(mul_dot_a)
{
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}};
migraphx::module m1;
{
auto a = m1.add_parameter("input", as);
auto lit =
m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 1, 32}}));
auto litb =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", as.lens()}}), lit);
auto mul = m1.add_instruction(migraphx::make_op("mul"), a, litb);
auto b = m1.add_literal(migraphx::generate_literal(bs));
auto dot = m1.add_instruction(migraphx::make_op("dot"), mul, b);
m1.add_return({dot});
};
run_pass(m1);
migraphx::module m2;
{
auto a = m2.add_parameter("input", as);
auto lit =
m2.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 1, 32}}));
auto litb = m2.add_instruction(
migraphx::make_op("multibroadcast",
{{"out_lens", migraphx::reorder_dims(bs.lens(), {0, 2, 1})}}),
lit);
auto litt =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), litb);
auto b = m2.add_literal(migraphx::generate_literal(bs));
auto mul = m2.add_instruction(migraphx::make_op("mul"), b, litt);
auto dot = m2.add_instruction(migraphx::make_op("dot"), a, mul);
m2.add_return({dot});
};
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(mul_dot_b)
{
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}};
migraphx::module m1;
{
auto b = m1.add_parameter("input", bs);
auto lit =
m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 32, 1}}));
auto litb =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", bs.lens()}}), lit);
auto mul = m1.add_instruction(migraphx::make_op("mul"), b, litb);
auto a = m1.add_literal(migraphx::generate_literal(as));
auto dot = m1.add_instruction(migraphx::make_op("dot"), a, mul);
m1.add_return({dot});
};
run_pass(m1);
migraphx::module m2;
{
auto b = m2.add_parameter("input", bs);
auto lit =
m2.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 32, 1}}));
auto litb = m2.add_instruction(
migraphx::make_op("multibroadcast",
{{"out_lens", migraphx::reorder_dims(as.lens(), {0, 2, 1})}}),
lit);
auto litt =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), litb);
auto a = m2.add_literal(migraphx::generate_literal(as));
auto mul = m2.add_instruction(migraphx::make_op("mul"), a, litt);
auto dot = m2.add_instruction(migraphx::make_op("dot"), mul, b);
m2.add_return({dot});
};
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(mul_dot_a_not_k_broadcast)
{
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}};
migraphx::module m1;
{
auto a = m1.add_parameter("input", as);
auto lit =
m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 256, 1}}));
auto litb =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", as.lens()}}), lit);
auto mul = m1.add_instruction(migraphx::make_op("mul"), a, litb);
auto b = m1.add_literal(migraphx::generate_literal(bs));
auto dot = m1.add_instruction(migraphx::make_op("dot"), mul, b);
m1.add_return({dot});
};
migraphx::module m2 = m1;
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(mul_dot_b_not_k_broadcast)
{
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}};
migraphx::module m1;
{
auto b = m1.add_parameter("input", bs);
auto lit =
m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 1, 128}}));
auto litb =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", bs.lens()}}), lit);
auto mul = m1.add_instruction(migraphx::make_op("mul"), b, litb);
auto a = m1.add_literal(migraphx::generate_literal(as));
auto dot = m1.add_instruction(migraphx::make_op("dot"), a, mul);
m1.add_return({dot});
};
migraphx::module m2 = m1;
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(dot_mul_a)
{
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}};
migraphx::module m1;
{
auto a = m1.add_parameter("input", as);
auto b = m1.add_literal(migraphx::generate_literal(bs));
auto dot = m1.add_instruction(migraphx::make_op("dot"), a, b);
auto lit =
m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 1, 128}}));
auto litb = m1.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", dot->get_shape().lens()}}), lit);
auto mul = m1.add_instruction(migraphx::make_op("mul"), dot, litb);
m1.add_return({mul});
};
run_pass(m1);
migraphx::module m2;
{
auto a = m2.add_parameter("input", as);
auto b = m2.add_literal(migraphx::generate_literal(bs));
auto lit =
m2.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 1, 128}}));
auto litb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", bs.lens()}}), lit);
auto mul = m2.add_instruction(migraphx::make_op("mul"), b, litb);
auto dot = m2.add_instruction(migraphx::make_op("dot"), a, mul);
m2.add_return({dot});
};
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(dot_mul_a_non_const)
{
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}};
migraphx::module m1;
{
auto a = m1.add_parameter("input", as);
auto b = m1.add_literal(migraphx::generate_literal(bs));
auto dot = m1.add_instruction(migraphx::make_op("dot"), a, b);
auto lit =
m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 256, 1}}));
auto litb = m1.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", dot->get_shape().lens()}}), lit);
auto mul = m1.add_instruction(migraphx::make_op("mul"), dot, litb);
m1.add_return({mul});
};
migraphx::module m2 = m1;
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(dot_mul_b)
{
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}};
migraphx::module m1;
{
auto a = m1.add_literal(migraphx::generate_literal(as));
auto b = m1.add_parameter("input", bs);
auto dot = m1.add_instruction(migraphx::make_op("dot"), a, b);
auto lit =
m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 256, 1}}));
auto litb = m1.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", dot->get_shape().lens()}}), lit);
auto mul = m1.add_instruction(migraphx::make_op("mul"), dot, litb);
m1.add_return({mul});
};
run_pass(m1);
migraphx::module m2;
{
auto a = m2.add_literal(migraphx::generate_literal(as));
auto b = m2.add_parameter("input", bs);
auto lit =
m2.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 256, 1}}));
auto litb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", as.lens()}}), lit);
auto mul = m2.add_instruction(migraphx::make_op("mul"), a, litb);
auto dot = m2.add_instruction(migraphx::make_op("dot"), mul, b);
m2.add_return({dot});
};
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(dot_mul_b_non_const)
{
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}};
migraphx::module m1;
{
auto a = m1.add_literal(migraphx::generate_literal(as));
auto b = m1.add_parameter("input", bs);
auto dot = m1.add_instruction(migraphx::make_op("dot"), a, b);
auto lit =
m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 1, 128}}));
auto litb = m1.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", dot->get_shape().lens()}}), lit);
auto mul = m1.add_instruction(migraphx::make_op("mul"), dot, litb);
m1.add_return({mul});
};
migraphx::module m2 = m1;
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
struct test_dot_mul_a : verify_program<test_dot_mul_a>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}};
auto a = mm->add_parameter("input", as);
auto b = mm->add_literal(migraphx::generate_literal(bs));
auto dot = mm->add_instruction(migraphx::make_op("dot"), a, b);
auto lit =
mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 1, 128}}));
auto litb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", dot->get_shape().lens()}}), lit);
auto mul = mm->add_instruction(migraphx::make_op("mul"), dot, litb);
mm->add_return({mul});
return p;
}
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
struct test_dot_mul_b : verify_program<test_dot_mul_b>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}};
auto a = mm->add_literal(migraphx::generate_literal(as));
auto b = mm->add_parameter("input", bs);
auto dot = mm->add_instruction(migraphx::make_op("dot"), a, b);
auto lit =
mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 256, 1}}));
auto litb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", dot->get_shape().lens()}}), lit);
auto mul = mm->add_instruction(migraphx::make_op("mul"), dot, litb);
mm->add_return({mul});
return p;
}
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_mul_dot_a : verify_program<test_mul_dot_a>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}};
auto a = mm->add_parameter("input", as);
auto lit =
mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 1, 32}}));
auto litb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", as.lens()}}), lit);
auto mul = mm->add_instruction(migraphx::make_op("mul"), a, litb);
auto b = mm->add_literal(migraphx::generate_literal(bs));
auto dot = mm->add_instruction(migraphx::make_op("dot"), mul, b);
mm->add_return({dot});
return p;
}
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_mul_dot_b : verify_program<test_mul_dot_b>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}};
auto b = mm->add_parameter("input", bs);
auto lit =
mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 32, 1}}));
auto litb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", bs.lens()}}), lit);
auto mul = mm->add_instruction(migraphx::make_op("mul"), b, litb);
auto a = mm->add_literal(migraphx::generate_literal(as));
auto dot = mm->add_instruction(migraphx::make_op("dot"), a, mul);
mm->add_return({dot});
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment