"library/vscode:/vscode.git/clone" did not exist on "5af78ac26ba2e4f0464b63fb159e77b73307708d"
Unverified Commit 21193e87 authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Remove alpha and beta from `dot` and `quant_dot` (#961)

Previously dot operator was defined as C = alpha * A . B + beta * C where * is scalar multiplication and . is dot product or matrix multiplication depending on dimension of the inputs.

Aim is to have the definition of dot operator as C = A . B without having alpha or beta.

In order to achieve the same effect as alpha and beta (1) it multiplies the one of the inputs to the dot operator with alpha value. (2) if beta is present then, multiplies the C with beta and then adds into the output from step 1.
parent 87978f03
...@@ -7,6 +7,7 @@ include(CheckCXXLinkerFlag) ...@@ -7,6 +7,7 @@ include(CheckCXXLinkerFlag)
add_library(migraphx add_library(migraphx
adjust_allocation.cpp adjust_allocation.cpp
analyze_streams.cpp analyze_streams.cpp
apply_alpha_beta.cpp
argument.cpp argument.cpp
auto_contiguous.cpp auto_contiguous.cpp
common.cpp common.cpp
...@@ -14,7 +15,6 @@ add_library(migraphx ...@@ -14,7 +15,6 @@ add_library(migraphx
convert_to_json.cpp convert_to_json.cpp
cpp_generator.cpp cpp_generator.cpp
dead_code_elimination.cpp dead_code_elimination.cpp
decompose.cpp
dom_info.cpp dom_info.cpp
dynamic_loader.cpp dynamic_loader.cpp
eliminate_allocation.cpp eliminate_allocation.cpp
...@@ -52,7 +52,6 @@ add_library(migraphx ...@@ -52,7 +52,6 @@ add_library(migraphx
reduce_dims.cpp reduce_dims.cpp
register_op.cpp register_op.cpp
register_target.cpp register_target.cpp
remap.cpp
simplify_qdq.cpp simplify_qdq.cpp
rewrite_batchnorm.cpp rewrite_batchnorm.cpp
rewrite_pooling.cpp rewrite_pooling.cpp
......
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/common.hpp>
#include <migraphx/apply_alpha_beta.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
instruction_ref insert_apply_alpha_beta(module& m,
instruction_ref pos,
const std::vector<instruction_ref>& args,
const operation& op,
const literal& alpha,
const literal& beta)
{
auto a = args[0];
auto b = args[1];
auto input_type = a->get_shape().type();
if(!float_equal(alpha.at<float>(0), 1.0))
{
auto alpha_literal = m.add_literal(alpha);
a = insert_common_op(m, pos, migraphx::make_op("mul"), {alpha_literal, a});
if(a->get_shape().type() != input_type)
{
a = m.insert_instruction(pos, make_op("convert", {{"target_type", input_type}}), a);
}
}
auto op_res = m.insert_instruction(pos, op, a, b);
if(args.size() == 3)
{
if(not float_equal(beta.at<float>(0), 0.0) && args[2]->get_shape().elements() > 0)
{
auto out_lens = op_res->get_shape().lens();
auto c = args[2];
auto c_lens = c->get_shape().lens();
input_type = c->get_shape().type();
if(out_lens != c_lens)
{
c = m.insert_instruction(
pos, migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), args[2]);
}
auto beta_literal = m.add_literal(beta);
auto beta_c = insert_common_op(m, pos, migraphx::make_op("mul"), {c, beta_literal});
if(beta_c->get_shape().type() != input_type)
{
beta_c = m.insert_instruction(
pos, migraphx::make_op("convert", {{"target_type", input_type}}), beta_c);
}
return m.insert_instruction(pos, migraphx::make_op("add"), op_res, beta_c);
}
}
return op_res;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
File mode changed from 100755 to 100644
#include <migraphx/decompose.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace {
struct alpha_beta
{
float alpha = 0.0;
float beta = 0.0;
};
alpha_beta get_alpha_beta(const operation& op)
{
auto v = op.to_value();
return {v.at("alpha").to<float>(), v.at("beta").to<float>()};
}
struct find_dot_add
{
auto matcher() const { return match::name("dot", "quant_dot")(match::nargs(3)); }
void apply(module& p, const match::matcher_result& r) const
{
auto ins = r.result;
auto dot = get_alpha_beta(ins->get_operator());
auto a_ins = ins->inputs()[0];
auto b_ins = ins->inputs()[1];
if(not float_equal(dot.alpha, 1))
{
auto alpha = p.add_literal(literal{shape{a_ins->get_shape().type()}, {dot.alpha}});
auto alpha_broadcast = p.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", a_ins->get_shape().lens()}}), alpha);
a_ins = p.insert_instruction(ins, make_op("mul"), a_ins, alpha_broadcast);
}
auto dot_ins = p.insert_instruction(ins, make_op(ins->name(), {{"beta", 0}}), a_ins, b_ins);
auto c_ins = ins->inputs()[2];
if(not float_equal(dot.beta, 1))
{
auto beta = p.add_literal(literal{shape{c_ins->get_shape().type()}, {dot.beta}});
auto beta_broadcast = p.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", ins->get_shape().lens()}}), beta);
c_ins = p.insert_instruction(ins, make_op("mul"), c_ins, beta_broadcast);
}
p.replace_instruction(ins, make_op("add"), dot_ins, c_ins);
}
};
struct find_dot_alpha
{
auto matcher() const { return match::name("dot", "quant_dot")(match::nargs(2)); }
void apply(module& p, const match::matcher_result& r) const
{
auto ins = r.result;
auto dot = get_alpha_beta(ins->get_operator());
auto a_ins = ins->inputs()[0];
auto b_ins = ins->inputs()[1];
if(not float_equal(dot.alpha, 1))
{
auto alpha = p.add_literal(literal{shape{a_ins->get_shape().type()}, {dot.alpha}});
auto alpha_broadcast = p.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", a_ins->get_shape().lens()}}), alpha);
a_ins = p.insert_instruction(ins, make_op("mul"), a_ins, alpha_broadcast);
}
p.replace_instruction(ins, make_op(ins->name(), {{"beta", 0}}), a_ins, b_ins);
}
};
} // namespace
void decompose::apply(module& p) const { match::find_matches(p, find_dot_add{}, find_dot_alpha{}); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include "models.hpp" #include "models.hpp"
namespace migraphx { namespace migraphx {
...@@ -144,10 +145,10 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size) ...@@ -144,10 +145,10 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::multibroadcast multibroadcast42; migraphx::op::multibroadcast multibroadcast42;
multibroadcast42.output_lens = {batch, 4096}; multibroadcast42.output_lens = {batch, 4096};
auto mx42 = mm->add_instruction(multibroadcast42, mx4); auto mx42 = mm->add_instruction(multibroadcast42, mx4);
migraphx::op::dot dot43; float dot43_alpha = 1;
dot43.alpha = 1; float dot43_beta = 1;
dot43.beta = 1; auto mx43 = migraphx::add_apply_alpha_beta(
auto mx43 = mm->add_instruction(dot43, mx40, mx41, mx42); *mm, {mx40, mx41, mx42}, migraphx::make_op("dot"), dot43_alpha, dot43_beta);
migraphx::op::relu relu44; migraphx::op::relu relu44;
auto mx44 = mm->add_instruction(relu44, mx43); auto mx44 = mm->add_instruction(relu44, mx43);
migraphx::op::identity identity45; migraphx::op::identity identity45;
...@@ -158,10 +159,10 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size) ...@@ -158,10 +159,10 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::multibroadcast multibroadcast47; migraphx::op::multibroadcast multibroadcast47;
multibroadcast47.output_lens = {batch, 4096}; multibroadcast47.output_lens = {batch, 4096};
auto mx47 = mm->add_instruction(multibroadcast47, mx2); auto mx47 = mm->add_instruction(multibroadcast47, mx2);
migraphx::op::dot dot48; float dot48_alpha = 1;
dot48.alpha = 1; float dot48_beta = 1;
dot48.beta = 1; auto mx48 = migraphx::add_apply_alpha_beta(
auto mx48 = mm->add_instruction(dot48, mx45, mx46, mx47); *mm, {mx45, mx46, mx47}, migraphx::make_op("dot"), dot48_alpha, dot48_beta);
migraphx::op::relu relu49; migraphx::op::relu relu49;
auto mx49 = mm->add_instruction(relu49, mx48); auto mx49 = mm->add_instruction(relu49, mx48);
migraphx::op::transpose transpose50; migraphx::op::transpose transpose50;
...@@ -170,10 +171,10 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size) ...@@ -170,10 +171,10 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::multibroadcast multibroadcast51; migraphx::op::multibroadcast multibroadcast51;
multibroadcast51.output_lens = {batch, 1000}; multibroadcast51.output_lens = {batch, 1000};
auto mx51 = mm->add_instruction(multibroadcast51, mx0); auto mx51 = mm->add_instruction(multibroadcast51, mx0);
migraphx::op::dot dot52; float dot52_alpha = 1;
dot52.alpha = 1; float dot52_beta = 1;
dot52.beta = 1; migraphx::add_apply_alpha_beta(
mm->add_instruction(dot52, mx49, mx50, mx51); *mm, {mx49, mx50, mx51}, migraphx::make_op("dot"), dot52_alpha, dot52_beta);
return p; return p;
} }
......
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include "models.hpp" #include "models.hpp"
namespace migraphx { namespace migraphx {
...@@ -2225,10 +2226,10 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -2225,10 +2226,10 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::multibroadcast multibroadcast798; migraphx::op::multibroadcast multibroadcast798;
multibroadcast798.output_lens = {batch, 1000}; multibroadcast798.output_lens = {batch, 1000};
auto mx798 = mm->add_instruction(multibroadcast798, mx0); auto mx798 = mm->add_instruction(multibroadcast798, mx0);
migraphx::op::dot dot799; float dot799_alpha = 1;
dot799.alpha = 1; float dot799_beta = 1;
dot799.beta = 1; migraphx::add_apply_alpha_beta(
mm->add_instruction(dot799, mx796, mx797, mx798); *mm, {mx796, mx797, mx798}, migraphx::make_op("dot"), dot799_alpha, dot799_beta);
return p; return p;
} }
......
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include "models.hpp" #include "models.hpp"
namespace migraphx { namespace migraphx {
...@@ -1228,10 +1229,10 @@ migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size) ...@@ -1228,10 +1229,10 @@ migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::multibroadcast multibroadcast442; migraphx::op::multibroadcast multibroadcast442;
multibroadcast442.output_lens = {batch, 1000}; multibroadcast442.output_lens = {batch, 1000};
auto mx442 = mm->add_instruction(multibroadcast442, mx0); auto mx442 = mm->add_instruction(multibroadcast442, mx0);
migraphx::op::dot dot443; float dot443_alpha = 1;
dot443.alpha = 1; float dot443_beta = 1;
dot443.beta = 1; migraphx::add_apply_alpha_beta(
mm->add_instruction(dot443, mx440, mx441, mx442); *mm, {mx440, mx441, mx442}, migraphx::make_op("dot"), dot443_alpha, dot443_beta);
return p; return p;
} }
......
#ifndef MIGRAPHX_GUARD_MIGRAPHX_APPLY_ALPHA_BETA_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_APPLY_ALPHA_BETA_HPP
#include "migraphx/make_op.hpp"
#include "migraphx/normalize_attributes.hpp"
#include "migraphx/operation.hpp"
#include <migraphx/instruction_ref.hpp>
#include <migraphx/module.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
instruction_ref insert_apply_alpha_beta(module& m,
instruction_ref pos,
const std::vector<instruction_ref>& args,
const operation& op,
const literal& alpha,
const literal& beta);
template <typename T = float>
instruction_ref insert_apply_alpha_beta(module& m,
instruction_ref pos,
const std::vector<instruction_ref>& args,
const operation& op,
T alpha = 1,
T beta = 0)
{
return insert_apply_alpha_beta(m, pos, args, op, literal{T{alpha}}, literal{T{beta}});
}
template <typename T = float>
instruction_ref add_apply_alpha_beta(module& m,
const std::vector<instruction_ref>& args,
const operation& op,
T alpha = 1,
T beta = 0)
{
return insert_apply_alpha_beta(m, m.end(), args, op, alpha, beta);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_APPLY_ALPHA_BETA_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_DECOMPOSE_HPP
#define MIGRAPHX_GUARD_RTGLIB_DECOMPOSE_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
/**
* Decompose operators.
*/
struct decompose
{
std::string name() const { return "decompose"; }
void apply(module& p) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -18,19 +18,10 @@ namespace op { ...@@ -18,19 +18,10 @@ namespace op {
struct dot struct dot
{ {
float alpha = 1.0;
float beta = 1.0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.alpha, "alpha"), f(self.beta, "beta"));
}
std::string name() const { return "dot"; } std::string name() const { return "dot"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.same_type(); check_shapes{inputs, *this}.same_type().has(2);
const shape& a = inputs.at(0); const shape& a = inputs.at(0);
const shape& b = inputs.at(1); const shape& b = inputs.at(1);
auto t = a.type(); auto t = a.type();
...@@ -58,25 +49,14 @@ struct dot ...@@ -58,25 +49,14 @@ struct dot
auto out_lens = a.lens(); auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1]; out_lens[dim_1] = b.lens()[dim_1];
if(inputs.size() == 3 && out_lens != inputs.at(2).lens())
{
MIGRAPHX_THROW("DOT: dimension mismatch, operand C: {" +
to_string_range(inputs.at(2).lens()) +
"}, cannot add to operand A * B: {" + to_string_range(out_lens) + "}");
}
return {t, out_lens}; return {t, out_lens};
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
argument result; argument result = argument{output_shape};
if(args.size() == 3)
result = args[2];
else
result = argument{output_shape};
visit_all(result, args[0], args[1])( visit_all(result, args[0], args[1])(
[&](auto cmat, auto amat, auto bmat) { gemm(cmat, amat, bmat, alpha, beta); }); [&](auto cmat, auto amat, auto bmat) { gemm(cmat, amat, bmat, 1.0f, 0.0f); });
return result; return result;
} }
}; };
......
...@@ -18,21 +18,12 @@ namespace op { ...@@ -18,21 +18,12 @@ namespace op {
struct quant_dot struct quant_dot
{ {
int32_t alpha = 1;
int32_t beta = 1;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.alpha, "alpha"), f(self.beta, "beta"));
}
value attributes() const { return {{"general_data_type", "dot"}}; } value attributes() const { return {{"general_data_type", "dot"}}; }
std::string name() const { return "quant_dot"; } std::string name() const { return "quant_dot"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{{inputs.at(0), inputs.at(1)}, *this}.same_type(); check_shapes{{inputs.at(0), inputs.at(1)}, *this}.same_type().has(2);
const shape& a = inputs.at(0); const shape& a = inputs.at(0);
const shape& b = inputs.at(1); const shape& b = inputs.at(1);
auto t = a.type(); auto t = a.type();
...@@ -64,18 +55,6 @@ struct quant_dot ...@@ -64,18 +55,6 @@ struct quant_dot
auto out_lens = a.lens(); auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1]; out_lens[dim_1] = b.lens()[dim_1];
if(inputs.size() == 3 && out_lens != inputs.at(2).lens())
{
MIGRAPHX_THROW("QUANT_DOT: dimension mismatch, operand C: {" +
to_string_range(inputs.at(2).lens()) +
"}, cannot add to operand A * B: {" + to_string_range(out_lens) + "}");
}
if(inputs.size() == 3 && inputs.at(2).type() != shape::int32_type)
{
MIGRAPHX_THROW("QUANT_DOT: operand C type must be int32");
}
return {shape::int32_type, out_lens}; return {shape::int32_type, out_lens};
} }
}; };
......
#ifndef MIGRAPHX_GUARD_RTGLIB_REMAP_HPP
#define MIGRAPHX_GUARD_RTGLIB_REMAP_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
/**
* Decompose operators.
*/
struct remap
{
std::string name() const { return "remap"; }
void apply(module& p) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -61,7 +61,7 @@ struct parse_gemm : op_parser<parse_gemm> ...@@ -61,7 +61,7 @@ struct parse_gemm : op_parser<parse_gemm>
? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[1]) ? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[1])
: args[1]; : args[1];
auto ret = info.add_instruction(make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), l1, l2); auto ret = info.add_instruction(make_op("dot"), l1, l2);
if(args.size() == 3) if(args.size() == 3)
{ {
......
...@@ -66,10 +66,8 @@ struct parse_matmul : op_parser<parse_matmul> ...@@ -66,10 +66,8 @@ struct parse_matmul : op_parser<parse_matmul>
make_op("multibroadcast", {{"out_lens", l1_broadcasted_lens}}), l1); make_op("multibroadcast", {{"out_lens", l1_broadcasted_lens}}), l1);
} }
} }
instruction_ref dot_res = info.add_instruction(make_op(opd.op_name), bl0, bl1);
auto dot_res = int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
info.add_instruction(make_op(opd.op_name, {{"alpha", 1}, {"beta", 0}}), bl0, bl1);
int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
if(is_a_prepended) if(is_a_prepended)
{ {
dot_res = info.add_instruction(make_op("squeeze", {{"axes", {num_axis - 2}}}), dot_res); dot_res = info.add_instruction(make_op("squeeze", {{"axes", {num_axis - 2}}}), dot_res);
......
#include <migraphx/remap.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/add.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace {
struct find_dot_add
{
auto matcher() const
{
return match::name("add")(match::any_of(
match::args(match::name("dot")(match::nargs(2)).bind("dot"), match::any().bind("a")),
match::args(match::used_once().bind("a"),
match::name("dot")(match::nargs(2)).bind("dot"))));
}
void apply(module& p, match::matcher_result r) const
{
auto ins = r.result;
auto dot_ins = r.instructions["dot"];
auto a_ins = r.instructions["a"];
auto dot = any_cast<op::dot>(dot_ins->get_operator());
dot.beta = 1;
p.replace_instruction(ins, dot, dot_ins->inputs()[0], dot_ins->inputs()[1], a_ins);
}
};
} // namespace
void remap::apply(module& p) const { match::find_matches(p, find_dot_add{}); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -84,13 +84,7 @@ struct match_find_quantizable_ops ...@@ -84,13 +84,7 @@ struct match_find_quantizable_ops
} }
else if(qop->name() == "dot") else if(qop->name() == "dot")
{ {
auto dot_op = any_cast<op::dot>(qop->get_operator()); dq = m.insert_instruction(qop, migraphx::make_op("quant_dot"), qop_args);
if(!(float_equal(dot_op.alpha, 1.0f) and float_equal(dot_op.beta, 0.0f)))
return;
if(qop_args.size() == 3)
qop_args.pop_back();
dq = m.insert_instruction(
qop, migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), qop_args);
} }
auto ins_type = qop->get_shape().type(); auto ins_type = qop->get_shape().type();
dq_scale = m.add_literal(literal({ins_type}, {scale})); dq_scale = m.add_literal(literal({ins_type}, {scale}));
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
#include <migraphx/check_context.hpp> #include <migraphx/check_context.hpp>
#include <migraphx/adjust_allocation.hpp> #include <migraphx/adjust_allocation.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/decompose.hpp>
#include <migraphx/eliminate_allocation.hpp> #include <migraphx/eliminate_allocation.hpp>
#include <migraphx/eliminate_common_subexpression.hpp> #include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/eliminate_concat.hpp> #include <migraphx/eliminate_concat.hpp>
...@@ -14,7 +13,6 @@ ...@@ -14,7 +13,6 @@
#include <migraphx/memory_coloring.hpp> #include <migraphx/memory_coloring.hpp>
#include <migraphx/propagate_constant.hpp> #include <migraphx/propagate_constant.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/remap.hpp>
#include <migraphx/rewrite_batchnorm.hpp> #include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/rewrite_pooling.hpp> #include <migraphx/rewrite_pooling.hpp>
#include <migraphx/rewrite_quantization.hpp> #include <migraphx/rewrite_quantization.hpp>
...@@ -52,8 +50,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -52,8 +50,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{}, dead_code_elimination{},
eliminate_data_type{unsupported_types, shape::type_t::float_type}, eliminate_data_type{unsupported_types, shape::type_t::float_type},
dead_code_elimination{}, dead_code_elimination{},
decompose{},
dead_code_elimination{},
simplify_reshapes{}, simplify_reshapes{},
eliminate_identity{}, eliminate_identity{},
eliminate_pad{}, eliminate_pad{},
......
...@@ -717,7 +717,7 @@ struct find_gemm_add ...@@ -717,7 +717,7 @@ struct find_gemm_add
auto gemm = any_cast<rocblas_gemm<op::dot>>(gemm_ins->get_operator()); auto gemm = any_cast<rocblas_gemm<op::dot>>(gemm_ins->get_operator());
// Already fused gemm // Already fused gemm
if(not float_equal(gemm.op.beta, 0)) if(not float_equal(gemm.beta, 0))
return; return;
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto i) { if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto i) {
...@@ -738,7 +738,7 @@ struct find_gemm_add ...@@ -738,7 +738,7 @@ struct find_gemm_add
inputs.push_back(copy_ins); inputs.push_back(copy_ins);
inputs.push_back(copy_ins); inputs.push_back(copy_ins);
gemm.op.beta = 1; gemm.beta = 1;
p.replace_instruction(ins, gemm, inputs); p.replace_instruction(ins, gemm, inputs);
} }
}; };
......
#ifndef MIGRAPHX_GUARD_RTGLIB_GPU_GEMM_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_GPU_GEMM_HPP
#define MIGRAPHX_GUARD_RTGLIB_GPU_GEMM_HPP #define MIGRAPHX_GUARD_RTGLIB_GPU_GEMM_HPP
#include <migraphx/errors.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/value.hpp>
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/reflect.hpp> #include <migraphx/reflect.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
...@@ -19,13 +22,17 @@ template <class Op> ...@@ -19,13 +22,17 @@ template <class Op>
struct rocblas_gemm struct rocblas_gemm
{ {
Op op; Op op;
float alpha = 1;
float beta = 0;
bool int8_x4_format = true; bool int8_x4_format = true;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack_join(migraphx::reflect(self.op, f), return pack_join(migraphx::reflect(self.op, f),
pack(f(self.int8_x4_format, "int8_x4_format"))); pack(f(self.alpha, "alpha"),
f(self.beta, "beta"),
f(self.int8_x4_format, "int8_x4_format")));
} }
std::string name() const std::string name() const
...@@ -44,6 +51,26 @@ struct rocblas_gemm ...@@ -44,6 +51,26 @@ struct rocblas_gemm
check_shapes{in_shapes, *this}.not_broadcasted(); check_shapes{in_shapes, *this}.not_broadcasted();
batch_not_transposed(inputs[0].strides()); batch_not_transposed(inputs[0].strides());
batch_not_transposed(inputs[1].strides()); batch_not_transposed(inputs[1].strides());
// if gemm and add are fused
if(not float_equal(beta, 0))
{
auto cmat_shape = in_shapes.back();
in_shapes.pop_back();
auto op_out_shape = op.compute_shape(in_shapes);
if(cmat_shape.lens() != op_out_shape.lens())
{
MIGRAPHX_THROW(this->name() + " : dimension mismatch, operand C: {" +
to_string_range(cmat_shape.lens()) +
"}, cannot add to operand A * B: {" +
to_string_range(op_out_shape.lens()) + "}");
}
if(cmat_shape.type() != op_out_shape.type())
{
MIGRAPHX_THROW(this->name() + " : operand C type mismatch, operand C is of type: " +
to_string(cmat_shape.type()) +
", it must be: " + to_string(op_out_shape.type()));
}
}
return op.compute_shape(in_shapes); return op.compute_shape(in_shapes);
} }
...@@ -51,7 +78,14 @@ struct rocblas_gemm ...@@ -51,7 +78,14 @@ struct rocblas_gemm
argument argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{ {
gemm(ctx, output_shape, args, op.alpha, op.beta, int8_x4_format); if(this->name() == "gpu::gemm")
{
gemm(ctx, output_shape, args, alpha, beta, int8_x4_format);
}
else
{
gemm(ctx, output_shape, args, int32_t(alpha), int32_t(beta), int8_x4_format);
}
return args.back(); return args.back();
} }
......
...@@ -304,17 +304,14 @@ struct miopen_apply ...@@ -304,17 +304,14 @@ struct miopen_apply
}); });
} }
template <class Op> template <typename Op>
void add_gemm_op(std::string name) void add_gemm_op(const std::string& name)
{ {
apply_map.emplace(name, [=](instruction_ref ins) { apply_map.emplace(name, [=](instruction_ref ins) {
auto&& op = any_cast<Op>(ins->get_operator());
auto beta = op.beta;
std::vector<instruction_ref> refs = ins->inputs(); std::vector<instruction_ref> refs = ins->inputs();
if(refs.size() == 2) if(refs.size() == 2)
{ {
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
beta = 0;
refs.push_back(output); refs.push_back(output);
} }
else else
...@@ -333,9 +330,8 @@ struct miopen_apply ...@@ -333,9 +330,8 @@ struct miopen_apply
refs.push_back(refs.back()); refs.push_back(refs.back());
} }
} }
return mod->replace_instruction( return mod->replace_instruction(
ins, rocblas_gemm<Op>{Op{op.alpha, beta}, int8_x4_format}, refs); ins, rocblas_gemm<Op>{Op{}, 1, 0, int8_x4_format}, refs);
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment