#include "custom_op_library.h" #define ORT_API_MANUAL_INIT #include "onnxruntime_cxx_api.h" #undef ORT_API_MANUAL_INIT #include #include #include #include #include "core/common/common.h" #include "core/framework/ortdevice.h" #include "core/framework/ortmemoryinfo.h" #include "rocm_ops.h" #include "onnxruntime_lite_custom_op.h" // static const char* c_OpDomain = "test.customop"; static const char* c_OpDomain = ""; static void AddOrtCustomOpDomainToContainer(Ort::CustomOpDomain&& domain) { static std::vector ort_custom_op_domain_container; static std::mutex ort_custom_op_domain_mutex; std::lock_guard lock(ort_custom_op_domain_mutex); ort_custom_op_domain_container.push_back(std::move(domain)); } OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api) { Ort::Global::api_ = api->GetApi(ORT_API_VERSION); OrtStatus* result = nullptr; ORT_TRY { Ort::CustomOpDomain domain{c_OpDomain}; Rocm::RegisterOps(domain); Ort::UnownedSessionOptions session_options(options); session_options.Add(domain); AddOrtCustomOpDomainToContainer(std::move(domain)); } ORT_CATCH(const std::exception& e) { ORT_HANDLE_EXCEPTION([&]() { Ort::Status status{e}; result = status.release(); }); } return result; } OrtStatus* ORT_API_CALL RegisterCustomOpsAltName(OrtSessionOptions* options, const OrtApiBase* api) { return RegisterCustomOps(options, api); }