Unverified Commit 7a7040aa authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Dynamically plug-in backend target libs (#1608)

Fixes #1595
parent 9ef6801e
......@@ -46,6 +46,8 @@ else()
set(CMAKE_INSTALL_PREFIX "/opt/rocm" CACHE PATH "")
endif()
set(CMAKE_BUILD_RPATH "${CMAKE_BINARY_DIR}/lib")
project(migraphx)
find_package(ROCM REQUIRED)
......
......@@ -29,7 +29,7 @@ set_target_properties(migraphx_c PROPERTIES EXPORT_NAME c)
rocm_set_soversion(migraphx_c 3.0)
rocm_clang_tidy_check(migraphx_c)
target_link_libraries(migraphx_c PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_all_targets)
target_link_libraries(migraphx_c PRIVATE migraphx migraphx_tf migraphx_onnx)
rocm_install_targets(
TARGETS migraphx_c
......
......@@ -32,7 +32,6 @@
#include <migraphx/register_target.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/load_save.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
......
......@@ -42,7 +42,7 @@ add_custom_command(
)
set_directory_properties(PROPERTIES ADDITIONAL_MAKE_CLEAN_FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver)
rocm_clang_tidy_check(driver)
target_link_libraries(driver migraphx_all_targets migraphx_onnx migraphx_tf)
target_link_libraries(driver migraphx migraphx_onnx migraphx_tf)
rocm_install_targets(
TARGETS driver
......
......@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify.hpp"
#include "argument_parser.hpp"
#include "command.hpp"
......
......@@ -24,7 +24,7 @@
#include "verify.hpp"
#include "perf.hpp"
#include <migraphx/ref/target.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/verify_args.hpp>
#include <migraphx/instruction.hpp>
......@@ -37,7 +37,7 @@ inline namespace MIGRAPHX_INLINE_NS {
std::vector<argument> run_ref(program p, const parameter_map& inputs)
{
p.compile(ref::target{});
p.compile(migraphx::make_target("ref"));
auto out = p.eval(inputs);
std::cout << p << std::endl;
return out;
......
......@@ -21,25 +21,44 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/manage_ptr.hpp>
#include <migraphx/dynamic_loader.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/file_buffer.hpp>
#include <migraphx/tmp_dir.hpp>
#include <utility>
#include <dlfcn.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void check_load_error(bool flush = false)
{
char* error_msg = dlerror();
if(not flush and error_msg != nullptr)
MIGRAPHX_THROW("Dynamic loading or symbol lookup failed with " + std::string(error_msg));
}
struct dynamic_loader_impl
{
dynamic_loader_impl() = default;
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wignored-attributes"
#endif
dynamic_loader_impl(const fs::path& p, std::shared_ptr<tmp_dir> t = nullptr)
: handle(dlopen(p.string().c_str(), RTLD_LAZY), &dlclose), temp(std::move(t))
: handle(dlopen(p.string().c_str(), RTLD_LAZY),
manage_deleter<decltype(&dlclose), &dlclose>{}),
temp(std::move(t))
{
check_load_error();
}
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic pop
#endif
static std::shared_ptr<dynamic_loader_impl> from_buffer(const char* image, std::size_t size)
{
auto t = std::make_shared<tmp_dir>("dloader");
......@@ -68,10 +87,11 @@ dynamic_loader::dynamic_loader(const std::vector<char>& buffer)
std::shared_ptr<void> dynamic_loader::get_symbol(const std::string& name) const
{
dlerror();
// flush any previous error messages
check_load_error(true);
void* symbol = dlsym(impl->handle.get(), name.c_str());
if(symbol == nullptr)
MIGRAPHX_THROW("Symbol not found: " + name);
check_load_error();
return {impl, symbol};
}
......
......@@ -33,15 +33,36 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
// unregister all ops for specified target, useful when unloading dynamically plugged-in target lib
void unregister_op(const std::string& op_name);
namespace detail {
struct op_handler
{
operation op;
std::string name;
op_handler(const operation& op_r) : op(op_r), name(op.name()){};
~op_handler() { unregister_op(name); }
};
} // namespace detail
void register_op_init();
void register_op(const operation& op);
operation load_op(const std::string& name);
bool has_op(const std::string& name);
std::vector<std::string> get_operators();
template <class T>
void register_op()
{
register_op(T{});
register_op_init(); // instantiate static op_map;
static auto op_h = detail::op_handler(T{});
register_op(op_h.op);
}
struct register_op_action
......
......@@ -33,14 +33,28 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void register_target_init();
void register_target(const target& t);
void unregister_target(const std::string& name);
target make_target(const std::string& name);
std::vector<std::string> get_targets();
namespace detail {
struct target_handler
{
target t;
std::string target_name;
target_handler(const target& t_r) : t(t_r), target_name(t.name()) {}
~target_handler() { unregister_target(target_name); }
};
} // namespace detail
template <class T>
void register_target()
{
register_target(T{});
register_target_init();
static auto t_h = detail::target_handler(T{});
register_target(t_h.t);
}
struct register_target_action
......
......@@ -33,7 +33,17 @@ std::unordered_map<std::string, operation>& op_map()
static std::unordered_map<std::string, operation> m; // NOLINT
return m;
}
void register_op_init() { (void)op_map(); }
void register_op(const operation& op) { op_map()[op.name()] = op; }
void unregister_op(const std::string& op_name)
{
assert(op_map().count(op_name));
op_map().erase(op_name);
}
operation load_op(const std::string& name)
{
return at(op_map(), name, "Operator not found: " + name);
......
......@@ -21,26 +21,48 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <string>
#include <unordered_map>
#include <migraphx/register_target.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/dynamic_loader.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void store_target_lib(const dynamic_loader& lib)
{
static std::vector<dynamic_loader> target_loader;
target_loader.emplace_back(lib);
}
std::unordered_map<std::string, target>& target_map()
{
static std::unordered_map<std::string, target> m; // NOLINT
return m;
}
void register_target_init() { (void)target_map(); }
void unregister_target(const std::string& name)
{
assert(target_map().count(name));
target_map().erase(name);
}
void register_target(const target& t) { target_map()[t.name()] = t; }
target make_target(const std::string& name)
{
if(not contains(target_map(), name))
{
std::string target_name = "libmigraphx_" + name + ".so";
store_target_lib(dynamic_loader(target_name));
}
const auto it = target_map().find(name);
if(it == target_map().end())
{
MIGRAPHX_THROW("Requested target '" + name + "' is not enabled or not supported");
MIGRAPHX_THROW("Requested target '" + name + "' is not loaded or not supported");
}
return it->second;
}
......
......@@ -40,14 +40,11 @@ struct target
std::string name() const;
std::vector<pass> get_passes(migraphx::context& gctx, const compile_options&) const;
migraphx::context get_context() const { return context{}; }
argument copy_to(const argument& arg) const { return arg; }
argument copy_from(const argument& arg) const { return arg; }
argument allocate(const shape& s) const;
};
MIGRAPHX_REGISTER_TARGET(target);
} // namespace cpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -23,7 +23,6 @@
*/
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/check_context.hpp>
#include <migraphx/adjust_allocation.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_allocation.hpp>
......
......@@ -43,14 +43,11 @@ struct target
std::vector<pass> get_passes(migraphx::context& ctx, const compile_options&) const;
migraphx::context get_context() const { return context{}; }
supported_segments find_supported(const_module_ref mod, support_metric m) const;
argument copy_to(const argument& arg) const { return arg; }
argument copy_from(const argument& arg) const { return arg; }
argument allocate(const shape& s) const;
};
MIGRAPHX_REGISTER_TARGET(target);
} // namespace fpga
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -27,7 +27,6 @@
#include <migraphx/shape.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/op/convolution.hpp>
......
......@@ -56,7 +56,6 @@ struct oper
return name.substr(pos_ns + 2);
}
}
return "unknown_operator_name";
}
};
......
......@@ -37,7 +37,6 @@ struct target
std::string name() const;
std::vector<pass> get_passes(migraphx::context& gctx, const compile_options& options) const;
migraphx::context get_context() const;
argument copy_to(const argument& arg) const;
argument copy_from(const argument& arg) const;
argument allocate(const shape& s) const;
......
......@@ -74,7 +74,6 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_POINTWISE_FUSION)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC)
struct id_pass
{
std::string name() const { return "id"; }
......
......@@ -46,8 +46,6 @@ struct target
argument allocate(const shape& s) const;
};
MIGRAPHX_REGISTER_TARGET(target);
} // namespace ref
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -106,19 +106,11 @@ function(add_test_executable TEST_NAME)
if(CMAKE_CXX_COMPILER_ID MATCHES "GNU")
set_target_properties(${TEST_NAME} PROPERTIES COMPILE_FLAGS -pthread LINK_FLAGS -pthread)
endif()
separate_arguments(MIOPEN_TEST_FLAGS_ARGS UNIX_COMMAND ${MIOPEN_TEST_FLAGS})
if(MIOPEN_TEST_ALL)
set(TEST_COMMAND ${TEST_NAME} ${MIOPEN_TEST_FLOAT_ARG} --all ${MIOPEN_TEST_FLAGS_ARGS})
else()
set(TEST_COMMAND ${TEST_NAME} ${MIOPEN_TEST_FLOAT_ARG} ${MIOPEN_TEST_FLAGS_ARGS})
endif()
set(TEST_COMMAND ${TEST_NAME})
add_test_command(${TEST_NAME} ${TEST_COMMAND})
add_dependencies(tests ${TEST_NAME})
add_dependencies(check ${TEST_NAME})
target_link_libraries(${TEST_NAME} migraphx migraphx_ref migraphx_onnx)
target_link_libraries(${TEST_NAME} migraphx migraphx_onnx)
target_include_directories(${TEST_NAME} PUBLIC include)
endfunction(add_test_executable)
......@@ -171,7 +163,7 @@ foreach(ONNX_TEST ${ONNX_TESTS})
set(TEST_NAME test_${BASE_NAME})
add_executable(${TEST_NAME} ${ONNX_TEST})
rocm_clang_tidy_check(${TEST_NAME})
target_link_libraries(${TEST_NAME} migraphx_onnx migraphx_ref)
target_link_libraries(${TEST_NAME} migraphx_onnx)
target_include_directories(${TEST_NAME} PUBLIC include)
add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}> WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_dependencies(tests ${TEST_NAME})
......@@ -182,7 +174,7 @@ endforeach()
set(TEST_TF_DIR ${CMAKE_CURRENT_SOURCE_DIR}/tf)
add_executable(test_tf tf/tf_test.cpp)
rocm_clang_tidy_check(test_tf)
target_link_libraries(test_tf migraphx_tf migraphx_ref)
target_link_libraries(test_tf migraphx_tf)
target_include_directories(test_tf PUBLIC include)
add_test(NAME test_tf COMMAND $<TARGET_FILE:test_tf> WORKING_DIRECTORY ${TEST_TF_DIR})
add_dependencies(tests test_tf)
......@@ -216,10 +208,7 @@ function(test_headers PREFIX)
string(MAKE_C_IDENTIFIER ${HEADER_REL} TEST_NAME)
get_filename_component(BASE_NAME ${HEADER} NAME_WE)
test_header(header_${TEST_NAME} ${PREFIX}/${BASE_NAME}.hpp)
if(MIGRAPHX_ENABLE_GPU)
target_link_libraries(header_${TEST_NAME} migraphx_gpu)
endif()
target_link_libraries(header_${TEST_NAME} migraphx_all_targets)
endforeach()
endfunction()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment