/*! * Copyright (c) 2016 by Contributors * \file c_runtime_api.cc * \brief Runtime API implementation */ #include #include #include #include #include #include #include #include #include #include #include #include #include "runtime_base.h" namespace dgl { namespace runtime { /*! * \brief The name of Device API factory. * \param type The device type. */ inline std::string DeviceName(int type) { switch (type) { case kDLCPU: return "cpu"; case kDLGPU: return "gpu"; case kDLOpenCL: return "opencl"; case kDLSDAccel: return "sdaccel"; case kDLAOCL: return "aocl"; case kDLVulkan: return "vulkan"; case kDLMetal: return "metal"; case kDLVPI: return "vpi"; case kDLROCM: return "rocm"; case kOpenGL: return "opengl"; case kExtDev: return "ext_dev"; default: LOG(FATAL) << "unknown type =" << type; return "Unknown"; } } class DeviceAPIManager { public: static const int kMaxDeviceAPI = 32; // Get API static DeviceAPI* Get(const DGLContext& ctx) { return Get(ctx.device_type); } static DeviceAPI* Get(int dev_type, bool allow_missing = false) { return Global()->GetAPI(dev_type, allow_missing); } private: std::array api_; DeviceAPI* rpc_api_{nullptr}; std::mutex mutex_; // constructor DeviceAPIManager() { std::fill(api_.begin(), api_.end(), nullptr); } // Global static variable. static DeviceAPIManager* Global() { static DeviceAPIManager inst; return &inst; } // Get or initialize API. DeviceAPI* GetAPI(int type, bool allow_missing) { if (type < kRPCSessMask) { if (api_[type] != nullptr) return api_[type]; std::lock_guard lock(mutex_); if (api_[type] != nullptr) return api_[type]; api_[type] = GetAPI(DeviceName(type), allow_missing); return api_[type]; } else { if (rpc_api_ != nullptr) return rpc_api_; std::lock_guard lock(mutex_); if (rpc_api_ != nullptr) return rpc_api_; rpc_api_ = GetAPI("rpc", allow_missing); return rpc_api_; } } DeviceAPI* GetAPI(const std::string name, bool allow_missing) { std::string factory = "device_api." + name; auto* f = Registry::Get(factory); if (f == nullptr) { CHECK(allow_missing) << "Device API " << name << " is not enabled. Please install the cuda version of dgl."; return nullptr; } void* ptr = (*f)(); return static_cast(ptr); } }; DeviceAPI* DeviceAPI::Get(DGLContext ctx, bool allow_missing) { return DeviceAPIManager::Get( static_cast(ctx.device_type), allow_missing); } DeviceAPI* DeviceAPI::Get(DLDeviceType dev_type, bool allow_missing) { return DeviceAPIManager::Get(static_cast(dev_type), allow_missing); } void* DeviceAPI::AllocWorkspace(DGLContext ctx, size_t size, DGLType type_hint) { return AllocDataSpace(ctx, size, kTempAllocaAlignment, type_hint); } void DeviceAPI::FreeWorkspace(DGLContext ctx, void* ptr) { FreeDataSpace(ctx, ptr); } DGLStreamHandle DeviceAPI::CreateStream(DGLContext ctx) { LOG(FATAL) << "Device does not support stream api."; return 0; } void DeviceAPI::FreeStream(DGLContext ctx, DGLStreamHandle stream) { LOG(FATAL) << "Device does not support stream api."; } void DeviceAPI::SyncStreamFromTo(DGLContext ctx, DGLStreamHandle event_src, DGLStreamHandle event_dst) { LOG(FATAL) << "Device does not support stream api."; } void DeviceAPI::PinData(void* ptr, size_t nbytes) { LOG(FATAL) << "Device does not support hipHostRegister api."; } void DeviceAPI::UnpinData(void* ptr) { LOG(FATAL) << "Device does not support hipHostUnregister api."; } } // namespace runtime } // namespace dgl using namespace dgl::runtime; struct DGLRuntimeEntry { std::string ret_str; std::string last_error; DGLByteArray ret_bytes; }; typedef dmlc::ThreadLocalStore DGLAPIRuntimeStore; const char *DGLGetLastError() { return DGLAPIRuntimeStore::Get()->last_error.c_str(); } void DGLAPISetLastError(const char* msg) { #ifndef _LIBCPP_SGX_CONFIG DGLAPIRuntimeStore::Get()->last_error = msg; #else sgx::OCallPackedFunc("__sgx_set_last_error__", msg); #endif } int DGLModLoadFromFile(const char* file_name, const char* format, DGLModuleHandle* out) { API_BEGIN(); Module m = Module::LoadFromFile(file_name, format); *out = new Module(m); API_END(); } int DGLModImport(DGLModuleHandle mod, DGLModuleHandle dep) { API_BEGIN(); static_cast(mod)->Import( *static_cast(dep)); API_END(); } int DGLModGetFunction(DGLModuleHandle mod, const char* func_name, int query_imports, DGLFunctionHandle *func) { API_BEGIN(); PackedFunc pf = static_cast(mod)->GetFunction( func_name, query_imports != 0); if (pf != nullptr) { *func = new PackedFunc(pf); } else { *func = nullptr; } API_END(); } int DGLModFree(DGLModuleHandle mod) { API_BEGIN(); delete static_cast(mod); API_END(); } int DGLBackendGetFuncFromEnv(void* mod_node, const char* func_name, DGLFunctionHandle *func) { API_BEGIN(); *func = (DGLFunctionHandle)( static_cast(mod_node)->GetFuncFromEnv(func_name)); API_END(); } void* DGLBackendAllocWorkspace(int device_type, int device_id, uint64_t size, int dtype_code_hint, int dtype_bits_hint) { DGLContext ctx; ctx.device_type = static_cast(device_type); ctx.device_id = device_id; DGLType type_hint; type_hint.code = static_cast(dtype_code_hint); type_hint.bits = static_cast(dtype_bits_hint); type_hint.lanes = 1; return DeviceAPIManager::Get(ctx)->AllocWorkspace(ctx, static_cast(size), type_hint); } int DGLBackendFreeWorkspace(int device_type, int device_id, void* ptr) { DGLContext ctx; ctx.device_type = static_cast(device_type); ctx.device_id = device_id; DeviceAPIManager::Get(ctx)->FreeWorkspace(ctx, ptr); return 0; } int DGLBackendRunOnce(void** handle, int (*f)(void*), void* cdata, int nbytes) { if (*handle == nullptr) { *handle = reinterpret_cast(1); return (*f)(cdata); } return 0; } int DGLFuncFree(DGLFunctionHandle func) { API_BEGIN(); delete static_cast(func); API_END(); } int DGLFuncCall(DGLFunctionHandle func, DGLValue* args, int* arg_type_codes, int num_args, DGLValue* ret_val, int* ret_type_code) { API_BEGIN(); DGLRetValue rv; (*static_cast(func)).CallPacked( DGLArgs(args, arg_type_codes, num_args), &rv); // handle return string. if (rv.type_code() == kStr || rv.type_code() == kDGLType || rv.type_code() == kBytes) { DGLRuntimeEntry* e = DGLAPIRuntimeStore::Get(); if (rv.type_code() != kDGLType) { e->ret_str = *rv.ptr(); } else { e->ret_str = rv.operator std::string(); } if (rv.type_code() == kBytes) { e->ret_bytes.data = e->ret_str.c_str(); e->ret_bytes.size = e->ret_str.length(); *ret_type_code = kBytes; ret_val->v_handle = &(e->ret_bytes); } else { *ret_type_code = kStr; ret_val->v_str = e->ret_str.c_str(); } } else { rv.MoveToCHost(ret_val, ret_type_code); } API_END(); } int DGLCFuncSetReturn(DGLRetValueHandle ret, DGLValue* value, int* type_code, int num_ret) { API_BEGIN(); CHECK_EQ(num_ret, 1); DGLRetValue* rv = static_cast(ret); *rv = DGLArgValue(value[0], type_code[0]); API_END(); } int DGLFuncCreateFromCFunc(DGLPackedCFunc func, void* resource_handle, DGLPackedCFuncFinalizer fin, DGLFunctionHandle *out) { API_BEGIN(); if (fin == nullptr) { *out = new PackedFunc( [func, resource_handle](DGLArgs args, DGLRetValue* rv) { int ret = func((DGLValue*)args.values, (int*)args.type_codes, // NOLINT(*) args.num_args, rv, resource_handle); if (ret != 0) { std::string err = "DGLCall CFunc Error:\n"; err += DGLGetLastError(); throw dmlc::Error(err); } }); } else { // wrap it in a shared_ptr, with fin as deleter. // so fin will be called when the lambda went out of scope. std::shared_ptr rpack(resource_handle, fin); *out = new PackedFunc( [func, rpack](DGLArgs args, DGLRetValue* rv) { int ret = func((DGLValue*)args.values, (int*)args.type_codes, // NOLINT(*) args.num_args, rv, rpack.get()); if (ret != 0) { std::string err = "DGLCall CFunc Error:\n"; err += DGLGetLastError(); throw dmlc::Error(err); } }); } API_END(); } int DGLStreamCreate(int device_type, int device_id, DGLStreamHandle* out) { API_BEGIN(); DGLContext ctx; ctx.device_type = static_cast(device_type); ctx.device_id = device_id; *out = DeviceAPIManager::Get(ctx)->CreateStream(ctx); API_END(); } int DGLStreamFree(int device_type, int device_id, DGLStreamHandle stream) { API_BEGIN(); DGLContext ctx; ctx.device_type = static_cast(device_type); ctx.device_id = device_id; DeviceAPIManager::Get(ctx)->FreeStream(ctx, stream); API_END(); } int DGLSetStream(int device_type, int device_id, DGLStreamHandle stream) { API_BEGIN(); DGLContext ctx; ctx.device_type = static_cast(device_type); ctx.device_id = device_id; DeviceAPIManager::Get(ctx)->SetStream(ctx, stream); API_END(); } int DGLGetStream(int device_type, int device_id, DGLStreamHandle* stream) { API_BEGIN(); DGLContext ctx; ctx.device_type = static_cast(device_type); ctx.device_id = device_id; *stream = DeviceAPIManager::Get(ctx)->GetStream(); API_END(); } int DGLSynchronize(int device_type, int device_id, DGLStreamHandle stream) { API_BEGIN(); DGLContext ctx; ctx.device_type = static_cast(device_type); ctx.device_id = device_id; DeviceAPIManager::Get(ctx)->StreamSync(ctx, stream); API_END(); } int DGLStreamStreamSynchronize(int device_type, int device_id, DGLStreamHandle src, DGLStreamHandle dst) { API_BEGIN(); DGLContext ctx; ctx.device_type = static_cast(device_type); ctx.device_id = device_id; DeviceAPIManager::Get(ctx)->SyncStreamFromTo(ctx, src, dst); API_END(); } int DGLCbArgToReturn(DGLValue* value, int code) { API_BEGIN(); dgl::runtime::DGLRetValue rv; rv = dgl::runtime::DGLArgValue(*value, code); int tcode; rv.MoveToCHost(value, &tcode); CHECK_EQ(tcode, code); API_END(); } int DGLLoadTensorAdapter(const char *path) { return TensorDispatcher::Global()->Load(path) ? 0 : -1; } // set device api DGL_REGISTER_GLOBAL(dgl::runtime::symbol::dgl_set_device) .set_body([](DGLArgs args, DGLRetValue *ret) { DGLContext ctx; ctx.device_type = static_cast(args[0].operator int()); ctx.device_id = args[1]; DeviceAPIManager::Get(ctx)->SetDevice(ctx); }); // set device api DGL_REGISTER_GLOBAL("_GetDeviceAttr") .set_body([](DGLArgs args, DGLRetValue *ret) { DGLContext ctx; ctx.device_type = static_cast(args[0].operator int()); ctx.device_id = args[1]; DeviceAttrKind kind = static_cast(args[2].operator int()); if (kind == kExist) { DeviceAPI* api = DeviceAPIManager::Get(ctx.device_type, true); if (api != nullptr) { api->GetAttr(ctx, kind, ret); } else { *ret = 0; } } else { DeviceAPIManager::Get(ctx)->GetAttr(ctx, kind, ret); } });