Unverified Commit 21193e87 authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Remove alpha and beta from `dot` and `quant_dot` (#961)

Previously dot operator was defined as C = alpha * A . B + beta * C where * is scalar multiplication and . is dot product or matrix multiplication depending on dimension of the inputs.

Aim is to have the definition of dot operator as C = A . B without having alpha or beta.

In order to achieve the same effect as alpha and beta (1) it multiplies the one of the inputs to the dot operator with alpha value. (2) if beta is present then, multiplies the C with beta and then adds into the output from step 1.
parent 87978f03
......@@ -7,6 +7,7 @@ include(CheckCXXLinkerFlag)
add_library(migraphx
adjust_allocation.cpp
analyze_streams.cpp
apply_alpha_beta.cpp
argument.cpp
auto_contiguous.cpp
common.cpp
......@@ -14,7 +15,6 @@ add_library(migraphx
convert_to_json.cpp
cpp_generator.cpp
dead_code_elimination.cpp
decompose.cpp
dom_info.cpp
dynamic_loader.cpp
eliminate_allocation.cpp
......@@ -52,7 +52,6 @@ add_library(migraphx
reduce_dims.cpp
register_op.cpp
register_target.cpp
remap.cpp
simplify_qdq.cpp
rewrite_batchnorm.cpp
rewrite_pooling.cpp
......
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/common.hpp>
#include <migraphx/apply_alpha_beta.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
instruction_ref insert_apply_alpha_beta(module& m,
instruction_ref pos,
const std::vector<instruction_ref>& args,
const operation& op,
const literal& alpha,
const literal& beta)
{
auto a = args[0];
auto b = args[1];
auto input_type = a->get_shape().type();
if(!float_equal(alpha.at<float>(0), 1.0))
{
auto alpha_literal = m.add_literal(alpha);
a = insert_common_op(m, pos, migraphx::make_op("mul"), {alpha_literal, a});
if(a->get_shape().type() != input_type)
{
a = m.insert_instruction(pos, make_op("convert", {{"target_type", input_type}}), a);
}
}
auto op_res = m.insert_instruction(pos, op, a, b);
if(args.size() == 3)
{
if(not float_equal(beta.at<float>(0), 0.0) && args[2]->get_shape().elements() > 0)
{
auto out_lens = op_res->get_shape().lens();
auto c = args[2];
auto c_lens = c->get_shape().lens();
input_type = c->get_shape().type();
if(out_lens != c_lens)
{
c = m.insert_instruction(
pos, migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), args[2]);
}
auto beta_literal = m.add_literal(beta);
auto beta_c = insert_common_op(m, pos, migraphx::make_op("mul"), {c, beta_literal});
if(beta_c->get_shape().type() != input_type)
{
beta_c = m.insert_instruction(
pos, migraphx::make_op("convert", {{"target_type", input_type}}), beta_c);
}
return m.insert_instruction(pos, migraphx::make_op("add"), op_res, beta_c);
}
}
return op_res;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
File mode changed from 100755 to 100644
#include <migraphx/decompose.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace {
struct alpha_beta
{
float alpha = 0.0;
float beta = 0.0;
};
alpha_beta get_alpha_beta(const operation& op)
{
auto v = op.to_value();
return {v.at("alpha").to<float>(), v.at("beta").to<float>()};
}
struct find_dot_add
{
auto matcher() const { return match::name("dot", "quant_dot")(match::nargs(3)); }
void apply(module& p, const match::matcher_result& r) const
{
auto ins = r.result;
auto dot = get_alpha_beta(ins->get_operator());
auto a_ins = ins->inputs()[0];
auto b_ins = ins->inputs()[1];
if(not float_equal(dot.alpha, 1))
{
auto alpha = p.add_literal(literal{shape{a_ins->get_shape().type()}, {dot.alpha}});
auto alpha_broadcast = p.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", a_ins->get_shape().lens()}}), alpha);
a_ins = p.insert_instruction(ins, make_op("mul"), a_ins, alpha_broadcast);
}
auto dot_ins = p.insert_instruction(ins, make_op(ins->name(), {{"beta", 0}}), a_ins, b_ins);
auto c_ins = ins->inputs()[2];
if(not float_equal(dot.beta, 1))
{
auto beta = p.add_literal(literal{shape{c_ins->get_shape().type()}, {dot.beta}});
auto beta_broadcast = p.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", ins->get_shape().lens()}}), beta);
c_ins = p.insert_instruction(ins, make_op("mul"), c_ins, beta_broadcast);
}
p.replace_instruction(ins, make_op("add"), dot_ins, c_ins);
}
};
struct find_dot_alpha
{
auto matcher() const { return match::name("dot", "quant_dot")(match::nargs(2)); }
void apply(module& p, const match::matcher_result& r) const
{
auto ins = r.result;
auto dot = get_alpha_beta(ins->get_operator());
auto a_ins = ins->inputs()[0];
auto b_ins = ins->inputs()[1];
if(not float_equal(dot.alpha, 1))
{
auto alpha = p.add_literal(literal{shape{a_ins->get_shape().type()}, {dot.alpha}});
auto alpha_broadcast = p.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", a_ins->get_shape().lens()}}), alpha);
a_ins = p.insert_instruction(ins, make_op("mul"), a_ins, alpha_broadcast);
}
p.replace_instruction(ins, make_op(ins->name(), {{"beta", 0}}), a_ins, b_ins);
}
};
} // namespace
void decompose::apply(module& p) const { match::find_matches(p, find_dot_add{}, find_dot_alpha{}); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/operators.hpp>
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include "models.hpp"
namespace migraphx {
......@@ -144,10 +145,10 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::multibroadcast multibroadcast42;
multibroadcast42.output_lens = {batch, 4096};
auto mx42 = mm->add_instruction(multibroadcast42, mx4);
migraphx::op::dot dot43;
dot43.alpha = 1;
dot43.beta = 1;
auto mx43 = mm->add_instruction(dot43, mx40, mx41, mx42);
float dot43_alpha = 1;
float dot43_beta = 1;
auto mx43 = migraphx::add_apply_alpha_beta(
*mm, {mx40, mx41, mx42}, migraphx::make_op("dot"), dot43_alpha, dot43_beta);
migraphx::op::relu relu44;
auto mx44 = mm->add_instruction(relu44, mx43);
migraphx::op::identity identity45;
......@@ -158,10 +159,10 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::multibroadcast multibroadcast47;
multibroadcast47.output_lens = {batch, 4096};
auto mx47 = mm->add_instruction(multibroadcast47, mx2);
migraphx::op::dot dot48;
dot48.alpha = 1;
dot48.beta = 1;
auto mx48 = mm->add_instruction(dot48, mx45, mx46, mx47);
float dot48_alpha = 1;
float dot48_beta = 1;
auto mx48 = migraphx::add_apply_alpha_beta(
*mm, {mx45, mx46, mx47}, migraphx::make_op("dot"), dot48_alpha, dot48_beta);
migraphx::op::relu relu49;
auto mx49 = mm->add_instruction(relu49, mx48);
migraphx::op::transpose transpose50;
......@@ -170,10 +171,10 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::multibroadcast multibroadcast51;
multibroadcast51.output_lens = {batch, 1000};
auto mx51 = mm->add_instruction(multibroadcast51, mx0);
migraphx::op::dot dot52;
dot52.alpha = 1;
dot52.beta = 1;
mm->add_instruction(dot52, mx49, mx50, mx51);
float dot52_alpha = 1;
float dot52_beta = 1;
migraphx::add_apply_alpha_beta(
*mm, {mx49, mx50, mx51}, migraphx::make_op("dot"), dot52_alpha, dot52_beta);
return p;
}
......
#include <migraphx/operators.hpp>
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include "models.hpp"
namespace migraphx {
......@@ -2225,10 +2226,10 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::multibroadcast multibroadcast798;
multibroadcast798.output_lens = {batch, 1000};
auto mx798 = mm->add_instruction(multibroadcast798, mx0);
migraphx::op::dot dot799;
dot799.alpha = 1;
dot799.beta = 1;
mm->add_instruction(dot799, mx796, mx797, mx798);
float dot799_alpha = 1;
float dot799_beta = 1;
migraphx::add_apply_alpha_beta(
*mm, {mx796, mx797, mx798}, migraphx::make_op("dot"), dot799_alpha, dot799_beta);
return p;
}
......
#include <migraphx/operators.hpp>
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include "models.hpp"
namespace migraphx {
......@@ -1228,10 +1229,10 @@ migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::multibroadcast multibroadcast442;
multibroadcast442.output_lens = {batch, 1000};
auto mx442 = mm->add_instruction(multibroadcast442, mx0);
migraphx::op::dot dot443;
dot443.alpha = 1;
dot443.beta = 1;
mm->add_instruction(dot443, mx440, mx441, mx442);
float dot443_alpha = 1;
float dot443_beta = 1;
migraphx::add_apply_alpha_beta(
*mm, {mx440, mx441, mx442}, migraphx::make_op("dot"), dot443_alpha, dot443_beta);
return p;
}
......
#ifndef MIGRAPHX_GUARD_MIGRAPHX_APPLY_ALPHA_BETA_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_APPLY_ALPHA_BETA_HPP
#include "migraphx/make_op.hpp"
#include "migraphx/normalize_attributes.hpp"
#include "migraphx/operation.hpp"
#include <migraphx/instruction_ref.hpp>
#include <migraphx/module.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
instruction_ref insert_apply_alpha_beta(module& m,
instruction_ref pos,
const std::vector<instruction_ref>& args,
const operation& op,
const literal& alpha,
const literal& beta);
template <typename T = float>
instruction_ref insert_apply_alpha_beta(module& m,
instruction_ref pos,
const std::vector<instruction_ref>& args,
const operation& op,
T alpha = 1,
T beta = 0)
{
return insert_apply_alpha_beta(m, pos, args, op, literal{T{alpha}}, literal{T{beta}});
}
template <typename T = float>
instruction_ref add_apply_alpha_beta(module& m,
const std::vector<instruction_ref>& args,
const operation& op,
T alpha = 1,
T beta = 0)
{
return insert_apply_alpha_beta(m, m.end(), args, op, alpha, beta);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_APPLY_ALPHA_BETA_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_DECOMPOSE_HPP
#define MIGRAPHX_GUARD_RTGLIB_DECOMPOSE_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
/**
* Decompose operators.
*/
struct decompose
{
std::string name() const { return "decompose"; }
void apply(module& p) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -18,19 +18,10 @@ namespace op {
struct dot
{
float alpha = 1.0;
float beta = 1.0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.alpha, "alpha"), f(self.beta, "beta"));
}
std::string name() const { return "dot"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.same_type();
check_shapes{inputs, *this}.same_type().has(2);
const shape& a = inputs.at(0);
const shape& b = inputs.at(1);
auto t = a.type();
......@@ -58,25 +49,14 @@ struct dot
auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1];
if(inputs.size() == 3 && out_lens != inputs.at(2).lens())
{
MIGRAPHX_THROW("DOT: dimension mismatch, operand C: {" +
to_string_range(inputs.at(2).lens()) +
"}, cannot add to operand A * B: {" + to_string_range(out_lens) + "}");
}
return {t, out_lens};
}
argument compute(shape output_shape, std::vector<argument> args) const
{
argument result;
if(args.size() == 3)
result = args[2];
else
result = argument{output_shape};
argument result = argument{output_shape};
visit_all(result, args[0], args[1])(
[&](auto cmat, auto amat, auto bmat) { gemm(cmat, amat, bmat, alpha, beta); });
[&](auto cmat, auto amat, auto bmat) { gemm(cmat, amat, bmat, 1.0f, 0.0f); });
return result;
}
};
......
......@@ -18,21 +18,12 @@ namespace op {
struct quant_dot
{
int32_t alpha = 1;
int32_t beta = 1;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.alpha, "alpha"), f(self.beta, "beta"));
}
value attributes() const { return {{"general_data_type", "dot"}}; }
std::string name() const { return "quant_dot"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{{inputs.at(0), inputs.at(1)}, *this}.same_type();
check_shapes{{inputs.at(0), inputs.at(1)}, *this}.same_type().has(2);
const shape& a = inputs.at(0);
const shape& b = inputs.at(1);
auto t = a.type();
......@@ -64,18 +55,6 @@ struct quant_dot
auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1];
if(inputs.size() == 3 && out_lens != inputs.at(2).lens())
{
MIGRAPHX_THROW("QUANT_DOT: dimension mismatch, operand C: {" +
to_string_range(inputs.at(2).lens()) +
"}, cannot add to operand A * B: {" + to_string_range(out_lens) + "}");
}
if(inputs.size() == 3 && inputs.at(2).type() != shape::int32_type)
{
MIGRAPHX_THROW("QUANT_DOT: operand C type must be int32");
}
return {shape::int32_type, out_lens};
}
};
......
#ifndef MIGRAPHX_GUARD_RTGLIB_REMAP_HPP
#define MIGRAPHX_GUARD_RTGLIB_REMAP_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
/**
* Decompose operators.
*/
struct remap
{
std::string name() const { return "remap"; }
void apply(module& p) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -61,7 +61,7 @@ struct parse_gemm : op_parser<parse_gemm>
? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[1])
: args[1];
auto ret = info.add_instruction(make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), l1, l2);
auto ret = info.add_instruction(make_op("dot"), l1, l2);
if(args.size() == 3)
{
......
......@@ -66,9 +66,7 @@ struct parse_matmul : op_parser<parse_matmul>
make_op("multibroadcast", {{"out_lens", l1_broadcasted_lens}}), l1);
}
}
auto dot_res =
info.add_instruction(make_op(opd.op_name, {{"alpha", 1}, {"beta", 0}}), bl0, bl1);
instruction_ref dot_res = info.add_instruction(make_op(opd.op_name), bl0, bl1);
int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
if(is_a_prepended)
{
......
#include <migraphx/remap.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/add.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace {
struct find_dot_add
{
auto matcher() const
{
return match::name("add")(match::any_of(
match::args(match::name("dot")(match::nargs(2)).bind("dot"), match::any().bind("a")),
match::args(match::used_once().bind("a"),
match::name("dot")(match::nargs(2)).bind("dot"))));
}
void apply(module& p, match::matcher_result r) const
{
auto ins = r.result;
auto dot_ins = r.instructions["dot"];
auto a_ins = r.instructions["a"];
auto dot = any_cast<op::dot>(dot_ins->get_operator());
dot.beta = 1;
p.replace_instruction(ins, dot, dot_ins->inputs()[0], dot_ins->inputs()[1], a_ins);
}
};
} // namespace
void remap::apply(module& p) const { match::find_matches(p, find_dot_add{}); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -84,13 +84,7 @@ struct match_find_quantizable_ops
}
else if(qop->name() == "dot")
{
auto dot_op = any_cast<op::dot>(qop->get_operator());
if(!(float_equal(dot_op.alpha, 1.0f) and float_equal(dot_op.beta, 0.0f)))
return;
if(qop_args.size() == 3)
qop_args.pop_back();
dq = m.insert_instruction(
qop, migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), qop_args);
dq = m.insert_instruction(qop, migraphx::make_op("quant_dot"), qop_args);
}
auto ins_type = qop->get_shape().type();
dq_scale = m.add_literal(literal({ins_type}, {scale}));
......
......@@ -3,7 +3,6 @@
#include <migraphx/check_context.hpp>
#include <migraphx/adjust_allocation.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/decompose.hpp>
#include <migraphx/eliminate_allocation.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/eliminate_concat.hpp>
......@@ -14,7 +13,6 @@
#include <migraphx/memory_coloring.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/remap.hpp>
#include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/rewrite_quantization.hpp>
......@@ -52,8 +50,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{},
eliminate_data_type{unsupported_types, shape::type_t::float_type},
dead_code_elimination{},
decompose{},
dead_code_elimination{},
simplify_reshapes{},
eliminate_identity{},
eliminate_pad{},
......
......@@ -717,7 +717,7 @@ struct find_gemm_add
auto gemm = any_cast<rocblas_gemm<op::dot>>(gemm_ins->get_operator());
// Already fused gemm
if(not float_equal(gemm.op.beta, 0))
if(not float_equal(gemm.beta, 0))
return;
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto i) {
......@@ -738,7 +738,7 @@ struct find_gemm_add
inputs.push_back(copy_ins);
inputs.push_back(copy_ins);
gemm.op.beta = 1;
gemm.beta = 1;
p.replace_instruction(ins, gemm, inputs);
}
};
......
#ifndef MIGRAPHX_GUARD_RTGLIB_GPU_GEMM_HPP
#define MIGRAPHX_GUARD_RTGLIB_GPU_GEMM_HPP
#include <migraphx/errors.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/value.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/gpu/context.hpp>
......@@ -19,13 +22,17 @@ template <class Op>
struct rocblas_gemm
{
Op op;
float alpha = 1;
float beta = 0;
bool int8_x4_format = true;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack_join(migraphx::reflect(self.op, f),
pack(f(self.int8_x4_format, "int8_x4_format")));
pack(f(self.alpha, "alpha"),
f(self.beta, "beta"),
f(self.int8_x4_format, "int8_x4_format")));
}
std::string name() const
......@@ -44,6 +51,26 @@ struct rocblas_gemm
check_shapes{in_shapes, *this}.not_broadcasted();
batch_not_transposed(inputs[0].strides());
batch_not_transposed(inputs[1].strides());
// if gemm and add are fused
if(not float_equal(beta, 0))
{
auto cmat_shape = in_shapes.back();
in_shapes.pop_back();
auto op_out_shape = op.compute_shape(in_shapes);
if(cmat_shape.lens() != op_out_shape.lens())
{
MIGRAPHX_THROW(this->name() + " : dimension mismatch, operand C: {" +
to_string_range(cmat_shape.lens()) +
"}, cannot add to operand A * B: {" +
to_string_range(op_out_shape.lens()) + "}");
}
if(cmat_shape.type() != op_out_shape.type())
{
MIGRAPHX_THROW(this->name() + " : operand C type mismatch, operand C is of type: " +
to_string(cmat_shape.type()) +
", it must be: " + to_string(op_out_shape.type()));
}
}
return op.compute_shape(in_shapes);
}
......@@ -51,7 +78,14 @@ struct rocblas_gemm
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{
gemm(ctx, output_shape, args, op.alpha, op.beta, int8_x4_format);
if(this->name() == "gpu::gemm")
{
gemm(ctx, output_shape, args, alpha, beta, int8_x4_format);
}
else
{
gemm(ctx, output_shape, args, int32_t(alpha), int32_t(beta), int8_x4_format);
}
return args.back();
}
......
......@@ -304,17 +304,14 @@ struct miopen_apply
});
}
template <class Op>
void add_gemm_op(std::string name)
template <typename Op>
void add_gemm_op(const std::string& name)
{
apply_map.emplace(name, [=](instruction_ref ins) {
auto&& op = any_cast<Op>(ins->get_operator());
auto beta = op.beta;
std::vector<instruction_ref> refs = ins->inputs();
if(refs.size() == 2)
{
auto output = insert_allocation(ins, ins->get_shape());
beta = 0;
refs.push_back(output);
}
else
......@@ -333,9 +330,8 @@ struct miopen_apply
refs.push_back(refs.back());
}
}
return mod->replace_instruction(
ins, rocblas_gemm<Op>{Op{op.alpha, beta}, int8_x4_format}, refs);
ins, rocblas_gemm<Op>{Op{}, 1, 0, int8_x4_format}, refs);
});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment