Unverified Commit ac15425a authored by Krzysztof Drewniak's avatar Krzysztof Drewniak Committed by GitHub
Browse files

[mlir] Improve context handling to potentially solve threading bugs (#1867)

Update `mlir_program` to only create one dialect registry, and to call
registerRocMLIRPasses() (which is needed and may not be thread-safe)
exactly once. 

In addition, use a single thread pool across all contexts. This is
recommended practice upstream for libraries that perform a lot of
compile jobs, and saves on the overhead of creating and destroying a
lot of threads
parent 632d69ff
......@@ -113,7 +113,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR
ADD tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh
RUN cget -p /usr/local install ROCmSoftwarePlatform/rocMLIR@f5ab829b1a46eca600eb8e9df4eaa9944c845f07 -DBUILD_MIXR_TARGET=On -DLLVM_ENABLE_ZSTD=Off -DLLVM_ENABLE_THREADS=Off
RUN cget -p /usr/local install ROCmSoftwarePlatform/rocMLIR@8d25af3b3721c159bb41cc6388e9453b1018c126 -DBUILD_MIXR_TARGET=On -DLLVM_ENABLE_ZSTD=Off -DLLVM_ENABLE_THREADS=Off
ENV MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db
ENV MIOPEN_USER_DB_PATH=/tmp/miopen/user-db
......
......@@ -122,6 +122,9 @@ struct mlir_handle
#define MIGRAPHX_MANAGE_MLIR_HANDLE(T, F) migraphx::gpu::mlir_handle<T, decltype(&F), &F> // NOLINT
using mlir_context = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirContext, mlirContextDestroy);
using mlir_thread_pool = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirLlvmThreadPool, mlirLlvmThreadPoolDestroy);
using mlir_dialect_registry = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirDialectRegistry,
mlirDialectRegistryDestroy);
using mlir_module = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirModule, mlirModuleDestroy);
using mlir_operation = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirOperation, mlirOperationDestroy);
using mlir_op_printing_flags = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirOpPrintingFlags,
......@@ -173,16 +176,38 @@ bool has_xdlops(const std::string& target_arch)
struct mlir_program
{
mlir_program()
: ctx(mlirContextCreate()),
: ctx(mlirContextCreateWithRegistry(get_dialect_registry().get(),
/*threadingEnable=*/false)),
location(mlirLocationUnknownGet(ctx.get())),
mmodule(mlirModuleCreateEmpty(location))
{
MlirDialectRegistry registry = mlirDialectRegistryCreate();
mlirRegisterRocMLIRDialects(registry);
mlirContextAppendDialectRegistry(ctx.get(), registry);
mlirContextSetThreadPool(ctx.get(), get_thread_pool().get());
mlirContextLoadAllAvailableDialects(ctx.get());
mlirDialectRegistryDestroy(registry);
mlirContextSetAllowUnregisteredDialects(ctx.get(), true /*allow*/);
}
static mlir_dialect_registry& get_dialect_registry()
{
static std::once_flag init_guard;
static mlir_dialect_registry the_registry;
// The MLIR registration functions (for dialects and passes) are not
// necessarily thread-safe and need to be executed exactly once
// (especially since they eventually call non-thread-safe LLVM
// initilizations).
std::call_once(init_guard, [&]() {
the_registry = mlirDialectRegistryCreate();
mlirRegisterRocMLIRDialects(the_registry.get());
mlirRegisterRocMLIRPasses();
});
return the_registry;
}
static mlir_thread_pool& get_thread_pool()
{
// To save on overhead, we create one LLVM thread pool and reuse it
// across all MLIR contexts as recommended by MLIR upstream.
// Note that this is thread-safe as of C++11.
static mlir_thread_pool the_pool = mlirLlvmThreadPoolCreate();
return the_pool;
}
MlirType make_type(shape::type_t t) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment