Unverified Commit 00640366 authored by ravil-mobile's avatar ravil-mobile Committed by GitHub
Browse files

Added support for standalone dot operations with mlir (#2169)

parent 205306ac
......@@ -302,10 +302,8 @@ struct find_mlir_fused_ops
}
};
struct find_mlir_standalone_convolution_op
struct find_mlir_standalone_op
{
auto matcher() const { return match::name("convolution"); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto conv_based_op = r.result;
......@@ -327,6 +325,16 @@ struct find_mlir_standalone_convolution_op
}
};
struct find_mlir_standalone_convolution_op : find_mlir_standalone_op
{
auto matcher() const { return match::name("convolution"); }
};
struct find_mlir_standalone_dot_op : find_mlir_standalone_op
{
auto matcher() const { return match::name("dot"); }
};
/**
* @brief Declares a new MIGraphX environment variable which forces to generate
* only specific MLIR operations.
......@@ -334,7 +342,7 @@ struct find_mlir_standalone_convolution_op
* The variable, if defined, forces MIGraphX to use only specific operations
* with MLIR regardless of the underlying GPU architecture. The variable accepts
* a list of operations separated by comma. The variable recognizes the following
* operations: "fused", "convolution". If the variable is not defined MIGraphX
* operations: "fused", "convolution", "dot". If the variable is not defined MIGraphX
* will decide by itself which operations to delegate to MLIR. The variable is
* intended to be primarily used by rocMLIR developers.
*/
......@@ -349,31 +357,33 @@ bool is_requested(std::string_view option)
return contains(options, option);
}
bool is_fusion_enabled()
bool is_enabled(std::string_view op_name, context* ctx)
{
if(is_self_decide())
{
return true;
}
return is_requested("fused");
}
bool is_standalone_convs_enabled(context* ctx)
{
if(is_self_decide())
{
if(ctx == nullptr)
if(op_name == "fused")
{
return false;
return true;
}
else if(op_name == "convolution")
{
if(ctx == nullptr)
{
return false;
}
else
{
const auto& device = ctx->get_current_device();
const std::string navi_family{"gfx110"};
return starts_with(device.get_gfx_name(), navi_family);
}
}
else
{
const auto& device = ctx->get_current_device();
const std::string navi_family{"gfx110"};
return starts_with(device.get_gfx_name(), navi_family);
return false;
}
}
return is_requested("convolution");
return is_requested(op_name);
}
} // namespace
......@@ -382,21 +392,25 @@ bool is_standalone_convs_enabled(context* ctx)
void fuse_mlir::apply(module_pass_manager& mpm) const
{
#ifdef MIGRAPHX_MLIR
if(is_fusion_enabled())
if(is_enabled("fused", this->ctx))
{
match::find_matches(mpm, find_mlir_fused_ops{});
}
if(is_standalone_convs_enabled(this->ctx))
if(is_enabled("convolution", this->ctx))
{
match::find_matches(mpm, find_mlir_standalone_convolution_op{});
}
if(is_enabled("dot", this->ctx))
{
match::find_matches(mpm, find_mlir_standalone_dot_op{});
}
#else
(void)mpm;
#endif
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment