Unverified Commit 00640366 authored by ravil-mobile's avatar ravil-mobile Committed by GitHub
Browse files

Added support for standalone dot operations with mlir (#2169)

parent 205306ac
...@@ -302,10 +302,8 @@ struct find_mlir_fused_ops ...@@ -302,10 +302,8 @@ struct find_mlir_fused_ops
} }
}; };
struct find_mlir_standalone_convolution_op struct find_mlir_standalone_op
{ {
auto matcher() const { return match::name("convolution"); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{ {
auto conv_based_op = r.result; auto conv_based_op = r.result;
...@@ -327,6 +325,16 @@ struct find_mlir_standalone_convolution_op ...@@ -327,6 +325,16 @@ struct find_mlir_standalone_convolution_op
} }
}; };
struct find_mlir_standalone_convolution_op : find_mlir_standalone_op
{
auto matcher() const { return match::name("convolution"); }
};
struct find_mlir_standalone_dot_op : find_mlir_standalone_op
{
auto matcher() const { return match::name("dot"); }
};
/** /**
* @brief Declares a new MIGraphX environment variable which forces to generate * @brief Declares a new MIGraphX environment variable which forces to generate
* only specific MLIR operations. * only specific MLIR operations.
...@@ -334,7 +342,7 @@ struct find_mlir_standalone_convolution_op ...@@ -334,7 +342,7 @@ struct find_mlir_standalone_convolution_op
* The variable, if defined, forces MIGraphX to use only specific operations * The variable, if defined, forces MIGraphX to use only specific operations
* with MLIR regardless of the underlying GPU architecture. The variable accepts * with MLIR regardless of the underlying GPU architecture. The variable accepts
* a list of operations separated by comma. The variable recognizes the following * a list of operations separated by comma. The variable recognizes the following
* operations: "fused", "convolution". If the variable is not defined MIGraphX * operations: "fused", "convolution", "dot". If the variable is not defined MIGraphX
* will decide by itself which operations to delegate to MLIR. The variable is * will decide by itself which operations to delegate to MLIR. The variable is
* intended to be primarily used by rocMLIR developers. * intended to be primarily used by rocMLIR developers.
*/ */
...@@ -349,31 +357,33 @@ bool is_requested(std::string_view option) ...@@ -349,31 +357,33 @@ bool is_requested(std::string_view option)
return contains(options, option); return contains(options, option);
} }
bool is_fusion_enabled() bool is_enabled(std::string_view op_name, context* ctx)
{ {
if(is_self_decide()) if(is_self_decide())
{ {
return true; if(op_name == "fused")
}
return is_requested("fused");
}
bool is_standalone_convs_enabled(context* ctx)
{
if(is_self_decide())
{
if(ctx == nullptr)
{ {
return false; return true;
}
else if(op_name == "convolution")
{
if(ctx == nullptr)
{
return false;
}
else
{
const auto& device = ctx->get_current_device();
const std::string navi_family{"gfx110"};
return starts_with(device.get_gfx_name(), navi_family);
}
} }
else else
{ {
const auto& device = ctx->get_current_device(); return false;
const std::string navi_family{"gfx110"};
return starts_with(device.get_gfx_name(), navi_family);
} }
} }
return is_requested("convolution"); return is_requested(op_name);
} }
} // namespace } // namespace
...@@ -382,21 +392,25 @@ bool is_standalone_convs_enabled(context* ctx) ...@@ -382,21 +392,25 @@ bool is_standalone_convs_enabled(context* ctx)
void fuse_mlir::apply(module_pass_manager& mpm) const void fuse_mlir::apply(module_pass_manager& mpm) const
{ {
#ifdef MIGRAPHX_MLIR #ifdef MIGRAPHX_MLIR
if(is_fusion_enabled()) if(is_enabled("fused", this->ctx))
{ {
match::find_matches(mpm, find_mlir_fused_ops{}); match::find_matches(mpm, find_mlir_fused_ops{});
} }
if(is_standalone_convs_enabled(this->ctx)) if(is_enabled("convolution", this->ctx))
{ {
match::find_matches(mpm, find_mlir_standalone_convolution_op{}); match::find_matches(mpm, find_mlir_standalone_convolution_op{});
} }
if(is_enabled("dot", this->ctx))
{
match::find_matches(mpm, find_mlir_standalone_dot_op{});
}
#else #else
(void)mpm; (void)mpm;
#endif #endif
} }
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment