Unverified Commit 9e43cb8b authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Remove alpha and beta attributes from dot operator (#945)

This PR aims to remove alpha and beta attributes from dot operator completely.

Previously dot operator was defined as C = alpha * A . B + beta * C where * is scalar multiplication and . is dot product or matrix multiplication depending on dimension of the inputs.

Aim is to have the definition of dot operator as C = A . B without having alpha or beta.

In order to achieve the same effect as alpha and beta (1) it multiplies the one of the inputs to the dot operator with alpha value. (2) if beta is present then, multiplies the C with beta and then adds into the output from step 1.
parent 31dc067e
...@@ -52,7 +52,6 @@ add_library(migraphx ...@@ -52,7 +52,6 @@ add_library(migraphx
reduce_dims.cpp reduce_dims.cpp
register_op.cpp register_op.cpp
register_target.cpp register_target.cpp
remap.cpp
simplify_qdq.cpp simplify_qdq.cpp
rewrite_batchnorm.cpp rewrite_batchnorm.cpp
rewrite_pooling.cpp rewrite_pooling.cpp
......
...@@ -113,5 +113,58 @@ instruction_ref add_common_op(module& m, const operation& op, std::vector<instru ...@@ -113,5 +113,58 @@ instruction_ref add_common_op(module& m, const operation& op, std::vector<instru
return insert_common_op(m, m.end(), op, std::move(inputs)); return insert_common_op(m, m.end(), op, std::move(inputs));
} }
instruction_ref insert_dot_apply_alpha_beta(module& m,
instruction_ref pos,
const std::vector<instruction_ref>& args,
float alpha,
float beta)
{
auto l1 = args[0];
auto l2 = args[1];
auto dot_type = l1->get_shape().type();
if(!float_equal(alpha, 1.0f))
{
auto alpha_literal = m.add_literal(alpha);
l1 = insert_common_op(m, pos, migraphx::make_op("mul"), {alpha_literal, l1});
if(l1->get_shape().type() != dot_type)
{
l1 = m.insert_instruction(pos, make_op("convert", {{"target_type", dot_type}}), l1);
}
}
auto dot_res = m.insert_instruction(pos, migraphx::make_op("dot"), l1, l2);
if(args.size() == 3)
{
if(not float_equal(beta, 0.0f) && args[2]->get_shape().elements() > 0)
{
auto out_lens = l1->get_shape().lens();
out_lens.back() = l2->get_shape().lens().back();
auto l3 = args[2];
auto l3_lens = l3->get_shape().lens();
if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
{
l3 = m.insert_instruction(
pos, migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), args[2]);
}
auto beta_literal = m.add_literal(beta);
auto beta_l3 = insert_common_op(m, pos, migraphx::make_op("mul"), {l3, beta_literal});
if(beta_l3->get_shape().type() != dot_type)
{
beta_l3 = m.insert_instruction(
pos, migraphx::make_op("convert", {{"target_type", dot_type}}), beta_l3);
}
return m.insert_instruction(pos, migraphx::make_op("add"), dot_res, beta_l3);
}
}
return dot_res;
}
instruction_ref add_dot_apply_alpha_beta(module& m,
const std::vector<instruction_ref>& args,
float alpha,
float beta)
{
return insert_dot_apply_alpha_beta(m, m.end(), args, alpha, beta);
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -27,7 +27,7 @@ alpha_beta get_alpha_beta(const operation& op) ...@@ -27,7 +27,7 @@ alpha_beta get_alpha_beta(const operation& op)
struct find_dot_add struct find_dot_add
{ {
auto matcher() const { return match::name("dot", "quant_dot")(match::nargs(3)); } auto matcher() const { return match::name("quant_dot")(match::nargs(3)); }
void apply(module& p, const match::matcher_result& r) const void apply(module& p, const match::matcher_result& r) const
{ {
...@@ -58,7 +58,7 @@ struct find_dot_add ...@@ -58,7 +58,7 @@ struct find_dot_add
struct find_dot_alpha struct find_dot_alpha
{ {
auto matcher() const { return match::name("dot", "quant_dot")(match::nargs(2)); } auto matcher() const { return match::name("quant_dot")(match::nargs(2)); }
void apply(module& p, const match::matcher_result& r) const void apply(module& p, const match::matcher_result& r) const
{ {
......
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/common.hpp>
#include "models.hpp" #include "models.hpp"
namespace migraphx { namespace migraphx {
...@@ -144,10 +145,10 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size) ...@@ -144,10 +145,10 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::multibroadcast multibroadcast42; migraphx::op::multibroadcast multibroadcast42;
multibroadcast42.output_lens = {batch, 4096}; multibroadcast42.output_lens = {batch, 4096};
auto mx42 = mm->add_instruction(multibroadcast42, mx4); auto mx42 = mm->add_instruction(multibroadcast42, mx4);
migraphx::op::dot dot43; float dot43_alpha = 1;
dot43.alpha = 1; float dot43_beta = 1;
dot43.beta = 1; auto mx43 =
auto mx43 = mm->add_instruction(dot43, mx40, mx41, mx42); migraphx::add_dot_apply_alpha_beta(*mm, {mx40, mx41, mx42}, dot43_alpha, dot43_beta);
migraphx::op::relu relu44; migraphx::op::relu relu44;
auto mx44 = mm->add_instruction(relu44, mx43); auto mx44 = mm->add_instruction(relu44, mx43);
migraphx::op::identity identity45; migraphx::op::identity identity45;
...@@ -158,10 +159,10 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size) ...@@ -158,10 +159,10 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::multibroadcast multibroadcast47; migraphx::op::multibroadcast multibroadcast47;
multibroadcast47.output_lens = {batch, 4096}; multibroadcast47.output_lens = {batch, 4096};
auto mx47 = mm->add_instruction(multibroadcast47, mx2); auto mx47 = mm->add_instruction(multibroadcast47, mx2);
migraphx::op::dot dot48; float dot48_alpha = 1;
dot48.alpha = 1; float dot48_beta = 1;
dot48.beta = 1; auto mx48 =
auto mx48 = mm->add_instruction(dot48, mx45, mx46, mx47); migraphx::add_dot_apply_alpha_beta(*mm, {mx45, mx46, mx47}, dot48_alpha, dot48_beta);
migraphx::op::relu relu49; migraphx::op::relu relu49;
auto mx49 = mm->add_instruction(relu49, mx48); auto mx49 = mm->add_instruction(relu49, mx48);
migraphx::op::transpose transpose50; migraphx::op::transpose transpose50;
...@@ -170,10 +171,9 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size) ...@@ -170,10 +171,9 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::multibroadcast multibroadcast51; migraphx::op::multibroadcast multibroadcast51;
multibroadcast51.output_lens = {batch, 1000}; multibroadcast51.output_lens = {batch, 1000};
auto mx51 = mm->add_instruction(multibroadcast51, mx0); auto mx51 = mm->add_instruction(multibroadcast51, mx0);
migraphx::op::dot dot52; float dot52_alpha = 1;
dot52.alpha = 1; float dot52_beta = 1;
dot52.beta = 1; migraphx::add_dot_apply_alpha_beta(*mm, {mx49, mx50, mx51}, dot52_alpha, dot52_beta);
mm->add_instruction(dot52, mx49, mx50, mx51);
return p; return p;
} }
......
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/common.hpp>
#include "models.hpp" #include "models.hpp"
namespace migraphx { namespace migraphx {
...@@ -2225,10 +2226,9 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -2225,10 +2226,9 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::multibroadcast multibroadcast798; migraphx::op::multibroadcast multibroadcast798;
multibroadcast798.output_lens = {batch, 1000}; multibroadcast798.output_lens = {batch, 1000};
auto mx798 = mm->add_instruction(multibroadcast798, mx0); auto mx798 = mm->add_instruction(multibroadcast798, mx0);
migraphx::op::dot dot799; float dot799_alpha = 1;
dot799.alpha = 1; float dot799_beta = 1;
dot799.beta = 1; migraphx::add_dot_apply_alpha_beta(*mm, {mx796, mx797, mx798}, dot799_alpha, dot799_beta);
mm->add_instruction(dot799, mx796, mx797, mx798);
return p; return p;
} }
......
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/common.hpp>
#include "models.hpp" #include "models.hpp"
namespace migraphx { namespace migraphx {
...@@ -1228,10 +1229,9 @@ migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size) ...@@ -1228,10 +1229,9 @@ migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::multibroadcast multibroadcast442; migraphx::op::multibroadcast multibroadcast442;
multibroadcast442.output_lens = {batch, 1000}; multibroadcast442.output_lens = {batch, 1000};
auto mx442 = mm->add_instruction(multibroadcast442, mx0); auto mx442 = mm->add_instruction(multibroadcast442, mx0);
migraphx::op::dot dot443; float dot443_alpha = 1;
dot443.alpha = 1; float dot443_beta = 1;
dot443.beta = 1; migraphx::add_dot_apply_alpha_beta(*mm, {mx440, mx441, mx442}, dot443_alpha, dot443_beta);
mm->add_instruction(dot443, mx440, mx441, mx442);
return p; return p;
} }
......
#include "migraphx/common.hpp"
#include "migraphx/errors.hpp"
#include "migraphx/float_equal.hpp"
#include <cmath>
#include <cstdint>
#include <migraphx/eliminate_data_type.hpp> #include <migraphx/eliminate_data_type.hpp>
#include <migraphx/module.hpp> #include <migraphx/module.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
...@@ -28,11 +34,25 @@ void eliminate_data_type::apply(module& m) const ...@@ -28,11 +34,25 @@ void eliminate_data_type::apply(module& m) const
continue; continue;
auto op = ins->get_operator(); auto op = ins->get_operator();
auto attributes = op.attributes(); auto attributes = op.attributes();
auto old_type = ins->get_shape().type();
auto val = op.to_value();
if(attributes.contains("general_data_type")) if(attributes.contains("general_data_type"))
{ {
op = make_op(attributes["general_data_type"].to<std::string>(), op.to_value()); if(ins->name() == "quant_dot")
{
auto alpha = val.at("alpha").to<std::float_t>();
auto beta = val.at("beta").to<std::float_t>();
auto dot_res = migraphx::insert_dot_apply_alpha_beta(m, ins, inputs, alpha, beta);
auto convert = m.insert_instruction(
ins, make_op("convert", {{"target_type", old_type}}), dot_res);
m.replace_instruction(ins, convert);
return;
}
else
{
op = make_op(attributes["general_data_type"].to<std::string>(), val);
}
} }
auto old_type = ins->get_shape().type();
auto out = m.insert_instruction(ins, op, inputs); auto out = m.insert_instruction(ins, op, inputs);
auto convert = auto convert =
m.insert_instruction(ins, make_op("convert", {{"target_type", old_type}}), out); m.insert_instruction(ins, make_op("convert", {{"target_type", old_type}}), out);
......
...@@ -21,6 +21,17 @@ instruction_ref insert_common_op(module& m, ...@@ -21,6 +21,17 @@ instruction_ref insert_common_op(module& m,
std::vector<instruction_ref> inputs); std::vector<instruction_ref> inputs);
instruction_ref add_common_op(module& m, const operation& op, std::vector<instruction_ref> inputs); instruction_ref add_common_op(module& m, const operation& op, std::vector<instruction_ref> inputs);
instruction_ref insert_dot_apply_alpha_beta(module& m,
instruction_ref pos,
const std::vector<instruction_ref>& args,
float alpha = 1.0,
float beta = 0.0);
instruction_ref add_dot_apply_alpha_beta(module& m,
const std::vector<instruction_ref>& args,
float alpha = 1.0,
float beta = 0.0);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_COMMON_HPP #endif // MIGRAPHX_GUARD_MIGRAPHX_COMMON_HPP
...@@ -18,15 +18,6 @@ namespace op { ...@@ -18,15 +18,6 @@ namespace op {
struct dot struct dot
{ {
float alpha = 1.0;
float beta = 1.0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.alpha, "alpha"), f(self.beta, "beta"));
}
std::string name() const { return "dot"; } std::string name() const { return "dot"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
...@@ -76,7 +67,7 @@ struct dot ...@@ -76,7 +67,7 @@ struct dot
else else
result = argument{output_shape}; result = argument{output_shape};
visit_all(result, args[0], args[1])( visit_all(result, args[0], args[1])(
[&](auto cmat, auto amat, auto bmat) { gemm(cmat, amat, bmat, alpha, beta); }); [&](auto cmat, auto amat, auto bmat) { gemm(cmat, amat, bmat, 1, 0); });
return result; return result;
} }
}; };
......
#ifndef MIGRAPHX_GUARD_RTGLIB_REMAP_HPP
#define MIGRAPHX_GUARD_RTGLIB_REMAP_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
/**
* Decompose operators.
*/
struct remap
{
std::string name() const { return "remap"; }
void apply(module& p) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -61,7 +61,7 @@ struct parse_gemm : op_parser<parse_gemm> ...@@ -61,7 +61,7 @@ struct parse_gemm : op_parser<parse_gemm>
? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[1]) ? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[1])
: args[1]; : args[1];
auto ret = info.add_instruction(make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), l1, l2); auto ret = info.add_instruction(make_op("dot"), l1, l2);
if(args.size() == 3) if(args.size() == 3)
{ {
......
...@@ -66,9 +66,16 @@ struct parse_matmul : op_parser<parse_matmul> ...@@ -66,9 +66,16 @@ struct parse_matmul : op_parser<parse_matmul>
make_op("multibroadcast", {{"out_lens", l1_broadcasted_lens}}), l1); make_op("multibroadcast", {{"out_lens", l1_broadcasted_lens}}), l1);
} }
} }
instruction_ref dot_res;
auto dot_res = if(opd.op_name == "dot")
{
dot_res = info.add_instruction(make_op(opd.op_name), bl0, bl1);
}
else
{
dot_res =
info.add_instruction(make_op(opd.op_name, {{"alpha", 1}, {"beta", 0}}), bl0, bl1); info.add_instruction(make_op(opd.op_name, {{"alpha", 1}, {"beta", 0}}), bl0, bl1);
}
int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size()); int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
if(is_a_prepended) if(is_a_prepended)
{ {
......
#include <migraphx/remap.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/add.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace {
struct find_dot_add
{
auto matcher() const
{
return match::name("add")(match::any_of(
match::args(match::name("dot")(match::nargs(2)).bind("dot"), match::any().bind("a")),
match::args(match::used_once().bind("a"),
match::name("dot")(match::nargs(2)).bind("dot"))));
}
void apply(module& p, match::matcher_result r) const
{
auto ins = r.result;
auto dot_ins = r.instructions["dot"];
auto a_ins = r.instructions["a"];
auto dot = any_cast<op::dot>(dot_ins->get_operator());
dot.beta = 1;
p.replace_instruction(ins, dot, dot_ins->inputs()[0], dot_ins->inputs()[1], a_ins);
}
};
} // namespace
void remap::apply(module& p) const { match::find_matches(p, find_dot_add{}); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -84,11 +84,6 @@ struct match_find_quantizable_ops ...@@ -84,11 +84,6 @@ struct match_find_quantizable_ops
} }
else if(qop->name() == "dot") else if(qop->name() == "dot")
{ {
auto dot_op = any_cast<op::dot>(qop->get_operator());
if(!(float_equal(dot_op.alpha, 1.0f) and float_equal(dot_op.beta, 0.0f)))
return;
if(qop_args.size() == 3)
qop_args.pop_back();
dq = m.insert_instruction( dq = m.insert_instruction(
qop, migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), qop_args); qop, migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), qop_args);
} }
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
#include <migraphx/memory_coloring.hpp> #include <migraphx/memory_coloring.hpp>
#include <migraphx/propagate_constant.hpp> #include <migraphx/propagate_constant.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/remap.hpp>
#include <migraphx/rewrite_batchnorm.hpp> #include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/rewrite_pooling.hpp> #include <migraphx/rewrite_pooling.hpp>
#include <migraphx/rewrite_quantization.hpp> #include <migraphx/rewrite_quantization.hpp>
......
...@@ -697,52 +697,6 @@ struct find_conv_bias_relu ...@@ -697,52 +697,6 @@ struct find_conv_bias_relu
apply_conv_bias<miopen_conv_bias_relu>(*ctx, p, std::move(r)); apply_conv_bias<miopen_conv_bias_relu>(*ctx, p, std::move(r));
} }
}; };
struct find_gemm_add
{
auto matcher() const
{
return match::name("gpu::add")(
match::all_of[match::inputs()](match::standard_shape()),
match::either_arg(0, 1)(match::used_once().bind("c"),
match::name("gpu::gemm")(match::nargs(3)).bind("gemm")));
}
void apply(module& p, match::matcher_result r) const
{
auto ins = r.result;
auto gemm_ins = r.instructions["gemm"];
auto c_ins = r.instructions["c"];
auto gemm = any_cast<rocblas_gemm<op::dot>>(gemm_ins->get_operator());
// Already fused gemm
if(not float_equal(gemm.op.beta, 0))
return;
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto i) {
return not i->get_shape().standard();
}))
return;
auto inputs = gemm_ins->inputs();
inputs.pop_back();
auto copy_ins = c_ins;
// Insert copy
if(ins == p.end() or c_ins->outputs().size() > 1 or c_ins->inputs().empty())
{
copy_ins = p.insert_instruction(ins, hip_copy{}, c_ins, ins->inputs().back());
}
inputs.push_back(copy_ins);
inputs.push_back(copy_ins);
gemm.op.beta = 1;
p.replace_instruction(ins, gemm, inputs);
}
};
struct find_commutative_broadcast struct find_commutative_broadcast
{ {
auto matcher() const auto matcher() const
...@@ -778,7 +732,7 @@ void fuse_ops::apply(module& p) const ...@@ -778,7 +732,7 @@ void fuse_ops::apply(module& p) const
find_add_unary{"gpu::tanh", hip_add_tanh{}, hip_triadd_tanh{}}, find_add_unary{"gpu::tanh", hip_add_tanh{}, hip_triadd_tanh{}},
find_add_clip{}); find_add_clip{});
run_passes(p, {dead_code_elimination{}}); run_passes(p, {dead_code_elimination{}});
match::find_matches(p, find_triadd_layernorm{}, find_gemm_add{}, find_commutative_broadcast{}); match::find_matches(p, find_triadd_layernorm{}, find_commutative_broadcast{});
} }
} // namespace gpu } // namespace gpu
......
#ifndef MIGRAPHX_GUARD_RTGLIB_GPU_GEMM_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_GPU_GEMM_HPP
#define MIGRAPHX_GUARD_RTGLIB_GPU_GEMM_HPP #define MIGRAPHX_GUARD_RTGLIB_GPU_GEMM_HPP
#include <migraphx/operation.hpp>
#include <migraphx/value.hpp>
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/reflect.hpp> #include <migraphx/reflect.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
...@@ -20,6 +22,8 @@ struct rocblas_gemm ...@@ -20,6 +22,8 @@ struct rocblas_gemm
{ {
Op op; Op op;
bool int8_x4_format = true; bool int8_x4_format = true;
float alpha = 1.0;
float beta = 0.0;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -51,7 +55,7 @@ struct rocblas_gemm ...@@ -51,7 +55,7 @@ struct rocblas_gemm
argument argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{ {
gemm(ctx, output_shape, args, op.alpha, op.beta, int8_x4_format); gemm(ctx, output_shape, args, alpha, beta, int8_x4_format);
return args.back(); return args.back();
} }
......
#include <cstring>
#include <iterator> #include <iterator>
#include <migraphx/gpu/lowering.hpp> #include <migraphx/gpu/lowering.hpp>
#include <migraphx/manage_ptr.hpp> #include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/abs.hpp> #include <migraphx/op/abs.hpp>
#include <migraphx/op/batch_norm_inference.hpp> #include <migraphx/op/batch_norm_inference.hpp>
...@@ -180,8 +183,8 @@ struct miopen_apply ...@@ -180,8 +183,8 @@ struct miopen_apply
add_extend_op("softmax"); add_extend_op("softmax");
add_extend_op("topk"); add_extend_op("topk");
add_gemm_op<op::dot>("dot"); add_gemm_op("dot");
add_gemm_op<op::quant_dot>("quant_dot"); add_int8_gemm_op("quant_dot");
add_convolution_op(); add_convolution_op();
add_deconvolution_op(); add_deconvolution_op();
add_quant_convolution_op(); add_quant_convolution_op();
...@@ -303,13 +306,12 @@ struct miopen_apply ...@@ -303,13 +306,12 @@ struct miopen_apply
}); });
} }
template <class Op> void add_int8_gemm_op(const std::string& name)
void add_gemm_op(std::string name)
{ {
apply_map.emplace(name, [=](instruction_ref ins) { apply_map.emplace(name, [=](instruction_ref ins) {
auto&& op = any_cast<Op>(ins->get_operator()); auto&& op = any_cast<op::quant_dot>(ins->get_operator());
auto beta = op.beta;
std::vector<instruction_ref> refs = ins->inputs(); std::vector<instruction_ref> refs = ins->inputs();
auto beta = op.beta;
if(refs.size() == 2) if(refs.size() == 2)
{ {
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
...@@ -334,7 +336,42 @@ struct miopen_apply ...@@ -334,7 +336,42 @@ struct miopen_apply
} }
return mod->replace_instruction( return mod->replace_instruction(
ins, rocblas_gemm<Op>{Op{op.alpha, beta}, int8_x4_format}, refs); ins,
rocblas_gemm<op::quant_dot>{op::quant_dot{op.alpha, beta},
int8_x4_format,
static_cast<float>(op.alpha),
static_cast<float>(beta)},
refs);
});
};
void add_gemm_op(const std::string& name)
{
apply_map.emplace(name, [=](instruction_ref ins) {
std::vector<instruction_ref> refs = ins->inputs();
if(refs.size() == 2)
{
auto output = insert_allocation(ins, ins->get_shape());
refs.push_back(output);
}
else
{
auto c_alias = instruction::get_output_alias(refs.back());
if(ins == last or refs.back()->outputs().size() > 1 or c_alias->inputs().empty())
{
auto output = insert_allocation(ins, ins->get_shape());
auto copy_out =
mod->insert_instruction(ins, make_op("hip::copy"), refs.back(), output);
refs.back() = copy_out;
refs.push_back(copy_out);
}
else
{
refs.push_back(refs.back());
}
}
return mod->replace_instruction(
ins, rocblas_gemm<op::dot>{op::dot{}, int8_x4_format, 1, 0}, refs);
}); });
} }
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
#include <migraphx/preallocate_param.hpp> #include <migraphx/preallocate_param.hpp>
#include <migraphx/propagate_constant.hpp> #include <migraphx/propagate_constant.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/remap.hpp>
#include <migraphx/rewrite_batchnorm.hpp> #include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/rewrite_pooling.hpp> #include <migraphx/rewrite_pooling.hpp>
#include <migraphx/rewrite_quantization.hpp> #include <migraphx/rewrite_quantization.hpp>
......
...@@ -518,42 +518,12 @@ struct ref_gemm ...@@ -518,42 +518,12 @@ struct ref_gemm
return migraphx::reflect(self.op, f); return migraphx::reflect(self.op, f);
} }
std::string name() const { return "ref::dot"; } std::string name() const { return "ref::dot"; }
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
{
if(inputs.size() == 3)
{
auto c_shape = inputs.at(2);
check_shapes{{c_shape}, *this}.not_broadcasted();
}
return op.compute_shape(inputs);
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
// 3 inputs, it is alpha * A * B + beta * C, then migemm(result, args[0], args[1], 1.0f, 0.0f);
// A and B are matrices, and C is of the same shape as A * B
if(args.size() == 3)
{
// no need to consider the value of args[2]
if(op.beta == 0.0f)
{
result.visit([&](auto output) { std::fill(output.begin(), output.end(), 0); });
}
else
{
visit_all(result, args[2])([&](auto output, auto input) {
std::copy(input.begin(), input.end(), output.begin());
});
}
migemm(result, args[0], args[1], op.alpha, op.beta);
return result;
}
// 2 input arguments
migemm(result, args[0], args[1], op.alpha, 0.0f);
return result; return result;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment