Unverified Commit 48af0bcf authored by Krzysztof Drewniak's avatar Krzysztof Drewniak Committed by GitHub
Browse files

[MLIR] Capture diagnostics, handle compile failure (#2001)

Add mlir_logger, which registers a MLIR diagnostic handler that
captures any information generated by a MLIR compile and saves it to a
string.

This will be useful during tuning, where some such errors may be the
result of an inapplicable tuning configuration and should be
suppressed.
parent 7c8f6c25
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "migraphx/make_op.hpp" #include "migraphx/make_op.hpp"
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/gpu/mlir.hpp> #include <migraphx/gpu/mlir.hpp>
#include <ostream>
#ifdef MIGRAPHX_MLIR #ifdef MIGRAPHX_MLIR
#include <mlir-c/IR.h> #include <mlir-c/IR.h>
...@@ -34,6 +35,7 @@ ...@@ -34,6 +35,7 @@
#include <mlir-c/Dialect/Rock.h> #include <mlir-c/Dialect/Rock.h>
#include <mlir-c/IntegerSet.h> #include <mlir-c/IntegerSet.h>
#include <mlir-c/Pass.h> #include <mlir-c/Pass.h>
#include <mlir-c/Support.h>
#include <mutex> #include <mutex>
#if !defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) || MLIR_MIGRAPHX_DIALECT_API_VERSION != 3 #if !defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) || MLIR_MIGRAPHX_DIALECT_API_VERSION != 3
#warning "Incompatible version of rocMLIR library used, disabling" #warning "Incompatible version of rocMLIR library used, disabling"
...@@ -180,13 +182,85 @@ std::string mlir_print(F f, T x) ...@@ -180,13 +182,85 @@ std::string mlir_print(F f, T x)
return ss.str(); return ss.str();
} }
struct mlir_logger
{
std::stringstream ss;
mlir_context* ctx;
std::optional<MlirDiagnosticHandlerID> id;
mlir_logger() : ctx(nullptr), id(std::nullopt) {}
mlir_logger(mlir_context* context) : ctx(context)
{
id =
mlirContextAttachDiagnosticHandler(ctx->get(), mlir_diagnostic_print_cb, this, nullptr);
}
~mlir_logger()
{
if(id.has_value())
mlirContextDetachDiagnosticHandler(ctx->get(), *id);
}
mlir_logger(const mlir_logger& other) = delete;
mlir_logger& operator=(const mlir_logger& other) = delete;
mlir_logger(mlir_logger&& other) noexcept
: ss(std::move(other.ss)), ctx(other.ctx), id(other.id)
{
other.ctx = nullptr;
other.id = std::nullopt;
}
mlir_logger& operator=(mlir_logger other) noexcept
{
std::swap(ss, other.ss);
std::swap(ctx, other.ctx);
std::swap(id, other.id);
return *this;
}
std::string str() const { return ss.str(); }
void clear() { ss = std::stringstream{}; }
static MlirLogicalResult mlir_diagnostic_print_cb(MlirDiagnostic diag, void* logger);
MlirLogicalResult handle(MlirDiagnostic diag);
};
MlirLogicalResult mlir_logger::mlir_diagnostic_print_cb(MlirDiagnostic diag, void* logger)
{
return reinterpret_cast<mlir_logger*>(logger)->handle(diag);
}
MlirLogicalResult mlir_logger::handle(MlirDiagnostic diag)
{
MlirDiagnosticSeverity sev = mlirDiagnosticGetSeverity(diag);
switch(sev)
{
case MlirDiagnosticSeverity::MlirDiagnosticError: ss << "Error: "; break;
case MlirDiagnosticSeverity::MlirDiagnosticWarning: ss << "Warning: "; break;
case MlirDiagnosticSeverity::MlirDiagnosticNote: ss << "Note: "; break;
case MlirDiagnosticSeverity::MlirDiagnosticRemark: ss << "Remark: "; break;
}
mlir_print(mlirDiagnosticPrint, diag, [&](auto s) { ss << s; });
ss << std::endl;
for(intptr_t i = 0, e = mlirDiagnosticGetNumNotes(diag); i < e; ++i)
{
(void)handle(mlirDiagnosticGetNote(diag, i));
}
return mlirLogicalResultSuccess();
}
struct mlir_program struct mlir_program
{ {
mlir_program() mlir_program()
: ctx(mlirContextCreateWithRegistry(get_dialect_registry().get(), : ctx(mlirContextCreateWithRegistry(get_dialect_registry().get(),
/*threadingEnable=*/false)), /*threadingEnable=*/false)),
location(mlirLocationUnknownGet(ctx.get())), location(mlirLocationUnknownGet(ctx.get())),
mmodule(mlirModuleCreateEmpty(location)) mmodule(mlirModuleCreateEmpty(location)),
logger(&ctx)
{ {
mlirContextSetThreadPool(ctx.get(), get_thread_pool().get()); mlirContextSetThreadPool(ctx.get(), get_thread_pool().get());
mlirContextLoadAllAvailableDialects(ctx.get()); mlirContextLoadAllAvailableDialects(ctx.get());
...@@ -614,21 +688,49 @@ struct mlir_program ...@@ -614,21 +688,49 @@ struct mlir_program
} }
} }
void run_high_level_pipeline() MIGRAPHX_TIDY_CONST void run_high_level_pipeline()
{ {
mlir_pass_manager pm_front{mlirPassManagerCreate(ctx.get())}; mlir_pass_manager pm_front{mlirPassManagerCreate(ctx.get())};
mlirMIGraphXAddHighLevelPipeline(pm_front.get()); mlirMIGraphXAddHighLevelPipeline(pm_front.get());
mlirPassManagerRunOnOp(pm_front.get(), mlirModuleGetOperation(mmodule.get())); logger.clear();
if(mlirLogicalResultIsFailure(
mlirPassManagerRunOnOp(pm_front.get(), mlirModuleGetOperation(mmodule.get()))))
{
std::string error = "Invalid MLIR created: " + logger.str();
if(enabled(MIGRAPHX_TRACE_MLIR{}))
{
std::cout << error << std::endl;
}
MIGRAPHX_THROW(error);
}
} }
void run_backend_pipeline() MIGRAPHX_TIDY_CONST void run_backend_pipeline()
{ {
mlir_pass_manager pm_back{mlirPassManagerCreate(ctx.get())}; mlir_pass_manager pm_back{mlirPassManagerCreate(ctx.get())};
mlirMIGraphXAddBackendPipeline(pm_back.get(), target_arch.c_str()); mlirMIGraphXAddBackendPipeline(pm_back.get(), target_arch.c_str());
mlirPassManagerRunOnOp(pm_back.get(), mlirModuleGetOperation(mmodule.get())); logger.clear();
const size_t trace = value_of(MIGRAPHX_TRACE_MLIR{});
static std::mutex mutex;
auto mod_op = mlirModuleGetOperation(mmodule.get());
if(trace >= 2)
{
const std::lock_guard<std::mutex> lock(mutex);
std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl;
}
if(mlirLogicalResultIsFailure(mlirPassManagerRunOnOp(pm_back.get(), mod_op)))
{
std::string error = "MLIR backend compilation failed: " + logger.str();
if(enabled(MIGRAPHX_TRACE_MLIR{}))
{
std::cout << error << std::endl;
}
MIGRAPHX_THROW(error);
}
} }
code_object_op compile(const value& solution) MIGRAPHX_TIDY_CONST code_object_op compile(const value& solution)
{ {
// 1st pipeline to call // 1st pipeline to call
run_high_level_pipeline(); run_high_level_pipeline();
...@@ -682,7 +784,7 @@ struct mlir_program ...@@ -682,7 +784,7 @@ struct mlir_program
MIGRAPHX_THROW("Failed setting tuning key: " + *str); MIGRAPHX_THROW("Failed setting tuning key: " + *str);
} }
tuning_config get_tuning_config(bool exhaustive) MIGRAPHX_TIDY_CONST tuning_config get_tuning_config(bool exhaustive)
{ {
tuning_config tc; tuning_config tc;
run_high_level_pipeline(); run_high_level_pipeline();
...@@ -809,6 +911,7 @@ struct mlir_program ...@@ -809,6 +911,7 @@ struct mlir_program
mlir_context ctx; mlir_context ctx;
MlirLocation location; MlirLocation location;
mlir_module mmodule; mlir_module mmodule;
mlir_logger logger;
problem_params pp; problem_params pp;
std::deque<std::string> strings{}; std::deque<std::string> strings{};
std::string target_arch = ""; std::string target_arch = "";
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment