Unverified Commit 7a7040aa authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Dynamically plug-in backend target libs (#1608)

Fixes #1595
parent 9ef6801e
...@@ -46,6 +46,8 @@ else() ...@@ -46,6 +46,8 @@ else()
set(CMAKE_INSTALL_PREFIX "/opt/rocm" CACHE PATH "") set(CMAKE_INSTALL_PREFIX "/opt/rocm" CACHE PATH "")
endif() endif()
set(CMAKE_BUILD_RPATH "${CMAKE_BINARY_DIR}/lib")
project(migraphx) project(migraphx)
find_package(ROCM REQUIRED) find_package(ROCM REQUIRED)
......
...@@ -29,7 +29,7 @@ set_target_properties(migraphx_c PROPERTIES EXPORT_NAME c) ...@@ -29,7 +29,7 @@ set_target_properties(migraphx_c PROPERTIES EXPORT_NAME c)
rocm_set_soversion(migraphx_c 3.0) rocm_set_soversion(migraphx_c 3.0)
rocm_clang_tidy_check(migraphx_c) rocm_clang_tidy_check(migraphx_c)
target_link_libraries(migraphx_c PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_all_targets) target_link_libraries(migraphx_c PRIVATE migraphx migraphx_tf migraphx_onnx)
rocm_install_targets( rocm_install_targets(
TARGETS migraphx_c TARGETS migraphx_c
......
...@@ -32,7 +32,6 @@ ...@@ -32,7 +32,6 @@
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/quantization.hpp> #include <migraphx/quantization.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/load_save.hpp> #include <migraphx/load_save.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
......
...@@ -42,7 +42,7 @@ add_custom_command( ...@@ -42,7 +42,7 @@ add_custom_command(
) )
set_directory_properties(PROPERTIES ADDITIONAL_MAKE_CLEAN_FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver) set_directory_properties(PROPERTIES ADDITIONAL_MAKE_CLEAN_FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver)
rocm_clang_tidy_check(driver) rocm_clang_tidy_check(driver)
target_link_libraries(driver migraphx_all_targets migraphx_onnx migraphx_tf) target_link_libraries(driver migraphx migraphx_onnx migraphx_tf)
rocm_install_targets( rocm_install_targets(
TARGETS driver TARGETS driver
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include "verify.hpp" #include "verify.hpp"
#include "argument_parser.hpp" #include "argument_parser.hpp"
#include "command.hpp" #include "command.hpp"
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include "verify.hpp" #include "verify.hpp"
#include "perf.hpp" #include "perf.hpp"
#include <migraphx/ref/target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/verify_args.hpp> #include <migraphx/verify_args.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
...@@ -37,7 +37,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -37,7 +37,7 @@ inline namespace MIGRAPHX_INLINE_NS {
std::vector<argument> run_ref(program p, const parameter_map& inputs) std::vector<argument> run_ref(program p, const parameter_map& inputs)
{ {
p.compile(ref::target{}); p.compile(migraphx::make_target("ref"));
auto out = p.eval(inputs); auto out = p.eval(inputs);
std::cout << p << std::endl; std::cout << p << std::endl;
return out; return out;
......
...@@ -21,25 +21,44 @@ ...@@ -21,25 +21,44 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/manage_ptr.hpp>
#include <migraphx/dynamic_loader.hpp> #include <migraphx/dynamic_loader.hpp>
#include <migraphx/errors.hpp> #include <migraphx/errors.hpp>
#include <migraphx/file_buffer.hpp> #include <migraphx/file_buffer.hpp>
#include <migraphx/tmp_dir.hpp> #include <migraphx/tmp_dir.hpp>
#include <utility> #include <utility>
#include <dlfcn.h> #include <dlfcn.h>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void check_load_error(bool flush = false)
{
char* error_msg = dlerror();
if(not flush and error_msg != nullptr)
MIGRAPHX_THROW("Dynamic loading or symbol lookup failed with " + std::string(error_msg));
}
struct dynamic_loader_impl struct dynamic_loader_impl
{ {
dynamic_loader_impl() = default; dynamic_loader_impl() = default;
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wignored-attributes"
#endif
dynamic_loader_impl(const fs::path& p, std::shared_ptr<tmp_dir> t = nullptr) dynamic_loader_impl(const fs::path& p, std::shared_ptr<tmp_dir> t = nullptr)
: handle(dlopen(p.string().c_str(), RTLD_LAZY), &dlclose), temp(std::move(t)) : handle(dlopen(p.string().c_str(), RTLD_LAZY),
manage_deleter<decltype(&dlclose), &dlclose>{}),
temp(std::move(t))
{ {
check_load_error();
} }
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic pop
#endif
static std::shared_ptr<dynamic_loader_impl> from_buffer(const char* image, std::size_t size) static std::shared_ptr<dynamic_loader_impl> from_buffer(const char* image, std::size_t size)
{ {
auto t = std::make_shared<tmp_dir>("dloader"); auto t = std::make_shared<tmp_dir>("dloader");
...@@ -68,10 +87,11 @@ dynamic_loader::dynamic_loader(const std::vector<char>& buffer) ...@@ -68,10 +87,11 @@ dynamic_loader::dynamic_loader(const std::vector<char>& buffer)
std::shared_ptr<void> dynamic_loader::get_symbol(const std::string& name) const std::shared_ptr<void> dynamic_loader::get_symbol(const std::string& name) const
{ {
dlerror(); // flush any previous error messages
check_load_error(true);
void* symbol = dlsym(impl->handle.get(), name.c_str()); void* symbol = dlsym(impl->handle.get(), name.c_str());
if(symbol == nullptr) if(symbol == nullptr)
MIGRAPHX_THROW("Symbol not found: " + name); check_load_error();
return {impl, symbol}; return {impl, symbol};
} }
......
...@@ -33,15 +33,36 @@ ...@@ -33,15 +33,36 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
// unregister all ops for specified target, useful when unloading dynamically plugged-in target lib
void unregister_op(const std::string& op_name);
namespace detail {
struct op_handler
{
operation op;
std::string name;
op_handler(const operation& op_r) : op(op_r), name(op.name()){};
~op_handler() { unregister_op(name); }
};
} // namespace detail
void register_op_init();
void register_op(const operation& op); void register_op(const operation& op);
operation load_op(const std::string& name); operation load_op(const std::string& name);
bool has_op(const std::string& name); bool has_op(const std::string& name);
std::vector<std::string> get_operators(); std::vector<std::string> get_operators();
template <class T> template <class T>
void register_op() void register_op()
{ {
register_op(T{}); register_op_init(); // instantiate static op_map;
static auto op_h = detail::op_handler(T{});
register_op(op_h.op);
} }
struct register_op_action struct register_op_action
......
...@@ -33,14 +33,28 @@ ...@@ -33,14 +33,28 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void register_target_init();
void register_target(const target& t); void register_target(const target& t);
void unregister_target(const std::string& name);
target make_target(const std::string& name); target make_target(const std::string& name);
std::vector<std::string> get_targets(); std::vector<std::string> get_targets();
namespace detail {
struct target_handler
{
target t;
std::string target_name;
target_handler(const target& t_r) : t(t_r), target_name(t.name()) {}
~target_handler() { unregister_target(target_name); }
};
} // namespace detail
template <class T> template <class T>
void register_target() void register_target()
{ {
register_target(T{}); register_target_init();
static auto t_h = detail::target_handler(T{});
register_target(t_h.t);
} }
struct register_target_action struct register_target_action
......
...@@ -33,7 +33,17 @@ std::unordered_map<std::string, operation>& op_map() ...@@ -33,7 +33,17 @@ std::unordered_map<std::string, operation>& op_map()
static std::unordered_map<std::string, operation> m; // NOLINT static std::unordered_map<std::string, operation> m; // NOLINT
return m; return m;
} }
void register_op_init() { (void)op_map(); }
void register_op(const operation& op) { op_map()[op.name()] = op; } void register_op(const operation& op) { op_map()[op.name()] = op; }
void unregister_op(const std::string& op_name)
{
assert(op_map().count(op_name));
op_map().erase(op_name);
}
operation load_op(const std::string& name) operation load_op(const std::string& name)
{ {
return at(op_map(), name, "Operator not found: " + name); return at(op_map(), name, "Operator not found: " + name);
......
...@@ -21,26 +21,48 @@ ...@@ -21,26 +21,48 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <string>
#include <unordered_map> #include <unordered_map>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/dynamic_loader.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void store_target_lib(const dynamic_loader& lib)
{
static std::vector<dynamic_loader> target_loader;
target_loader.emplace_back(lib);
}
std::unordered_map<std::string, target>& target_map() std::unordered_map<std::string, target>& target_map()
{ {
static std::unordered_map<std::string, target> m; // NOLINT static std::unordered_map<std::string, target> m; // NOLINT
return m; return m;
} }
void register_target_init() { (void)target_map(); }
void unregister_target(const std::string& name)
{
assert(target_map().count(name));
target_map().erase(name);
}
void register_target(const target& t) { target_map()[t.name()] = t; } void register_target(const target& t) { target_map()[t.name()] = t; }
target make_target(const std::string& name) target make_target(const std::string& name)
{ {
if(not contains(target_map(), name))
{
std::string target_name = "libmigraphx_" + name + ".so";
store_target_lib(dynamic_loader(target_name));
}
const auto it = target_map().find(name); const auto it = target_map().find(name);
if(it == target_map().end()) if(it == target_map().end())
{ {
MIGRAPHX_THROW("Requested target '" + name + "' is not enabled or not supported"); MIGRAPHX_THROW("Requested target '" + name + "' is not loaded or not supported");
} }
return it->second; return it->second;
} }
......
...@@ -40,14 +40,11 @@ struct target ...@@ -40,14 +40,11 @@ struct target
std::string name() const; std::string name() const;
std::vector<pass> get_passes(migraphx::context& gctx, const compile_options&) const; std::vector<pass> get_passes(migraphx::context& gctx, const compile_options&) const;
migraphx::context get_context() const { return context{}; } migraphx::context get_context() const { return context{}; }
argument copy_to(const argument& arg) const { return arg; } argument copy_to(const argument& arg) const { return arg; }
argument copy_from(const argument& arg) const { return arg; } argument copy_from(const argument& arg) const { return arg; }
argument allocate(const shape& s) const; argument allocate(const shape& s) const;
}; };
MIGRAPHX_REGISTER_TARGET(target);
} // namespace cpu } // namespace cpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -23,7 +23,6 @@ ...@@ -23,7 +23,6 @@
*/ */
#include <migraphx/auto_contiguous.hpp> #include <migraphx/auto_contiguous.hpp>
#include <migraphx/check_context.hpp>
#include <migraphx/adjust_allocation.hpp> #include <migraphx/adjust_allocation.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_allocation.hpp> #include <migraphx/eliminate_allocation.hpp>
......
...@@ -43,14 +43,11 @@ struct target ...@@ -43,14 +43,11 @@ struct target
std::vector<pass> get_passes(migraphx::context& ctx, const compile_options&) const; std::vector<pass> get_passes(migraphx::context& ctx, const compile_options&) const;
migraphx::context get_context() const { return context{}; } migraphx::context get_context() const { return context{}; }
supported_segments find_supported(const_module_ref mod, support_metric m) const; supported_segments find_supported(const_module_ref mod, support_metric m) const;
argument copy_to(const argument& arg) const { return arg; } argument copy_to(const argument& arg) const { return arg; }
argument copy_from(const argument& arg) const { return arg; } argument copy_from(const argument& arg) const { return arg; }
argument allocate(const shape& s) const; argument allocate(const shape& s) const;
}; };
MIGRAPHX_REGISTER_TARGET(target);
} // namespace fpga } // namespace fpga
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -27,7 +27,6 @@ ...@@ -27,7 +27,6 @@
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operation.hpp> #include <migraphx/operation.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/gpu/miopen.hpp> #include <migraphx/gpu/miopen.hpp>
#include <migraphx/op/identity.hpp> #include <migraphx/op/identity.hpp>
#include <migraphx/op/convolution.hpp> #include <migraphx/op/convolution.hpp>
......
...@@ -56,7 +56,6 @@ struct oper ...@@ -56,7 +56,6 @@ struct oper
return name.substr(pos_ns + 2); return name.substr(pos_ns + 2);
} }
} }
return "unknown_operator_name"; return "unknown_operator_name";
} }
}; };
......
...@@ -37,7 +37,6 @@ struct target ...@@ -37,7 +37,6 @@ struct target
std::string name() const; std::string name() const;
std::vector<pass> get_passes(migraphx::context& gctx, const compile_options& options) const; std::vector<pass> get_passes(migraphx::context& gctx, const compile_options& options) const;
migraphx::context get_context() const; migraphx::context get_context() const;
argument copy_to(const argument& arg) const; argument copy_to(const argument& arg) const;
argument copy_from(const argument& arg) const; argument copy_from(const argument& arg) const;
argument allocate(const shape& s) const; argument allocate(const shape& s) const;
......
...@@ -74,7 +74,6 @@ namespace gpu { ...@@ -74,7 +74,6 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_POINTWISE_FUSION) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_POINTWISE_FUSION)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC)
struct id_pass struct id_pass
{ {
std::string name() const { return "id"; } std::string name() const { return "id"; }
......
...@@ -46,8 +46,6 @@ struct target ...@@ -46,8 +46,6 @@ struct target
argument allocate(const shape& s) const; argument allocate(const shape& s) const;
}; };
MIGRAPHX_REGISTER_TARGET(target);
} // namespace ref } // namespace ref
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -106,19 +106,11 @@ function(add_test_executable TEST_NAME) ...@@ -106,19 +106,11 @@ function(add_test_executable TEST_NAME)
if(CMAKE_CXX_COMPILER_ID MATCHES "GNU") if(CMAKE_CXX_COMPILER_ID MATCHES "GNU")
set_target_properties(${TEST_NAME} PROPERTIES COMPILE_FLAGS -pthread LINK_FLAGS -pthread) set_target_properties(${TEST_NAME} PROPERTIES COMPILE_FLAGS -pthread LINK_FLAGS -pthread)
endif() endif()
set(TEST_COMMAND ${TEST_NAME})
separate_arguments(MIOPEN_TEST_FLAGS_ARGS UNIX_COMMAND ${MIOPEN_TEST_FLAGS})
if(MIOPEN_TEST_ALL)
set(TEST_COMMAND ${TEST_NAME} ${MIOPEN_TEST_FLOAT_ARG} --all ${MIOPEN_TEST_FLAGS_ARGS})
else()
set(TEST_COMMAND ${TEST_NAME} ${MIOPEN_TEST_FLOAT_ARG} ${MIOPEN_TEST_FLAGS_ARGS})
endif()
add_test_command(${TEST_NAME} ${TEST_COMMAND}) add_test_command(${TEST_NAME} ${TEST_COMMAND})
add_dependencies(tests ${TEST_NAME}) add_dependencies(tests ${TEST_NAME})
add_dependencies(check ${TEST_NAME}) add_dependencies(check ${TEST_NAME})
target_link_libraries(${TEST_NAME} migraphx migraphx_ref migraphx_onnx) target_link_libraries(${TEST_NAME} migraphx migraphx_onnx)
target_include_directories(${TEST_NAME} PUBLIC include) target_include_directories(${TEST_NAME} PUBLIC include)
endfunction(add_test_executable) endfunction(add_test_executable)
...@@ -171,7 +163,7 @@ foreach(ONNX_TEST ${ONNX_TESTS}) ...@@ -171,7 +163,7 @@ foreach(ONNX_TEST ${ONNX_TESTS})
set(TEST_NAME test_${BASE_NAME}) set(TEST_NAME test_${BASE_NAME})
add_executable(${TEST_NAME} ${ONNX_TEST}) add_executable(${TEST_NAME} ${ONNX_TEST})
rocm_clang_tidy_check(${TEST_NAME}) rocm_clang_tidy_check(${TEST_NAME})
target_link_libraries(${TEST_NAME} migraphx_onnx migraphx_ref) target_link_libraries(${TEST_NAME} migraphx_onnx)
target_include_directories(${TEST_NAME} PUBLIC include) target_include_directories(${TEST_NAME} PUBLIC include)
add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}> WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}> WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_dependencies(tests ${TEST_NAME}) add_dependencies(tests ${TEST_NAME})
...@@ -182,7 +174,7 @@ endforeach() ...@@ -182,7 +174,7 @@ endforeach()
set(TEST_TF_DIR ${CMAKE_CURRENT_SOURCE_DIR}/tf) set(TEST_TF_DIR ${CMAKE_CURRENT_SOURCE_DIR}/tf)
add_executable(test_tf tf/tf_test.cpp) add_executable(test_tf tf/tf_test.cpp)
rocm_clang_tidy_check(test_tf) rocm_clang_tidy_check(test_tf)
target_link_libraries(test_tf migraphx_tf migraphx_ref) target_link_libraries(test_tf migraphx_tf)
target_include_directories(test_tf PUBLIC include) target_include_directories(test_tf PUBLIC include)
add_test(NAME test_tf COMMAND $<TARGET_FILE:test_tf> WORKING_DIRECTORY ${TEST_TF_DIR}) add_test(NAME test_tf COMMAND $<TARGET_FILE:test_tf> WORKING_DIRECTORY ${TEST_TF_DIR})
add_dependencies(tests test_tf) add_dependencies(tests test_tf)
...@@ -216,10 +208,7 @@ function(test_headers PREFIX) ...@@ -216,10 +208,7 @@ function(test_headers PREFIX)
string(MAKE_C_IDENTIFIER ${HEADER_REL} TEST_NAME) string(MAKE_C_IDENTIFIER ${HEADER_REL} TEST_NAME)
get_filename_component(BASE_NAME ${HEADER} NAME_WE) get_filename_component(BASE_NAME ${HEADER} NAME_WE)
test_header(header_${TEST_NAME} ${PREFIX}/${BASE_NAME}.hpp) test_header(header_${TEST_NAME} ${PREFIX}/${BASE_NAME}.hpp)
target_link_libraries(header_${TEST_NAME} migraphx_all_targets)
if(MIGRAPHX_ENABLE_GPU)
target_link_libraries(header_${TEST_NAME} migraphx_gpu)
endif()
endforeach() endforeach()
endfunction() endfunction()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment