Unverified Commit 985f58b0 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Revert "Remove alpha and beta attributes from dot operator (#945)" (#957)

This reverts commit 9e43cb8b.
parent 9e43cb8b
...@@ -52,6 +52,7 @@ add_library(migraphx ...@@ -52,6 +52,7 @@ add_library(migraphx
reduce_dims.cpp reduce_dims.cpp
register_op.cpp register_op.cpp
register_target.cpp register_target.cpp
remap.cpp
simplify_qdq.cpp simplify_qdq.cpp
rewrite_batchnorm.cpp rewrite_batchnorm.cpp
rewrite_pooling.cpp rewrite_pooling.cpp
......
...@@ -113,58 +113,5 @@ instruction_ref add_common_op(module& m, const operation& op, std::vector<instru ...@@ -113,58 +113,5 @@ instruction_ref add_common_op(module& m, const operation& op, std::vector<instru
return insert_common_op(m, m.end(), op, std::move(inputs)); return insert_common_op(m, m.end(), op, std::move(inputs));
} }
instruction_ref insert_dot_apply_alpha_beta(module& m,
instruction_ref pos,
const std::vector<instruction_ref>& args,
float alpha,
float beta)
{
auto l1 = args[0];
auto l2 = args[1];
auto dot_type = l1->get_shape().type();
if(!float_equal(alpha, 1.0f))
{
auto alpha_literal = m.add_literal(alpha);
l1 = insert_common_op(m, pos, migraphx::make_op("mul"), {alpha_literal, l1});
if(l1->get_shape().type() != dot_type)
{
l1 = m.insert_instruction(pos, make_op("convert", {{"target_type", dot_type}}), l1);
}
}
auto dot_res = m.insert_instruction(pos, migraphx::make_op("dot"), l1, l2);
if(args.size() == 3)
{
if(not float_equal(beta, 0.0f) && args[2]->get_shape().elements() > 0)
{
auto out_lens = l1->get_shape().lens();
out_lens.back() = l2->get_shape().lens().back();
auto l3 = args[2];
auto l3_lens = l3->get_shape().lens();
if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
{
l3 = m.insert_instruction(
pos, migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), args[2]);
}
auto beta_literal = m.add_literal(beta);
auto beta_l3 = insert_common_op(m, pos, migraphx::make_op("mul"), {l3, beta_literal});
if(beta_l3->get_shape().type() != dot_type)
{
beta_l3 = m.insert_instruction(
pos, migraphx::make_op("convert", {{"target_type", dot_type}}), beta_l3);
}
return m.insert_instruction(pos, migraphx::make_op("add"), dot_res, beta_l3);
}
}
return dot_res;
}
instruction_ref add_dot_apply_alpha_beta(module& m,
const std::vector<instruction_ref>& args,
float alpha,
float beta)
{
return insert_dot_apply_alpha_beta(m, m.end(), args, alpha, beta);
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -27,7 +27,7 @@ alpha_beta get_alpha_beta(const operation& op) ...@@ -27,7 +27,7 @@ alpha_beta get_alpha_beta(const operation& op)
struct find_dot_add struct find_dot_add
{ {
auto matcher() const { return match::name("quant_dot")(match::nargs(3)); } auto matcher() const { return match::name("dot", "quant_dot")(match::nargs(3)); }
void apply(module& p, const match::matcher_result& r) const void apply(module& p, const match::matcher_result& r) const
{ {
...@@ -58,7 +58,7 @@ struct find_dot_add ...@@ -58,7 +58,7 @@ struct find_dot_add
struct find_dot_alpha struct find_dot_alpha
{ {
auto matcher() const { return match::name("quant_dot")(match::nargs(2)); } auto matcher() const { return match::name("dot", "quant_dot")(match::nargs(2)); }
void apply(module& p, const match::matcher_result& r) const void apply(module& p, const match::matcher_result& r) const
{ {
......
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/common.hpp>
#include "models.hpp" #include "models.hpp"
namespace migraphx { namespace migraphx {
...@@ -145,10 +144,10 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size) ...@@ -145,10 +144,10 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::multibroadcast multibroadcast42; migraphx::op::multibroadcast multibroadcast42;
multibroadcast42.output_lens = {batch, 4096}; multibroadcast42.output_lens = {batch, 4096};
auto mx42 = mm->add_instruction(multibroadcast42, mx4); auto mx42 = mm->add_instruction(multibroadcast42, mx4);
float dot43_alpha = 1; migraphx::op::dot dot43;
float dot43_beta = 1; dot43.alpha = 1;
auto mx43 = dot43.beta = 1;
migraphx::add_dot_apply_alpha_beta(*mm, {mx40, mx41, mx42}, dot43_alpha, dot43_beta); auto mx43 = mm->add_instruction(dot43, mx40, mx41, mx42);
migraphx::op::relu relu44; migraphx::op::relu relu44;
auto mx44 = mm->add_instruction(relu44, mx43); auto mx44 = mm->add_instruction(relu44, mx43);
migraphx::op::identity identity45; migraphx::op::identity identity45;
...@@ -159,10 +158,10 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size) ...@@ -159,10 +158,10 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::multibroadcast multibroadcast47; migraphx::op::multibroadcast multibroadcast47;
multibroadcast47.output_lens = {batch, 4096}; multibroadcast47.output_lens = {batch, 4096};
auto mx47 = mm->add_instruction(multibroadcast47, mx2); auto mx47 = mm->add_instruction(multibroadcast47, mx2);
float dot48_alpha = 1; migraphx::op::dot dot48;
float dot48_beta = 1; dot48.alpha = 1;
auto mx48 = dot48.beta = 1;
migraphx::add_dot_apply_alpha_beta(*mm, {mx45, mx46, mx47}, dot48_alpha, dot48_beta); auto mx48 = mm->add_instruction(dot48, mx45, mx46, mx47);
migraphx::op::relu relu49; migraphx::op::relu relu49;
auto mx49 = mm->add_instruction(relu49, mx48); auto mx49 = mm->add_instruction(relu49, mx48);
migraphx::op::transpose transpose50; migraphx::op::transpose transpose50;
...@@ -171,9 +170,10 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size) ...@@ -171,9 +170,10 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::multibroadcast multibroadcast51; migraphx::op::multibroadcast multibroadcast51;
multibroadcast51.output_lens = {batch, 1000}; multibroadcast51.output_lens = {batch, 1000};
auto mx51 = mm->add_instruction(multibroadcast51, mx0); auto mx51 = mm->add_instruction(multibroadcast51, mx0);
float dot52_alpha = 1; migraphx::op::dot dot52;
float dot52_beta = 1; dot52.alpha = 1;
migraphx::add_dot_apply_alpha_beta(*mm, {mx49, mx50, mx51}, dot52_alpha, dot52_beta); dot52.beta = 1;
mm->add_instruction(dot52, mx49, mx50, mx51);
return p; return p;
} }
......
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/common.hpp>
#include "models.hpp" #include "models.hpp"
namespace migraphx { namespace migraphx {
...@@ -2226,9 +2225,10 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -2226,9 +2225,10 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::multibroadcast multibroadcast798; migraphx::op::multibroadcast multibroadcast798;
multibroadcast798.output_lens = {batch, 1000}; multibroadcast798.output_lens = {batch, 1000};
auto mx798 = mm->add_instruction(multibroadcast798, mx0); auto mx798 = mm->add_instruction(multibroadcast798, mx0);
float dot799_alpha = 1; migraphx::op::dot dot799;
float dot799_beta = 1; dot799.alpha = 1;
migraphx::add_dot_apply_alpha_beta(*mm, {mx796, mx797, mx798}, dot799_alpha, dot799_beta); dot799.beta = 1;
mm->add_instruction(dot799, mx796, mx797, mx798);
return p; return p;
} }
......
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/common.hpp>
#include "models.hpp" #include "models.hpp"
namespace migraphx { namespace migraphx {
...@@ -1229,9 +1228,10 @@ migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size) ...@@ -1229,9 +1228,10 @@ migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::multibroadcast multibroadcast442; migraphx::op::multibroadcast multibroadcast442;
multibroadcast442.output_lens = {batch, 1000}; multibroadcast442.output_lens = {batch, 1000};
auto mx442 = mm->add_instruction(multibroadcast442, mx0); auto mx442 = mm->add_instruction(multibroadcast442, mx0);
float dot443_alpha = 1; migraphx::op::dot dot443;
float dot443_beta = 1; dot443.alpha = 1;
migraphx::add_dot_apply_alpha_beta(*mm, {mx440, mx441, mx442}, dot443_alpha, dot443_beta); dot443.beta = 1;
mm->add_instruction(dot443, mx440, mx441, mx442);
return p; return p;
} }
......
#include "migraphx/common.hpp"
#include "migraphx/errors.hpp"
#include "migraphx/float_equal.hpp"
#include <cmath>
#include <cstdint>
#include <migraphx/eliminate_data_type.hpp> #include <migraphx/eliminate_data_type.hpp>
#include <migraphx/module.hpp> #include <migraphx/module.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
...@@ -34,26 +28,12 @@ void eliminate_data_type::apply(module& m) const ...@@ -34,26 +28,12 @@ void eliminate_data_type::apply(module& m) const
continue; continue;
auto op = ins->get_operator(); auto op = ins->get_operator();
auto attributes = op.attributes(); auto attributes = op.attributes();
auto old_type = ins->get_shape().type();
auto val = op.to_value();
if(attributes.contains("general_data_type")) if(attributes.contains("general_data_type"))
{ {
if(ins->name() == "quant_dot") op = make_op(attributes["general_data_type"].to<std::string>(), op.to_value());
{
auto alpha = val.at("alpha").to<std::float_t>();
auto beta = val.at("beta").to<std::float_t>();
auto dot_res = migraphx::insert_dot_apply_alpha_beta(m, ins, inputs, alpha, beta);
auto convert = m.insert_instruction(
ins, make_op("convert", {{"target_type", old_type}}), dot_res);
m.replace_instruction(ins, convert);
return;
}
else
{
op = make_op(attributes["general_data_type"].to<std::string>(), val);
}
} }
auto out = m.insert_instruction(ins, op, inputs); auto old_type = ins->get_shape().type();
auto out = m.insert_instruction(ins, op, inputs);
auto convert = auto convert =
m.insert_instruction(ins, make_op("convert", {{"target_type", old_type}}), out); m.insert_instruction(ins, make_op("convert", {{"target_type", old_type}}), out);
m.replace_instruction(ins, convert); m.replace_instruction(ins, convert);
......
...@@ -21,17 +21,6 @@ instruction_ref insert_common_op(module& m, ...@@ -21,17 +21,6 @@ instruction_ref insert_common_op(module& m,
std::vector<instruction_ref> inputs); std::vector<instruction_ref> inputs);
instruction_ref add_common_op(module& m, const operation& op, std::vector<instruction_ref> inputs); instruction_ref add_common_op(module& m, const operation& op, std::vector<instruction_ref> inputs);
instruction_ref insert_dot_apply_alpha_beta(module& m,
instruction_ref pos,
const std::vector<instruction_ref>& args,
float alpha = 1.0,
float beta = 0.0);
instruction_ref add_dot_apply_alpha_beta(module& m,
const std::vector<instruction_ref>& args,
float alpha = 1.0,
float beta = 0.0);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_COMMON_HPP #endif // MIGRAPHX_GUARD_MIGRAPHX_COMMON_HPP
...@@ -18,6 +18,15 @@ namespace op { ...@@ -18,6 +18,15 @@ namespace op {
struct dot struct dot
{ {
float alpha = 1.0;
float beta = 1.0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.alpha, "alpha"), f(self.beta, "beta"));
}
std::string name() const { return "dot"; } std::string name() const { return "dot"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
...@@ -67,7 +76,7 @@ struct dot ...@@ -67,7 +76,7 @@ struct dot
else else
result = argument{output_shape}; result = argument{output_shape};
visit_all(result, args[0], args[1])( visit_all(result, args[0], args[1])(
[&](auto cmat, auto amat, auto bmat) { gemm(cmat, amat, bmat, 1, 0); }); [&](auto cmat, auto amat, auto bmat) { gemm(cmat, amat, bmat, alpha, beta); });
return result; return result;
} }
}; };
......
#ifndef MIGRAPHX_GUARD_RTGLIB_REMAP_HPP
#define MIGRAPHX_GUARD_RTGLIB_REMAP_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
/**
* Decompose operators.
*/
struct remap
{
std::string name() const { return "remap"; }
void apply(module& p) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -61,7 +61,7 @@ struct parse_gemm : op_parser<parse_gemm> ...@@ -61,7 +61,7 @@ struct parse_gemm : op_parser<parse_gemm>
? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[1]) ? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[1])
: args[1]; : args[1];
auto ret = info.add_instruction(make_op("dot"), l1, l2); auto ret = info.add_instruction(make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), l1, l2);
if(args.size() == 3) if(args.size() == 3)
{ {
......
...@@ -66,16 +66,9 @@ struct parse_matmul : op_parser<parse_matmul> ...@@ -66,16 +66,9 @@ struct parse_matmul : op_parser<parse_matmul>
make_op("multibroadcast", {{"out_lens", l1_broadcasted_lens}}), l1); make_op("multibroadcast", {{"out_lens", l1_broadcasted_lens}}), l1);
} }
} }
instruction_ref dot_res;
if(opd.op_name == "dot") auto dot_res =
{ info.add_instruction(make_op(opd.op_name, {{"alpha", 1}, {"beta", 0}}), bl0, bl1);
dot_res = info.add_instruction(make_op(opd.op_name), bl0, bl1);
}
else
{
dot_res =
info.add_instruction(make_op(opd.op_name, {{"alpha", 1}, {"beta", 0}}), bl0, bl1);
}
int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size()); int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
if(is_a_prepended) if(is_a_prepended)
{ {
......
#include <migraphx/remap.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/add.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace {
struct find_dot_add
{
auto matcher() const
{
return match::name("add")(match::any_of(
match::args(match::name("dot")(match::nargs(2)).bind("dot"), match::any().bind("a")),
match::args(match::used_once().bind("a"),
match::name("dot")(match::nargs(2)).bind("dot"))));
}
void apply(module& p, match::matcher_result r) const
{
auto ins = r.result;
auto dot_ins = r.instructions["dot"];
auto a_ins = r.instructions["a"];
auto dot = any_cast<op::dot>(dot_ins->get_operator());
dot.beta = 1;
p.replace_instruction(ins, dot, dot_ins->inputs()[0], dot_ins->inputs()[1], a_ins);
}
};
} // namespace
void remap::apply(module& p) const { match::find_matches(p, find_dot_add{}); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -84,6 +84,11 @@ struct match_find_quantizable_ops ...@@ -84,6 +84,11 @@ struct match_find_quantizable_ops
} }
else if(qop->name() == "dot") else if(qop->name() == "dot")
{ {
auto dot_op = any_cast<op::dot>(qop->get_operator());
if(!(float_equal(dot_op.alpha, 1.0f) and float_equal(dot_op.beta, 0.0f)))
return;
if(qop_args.size() == 3)
qop_args.pop_back();
dq = m.insert_instruction( dq = m.insert_instruction(
qop, migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), qop_args); qop, migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), qop_args);
} }
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include <migraphx/memory_coloring.hpp> #include <migraphx/memory_coloring.hpp>
#include <migraphx/propagate_constant.hpp> #include <migraphx/propagate_constant.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/remap.hpp>
#include <migraphx/rewrite_batchnorm.hpp> #include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/rewrite_pooling.hpp> #include <migraphx/rewrite_pooling.hpp>
#include <migraphx/rewrite_quantization.hpp> #include <migraphx/rewrite_quantization.hpp>
......
...@@ -697,6 +697,52 @@ struct find_conv_bias_relu ...@@ -697,6 +697,52 @@ struct find_conv_bias_relu
apply_conv_bias<miopen_conv_bias_relu>(*ctx, p, std::move(r)); apply_conv_bias<miopen_conv_bias_relu>(*ctx, p, std::move(r));
} }
}; };
struct find_gemm_add
{
auto matcher() const
{
return match::name("gpu::add")(
match::all_of[match::inputs()](match::standard_shape()),
match::either_arg(0, 1)(match::used_once().bind("c"),
match::name("gpu::gemm")(match::nargs(3)).bind("gemm")));
}
void apply(module& p, match::matcher_result r) const
{
auto ins = r.result;
auto gemm_ins = r.instructions["gemm"];
auto c_ins = r.instructions["c"];
auto gemm = any_cast<rocblas_gemm<op::dot>>(gemm_ins->get_operator());
// Already fused gemm
if(not float_equal(gemm.op.beta, 0))
return;
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto i) {
return not i->get_shape().standard();
}))
return;
auto inputs = gemm_ins->inputs();
inputs.pop_back();
auto copy_ins = c_ins;
// Insert copy
if(ins == p.end() or c_ins->outputs().size() > 1 or c_ins->inputs().empty())
{
copy_ins = p.insert_instruction(ins, hip_copy{}, c_ins, ins->inputs().back());
}
inputs.push_back(copy_ins);
inputs.push_back(copy_ins);
gemm.op.beta = 1;
p.replace_instruction(ins, gemm, inputs);
}
};
struct find_commutative_broadcast struct find_commutative_broadcast
{ {
auto matcher() const auto matcher() const
...@@ -732,7 +778,7 @@ void fuse_ops::apply(module& p) const ...@@ -732,7 +778,7 @@ void fuse_ops::apply(module& p) const
find_add_unary{"gpu::tanh", hip_add_tanh{}, hip_triadd_tanh{}}, find_add_unary{"gpu::tanh", hip_add_tanh{}, hip_triadd_tanh{}},
find_add_clip{}); find_add_clip{});
run_passes(p, {dead_code_elimination{}}); run_passes(p, {dead_code_elimination{}});
match::find_matches(p, find_triadd_layernorm{}, find_commutative_broadcast{}); match::find_matches(p, find_triadd_layernorm{}, find_gemm_add{}, find_commutative_broadcast{});
} }
} // namespace gpu } // namespace gpu
......
#ifndef MIGRAPHX_GUARD_RTGLIB_GPU_GEMM_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_GPU_GEMM_HPP
#define MIGRAPHX_GUARD_RTGLIB_GPU_GEMM_HPP #define MIGRAPHX_GUARD_RTGLIB_GPU_GEMM_HPP
#include <migraphx/operation.hpp>
#include <migraphx/value.hpp>
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/reflect.hpp> #include <migraphx/reflect.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
...@@ -22,8 +20,6 @@ struct rocblas_gemm ...@@ -22,8 +20,6 @@ struct rocblas_gemm
{ {
Op op; Op op;
bool int8_x4_format = true; bool int8_x4_format = true;
float alpha = 1.0;
float beta = 0.0;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -55,7 +51,7 @@ struct rocblas_gemm ...@@ -55,7 +51,7 @@ struct rocblas_gemm
argument argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{ {
gemm(ctx, output_shape, args, alpha, beta, int8_x4_format); gemm(ctx, output_shape, args, op.alpha, op.beta, int8_x4_format);
return args.back(); return args.back();
} }
......
#include <cstring>
#include <iterator> #include <iterator>
#include <migraphx/gpu/lowering.hpp> #include <migraphx/gpu/lowering.hpp>
#include <migraphx/manage_ptr.hpp> #include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/abs.hpp> #include <migraphx/op/abs.hpp>
#include <migraphx/op/batch_norm_inference.hpp> #include <migraphx/op/batch_norm_inference.hpp>
...@@ -183,8 +180,8 @@ struct miopen_apply ...@@ -183,8 +180,8 @@ struct miopen_apply
add_extend_op("softmax"); add_extend_op("softmax");
add_extend_op("topk"); add_extend_op("topk");
add_gemm_op("dot"); add_gemm_op<op::dot>("dot");
add_int8_gemm_op("quant_dot"); add_gemm_op<op::quant_dot>("quant_dot");
add_convolution_op(); add_convolution_op();
add_deconvolution_op(); add_deconvolution_op();
add_quant_convolution_op(); add_quant_convolution_op();
...@@ -306,12 +303,13 @@ struct miopen_apply ...@@ -306,12 +303,13 @@ struct miopen_apply
}); });
} }
void add_int8_gemm_op(const std::string& name) template <class Op>
void add_gemm_op(std::string name)
{ {
apply_map.emplace(name, [=](instruction_ref ins) { apply_map.emplace(name, [=](instruction_ref ins) {
auto&& op = any_cast<op::quant_dot>(ins->get_operator()); auto&& op = any_cast<Op>(ins->get_operator());
std::vector<instruction_ref> refs = ins->inputs();
auto beta = op.beta; auto beta = op.beta;
std::vector<instruction_ref> refs = ins->inputs();
if(refs.size() == 2) if(refs.size() == 2)
{ {
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
...@@ -336,42 +334,7 @@ struct miopen_apply ...@@ -336,42 +334,7 @@ struct miopen_apply
} }
return mod->replace_instruction( return mod->replace_instruction(
ins, ins, rocblas_gemm<Op>{Op{op.alpha, beta}, int8_x4_format}, refs);
rocblas_gemm<op::quant_dot>{op::quant_dot{op.alpha, beta},
int8_x4_format,
static_cast<float>(op.alpha),
static_cast<float>(beta)},
refs);
});
};
void add_gemm_op(const std::string& name)
{
apply_map.emplace(name, [=](instruction_ref ins) {
std::vector<instruction_ref> refs = ins->inputs();
if(refs.size() == 2)
{
auto output = insert_allocation(ins, ins->get_shape());
refs.push_back(output);
}
else
{
auto c_alias = instruction::get_output_alias(refs.back());
if(ins == last or refs.back()->outputs().size() > 1 or c_alias->inputs().empty())
{
auto output = insert_allocation(ins, ins->get_shape());
auto copy_out =
mod->insert_instruction(ins, make_op("hip::copy"), refs.back(), output);
refs.back() = copy_out;
refs.push_back(copy_out);
}
else
{
refs.push_back(refs.back());
}
}
return mod->replace_instruction(
ins, rocblas_gemm<op::dot>{op::dot{}, int8_x4_format, 1, 0}, refs);
}); });
} }
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <migraphx/preallocate_param.hpp> #include <migraphx/preallocate_param.hpp>
#include <migraphx/propagate_constant.hpp> #include <migraphx/propagate_constant.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/remap.hpp>
#include <migraphx/rewrite_batchnorm.hpp> #include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/rewrite_pooling.hpp> #include <migraphx/rewrite_pooling.hpp>
#include <migraphx/rewrite_quantization.hpp> #include <migraphx/rewrite_quantization.hpp>
......
...@@ -518,12 +518,42 @@ struct ref_gemm ...@@ -518,12 +518,42 @@ struct ref_gemm
return migraphx::reflect(self.op, f); return migraphx::reflect(self.op, f);
} }
std::string name() const { return "ref::dot"; } std::string name() const { return "ref::dot"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const
{
if(inputs.size() == 3)
{
auto c_shape = inputs.at(2);
check_shapes{{c_shape}, *this}.not_broadcasted();
}
return op.compute_shape(inputs);
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
migemm(result, args[0], args[1], 1.0f, 0.0f); // 3 inputs, it is alpha * A * B + beta * C, then
// A and B are matrices, and C is of the same shape as A * B
if(args.size() == 3)
{
// no need to consider the value of args[2]
if(op.beta == 0.0f)
{
result.visit([&](auto output) { std::fill(output.begin(), output.end(), 0); });
}
else
{
visit_all(result, args[2])([&](auto output, auto input) {
std::copy(input.begin(), input.end(), output.begin());
});
}
migemm(result, args[0], args[1], op.alpha, op.beta);
return result;
}
// 2 input arguments
migemm(result, args[0], args[1], op.alpha, 0.0f);
return result; return result;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment