Commit e7f93b38 authored by jerryyin's avatar jerryyin
Browse files

[mlir] Adding enabled quant convolution fusion

parent 71c8181c
......@@ -174,7 +174,13 @@ std::string cpp_generator::generate_point_op(const operation& op,
else if(with_char(::isdigit)(key[0]))
{
auto i = std::stoul(key);
return args.at(i);
// For an optional argument where i >= args.size(), treat
// the optional argument as a straight zero. This will
// cacel out the optional bias, if it exists.
if(i < args.size())
return args.at(i);
else
return "0";
}
else if(v.contains(key))
{
......
......@@ -37,6 +37,12 @@ namespace op {
struct dequantizelinear
{
value attributes() const {
return {{"pointwise", true}, {"point_op",
"${1} * (${function:convert}<float>(${0}) - ${function:convert}<float>(${2}))"}};
}
std::string name() const { return "dequantizelinear"; }
shape compute_shape(std::vector<shape> inputs) const
{
......
......@@ -38,6 +38,10 @@ namespace op {
struct quantizelinear
{
std::string name() const { return "quantizelinear"; }
value attributes() const { return {{"pointwise", true}, {"point_op",
"${function:max}(${function:min}(${function:round}(${function:convert}<float>(${0}) / ${1}) + ${function:convert}<float>(${2}), 127.0), -128.0)"}}; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.same_dims().has(2, 3);
......
......@@ -38,6 +38,27 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR);
bool mlir_enabled()
{
#ifdef MIGRAPHX_MLIR
const bool mlir_enabled = enabled(MIGRAPHX_ENABLE_MLIR{});
if(mlir_enabled)
{
return true;
}
else
{
std::cerr << "WARNING: MIGraphX built with MLIR but it is not enabled. Please set the env "
"var MIGRAPHX_ENABLE_MLIR to use MLIR kernel generator."
<< std::endl;
return false;
}
#else
return false;
#endif
}
#ifdef MIGRAPHX_MLIR
struct mlir_op
......@@ -58,8 +79,11 @@ struct mlir_op
MIGRAPHX_THROW("should have one submodule.");
if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs.");
auto n = inputs.size();
return op.compute_shape({inputs[n - 2], inputs[n - 1]});
auto n = inputs.size();
auto* pm = mods.front();
auto type = pm->get_output_shapes().front().type();
auto shape = op.compute_shape({inputs[n - 2], inputs[n - 1]});
return shape.with_type(type);
}
};
MIGRAPHX_REGISTER_OP(mlir_op);
......@@ -68,7 +92,7 @@ namespace {
MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins)
{
if(ins->name() != "convolution")
if(ins->name() != "convolution" && ins->name() != "quant_convolution")
return false;
value v = ins->get_operator().to_value();
auto group = v.at("group").to<int>();
......@@ -98,14 +122,25 @@ struct find_mlir_op
auto names = pm->get_parameter_names();
// Whitelist pointwise operators
if(std::any_of(pm->begin(), pm->end(), [](const auto& i) {
return not contains(
{"@literal", "@param", "@return", "convolution", "dot", "add", "relu"},
i.name());
return not contains({"@literal",
"@param",
"@return",
"convolution",
"quant_convolution",
"dot",
"add",
"relu",
"dequantizelinear",
"quantizelinear"},
i.name());
}))
return;
// Only fuse with fp32/fp16
// Only fuse with fp32/fp16/int8/int32
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [&](auto i) {
return not contains({shape::type_t::float_type, shape::type_t::half_type},
return not contains({shape::type_t::float_type,
shape::type_t::half_type,
shape::type_t::int8_type,
shape::type_t::int32_type},
i->get_shape().type());
}))
return;
......@@ -148,17 +183,7 @@ struct find_mlir_op
void fuse_mlir::apply(module_pass_manager& mpm) const
{
#ifdef MIGRAPHX_MLIR
const bool mlir_enabled = enabled(MIGRAPHX_ENABLE_MLIR{});
if(mlir_enabled)
{
match::find_matches(mpm, find_mlir_op{});
}
else
{
std::cerr << "WARNING: MIGraphX built with MLIR but it is not enabled. Please set the env "
"var MIGRAPHX_ENABLE_MLIR to use MLIR kernel generator."
<< std::endl;
}
match::find_matches(mpm, find_mlir_op{});
#else
(void)mpm;
#endif
......
......@@ -34,6 +34,8 @@ struct module_pass_manager;
namespace gpu {
bool mlir_enabled();
struct fuse_mlir
{
context* ctx = nullptr;
......
......@@ -107,7 +107,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
normalize_ops{},
dead_code_elimination{},
simplify_qdq{},
rewrite_quantization{},
enable_pass(not mlir_enabled(), rewrite_quantization{}),
dead_code_elimination{},
eliminate_data_type{unsupported_types, shape::type_t::float_type},
simplify_reshapes{},
......@@ -133,7 +133,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{},
enable_pass(not enabled(MIGRAPHX_DISABLE_REDUCE_FUSION{}), fuse_reduce{}),
dead_code_elimination{},
fuse_mlir{&ctx},
enable_pass(mlir_enabled(), fuse_mlir{&ctx}),
dead_code_elimination{},
lowering{&ctx, options.offload_copy},
eliminate_contiguous{"gpu::contiguous"},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment