Unverified Commit 401e1278 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] clang-format auto fix. (#4811)



* [Misc] clang-format auto fix.

* fix

* manual
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 6c53f351
...@@ -2,14 +2,15 @@ ...@@ -2,14 +2,15 @@
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* \file file_util.cc * \file file_util.cc
*/ */
#include "file_util.h"
#include <dgl/runtime/serializer.h>
#include <dmlc/json.h> #include <dmlc/json.h>
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <dgl/runtime/serializer.h>
#include <fstream> #include <fstream>
#include <vector>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "file_util.h"
namespace dgl { namespace dgl {
namespace runtime { namespace runtime {
...@@ -52,8 +53,8 @@ bool FunctionInfo::Load(dmlc::Stream* reader) { ...@@ -52,8 +53,8 @@ bool FunctionInfo::Load(dmlc::Stream* reader) {
return true; return true;
} }
std::string GetFileFormat(const std::string& file_name, std::string GetFileFormat(
const std::string& format) { const std::string& file_name, const std::string& format) {
std::string fmt = format; std::string fmt = format;
if (fmt.length() == 0) { if (fmt.length() == 0) {
if (file_name.find(".signed.so") != std::string::npos) return "sgx"; if (file_name.find(".signed.so") != std::string::npos) return "sgx";
...@@ -87,7 +88,7 @@ std::string GetFileBasename(const std::string& file_name) { ...@@ -87,7 +88,7 @@ std::string GetFileBasename(const std::string& file_name) {
} }
std::string GetMetaFilePath(const std::string& file_name) { std::string GetMetaFilePath(const std::string& file_name) {
size_t pos = file_name.find_last_of("."); size_t pos = file_name.find_last_of(".");
if (pos != std::string::npos) { if (pos != std::string::npos) {
return file_name.substr(0, pos) + ".dgl_meta.json"; return file_name.substr(0, pos) + ".dgl_meta.json";
} else { } else {
...@@ -95,8 +96,7 @@ std::string GetMetaFilePath(const std::string& file_name) { ...@@ -95,8 +96,7 @@ std::string GetMetaFilePath(const std::string& file_name) {
} }
} }
void LoadBinaryFromFile(const std::string& file_name, void LoadBinaryFromFile(const std::string& file_name, std::string* data) {
std::string* data) {
std::ifstream fs(file_name, std::ios::in | std::ios::binary); std::ifstream fs(file_name, std::ios::in | std::ios::binary);
CHECK(!fs.fail()) << "Cannot open " << file_name; CHECK(!fs.fail()) << "Cannot open " << file_name;
// get its size: // get its size:
...@@ -107,9 +107,7 @@ void LoadBinaryFromFile(const std::string& file_name, ...@@ -107,9 +107,7 @@ void LoadBinaryFromFile(const std::string& file_name,
fs.read(&(*data)[0], size); fs.read(&(*data)[0], size);
} }
void SaveBinaryToFile( void SaveBinaryToFile(const std::string& file_name, const std::string& data) {
const std::string& file_name,
const std::string& data) {
std::ofstream fs(file_name, std::ios::out | std::ios::binary); std::ofstream fs(file_name, std::ios::out | std::ios::binary);
CHECK(!fs.fail()) << "Cannot open " << file_name; CHECK(!fs.fail()) << "Cannot open " << file_name;
fs.write(&data[0], data.length()); fs.write(&data[0], data.length());
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "meta_data.h" #include "meta_data.h"
namespace dgl { namespace dgl {
...@@ -17,8 +18,8 @@ namespace runtime { ...@@ -17,8 +18,8 @@ namespace runtime {
* \param file_name The name of the file. * \param file_name The name of the file.
* \param format The format of the file. * \param format The format of the file.
*/ */
std::string GetFileFormat(const std::string& file_name, std::string GetFileFormat(
const std::string& format); const std::string& file_name, const std::string& format);
/*! /*!
* \return the directory in which DGL stores cached files. * \return the directory in which DGL stores cached files.
...@@ -44,16 +45,14 @@ std::string GetFileBasename(const std::string& file_name); ...@@ -44,16 +45,14 @@ std::string GetFileBasename(const std::string& file_name);
* \param file_name The name of the file. * \param file_name The name of the file.
* \param data The data to be loaded. * \param data The data to be loaded.
*/ */
void LoadBinaryFromFile(const std::string& file_name, void LoadBinaryFromFile(const std::string& file_name, std::string* data);
std::string* data);
/*! /*!
* \brief Load binary file into a in-memory buffer. * \brief Load binary file into a in-memory buffer.
* \param file_name The name of the file. * \param file_name The name of the file.
* \param data The binary data to be saved. * \param data The binary data to be saved.
*/ */
void SaveBinaryToFile(const std::string& file_name, void SaveBinaryToFile(const std::string& file_name, const std::string& data);
const std::string& data);
/*! /*!
* \brief Save meta data to file. * \brief Save meta data to file.
......
...@@ -6,11 +6,13 @@ ...@@ -6,11 +6,13 @@
#ifndef DGL_RUNTIME_META_DATA_H_ #ifndef DGL_RUNTIME_META_DATA_H_
#define DGL_RUNTIME_META_DATA_H_ #define DGL_RUNTIME_META_DATA_H_
#include <dmlc/json.h>
#include <dmlc/io.h>
#include <dgl/runtime/packed_func.h> #include <dgl/runtime/packed_func.h>
#include <dmlc/io.h>
#include <dmlc/json.h>
#include <string> #include <string>
#include <vector> #include <vector>
#include "runtime_base.h" #include "runtime_base.h"
namespace dgl { namespace dgl {
......
...@@ -4,10 +4,11 @@ ...@@ -4,10 +4,11 @@
* \brief DGL module system * \brief DGL module system
*/ */
#include <dgl/runtime/module.h> #include <dgl/runtime/module.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/packed_func.h> #include <dgl/runtime/packed_func.h>
#include <unordered_set> #include <dgl/runtime/registry.h>
#include <cstring> #include <cstring>
#include <unordered_set>
#ifndef _LIBCPP_SGX_CONFIG #ifndef _LIBCPP_SGX_CONFIG
#include "file_util.h" #include "file_util.h"
#endif #endif
...@@ -44,20 +45,18 @@ void Module::Import(Module other) { ...@@ -44,20 +45,18 @@ void Module::Import(Module other) {
node_->imports_.emplace_back(std::move(other)); node_->imports_.emplace_back(std::move(other));
} }
Module Module::LoadFromFile(const std::string& file_name, Module Module::LoadFromFile(
const std::string& format) { const std::string& file_name, const std::string& format) {
#ifndef _LIBCPP_SGX_CONFIG #ifndef _LIBCPP_SGX_CONFIG
std::string fmt = GetFileFormat(file_name, format); std::string fmt = GetFileFormat(file_name, format);
CHECK(fmt.length() != 0) CHECK(fmt.length() != 0) << "Cannot deduce format of file " << file_name;
<< "Cannot deduce format of file " << file_name;
if (fmt == "dll" || fmt == "dylib" || fmt == "dso") { if (fmt == "dll" || fmt == "dylib" || fmt == "dso") {
fmt = "so"; fmt = "so";
} }
std::string load_f_name = "module.loadfile_" + fmt; std::string load_f_name = "module.loadfile_" + fmt;
const PackedFunc* f = Registry::Get(load_f_name); const PackedFunc* f = Registry::Get(load_f_name);
CHECK(f != nullptr) CHECK(f != nullptr) << "Loader of " << format << "(" << load_f_name
<< "Loader of " << format << "(" << ") is not presented.";
<< load_f_name << ") is not presented.";
Module m = (*f)(file_name, format); Module m = (*f)(file_name, format);
return m; return m;
#else #else
...@@ -65,8 +64,8 @@ Module Module::LoadFromFile(const std::string& file_name, ...@@ -65,8 +64,8 @@ Module Module::LoadFromFile(const std::string& file_name,
#endif #endif
} }
void ModuleNode::SaveToFile(const std::string& file_name, void ModuleNode::SaveToFile(
const std::string& format) { const std::string& file_name, const std::string& format) {
LOG(FATAL) << "Module[" << type_key() << "] does not support SaveToFile"; LOG(FATAL) << "Module[" << type_key() << "] does not support SaveToFile";
} }
...@@ -89,9 +88,8 @@ const PackedFunc* ModuleNode::GetFuncFromEnv(const std::string& name) { ...@@ -89,9 +88,8 @@ const PackedFunc* ModuleNode::GetFuncFromEnv(const std::string& name) {
} }
if (pf == nullptr) { if (pf == nullptr) {
const PackedFunc* f = Registry::Get(name); const PackedFunc* f = Registry::Get(name);
CHECK(f != nullptr) CHECK(f != nullptr) << "Cannot find function " << name
<< "Cannot find function " << name << " in the imported modules or global registry";
<< " in the imported modules or global registry";
return f; return f;
} else { } else {
std::unique_ptr<PackedFunc> f(new PackedFunc(pf)); std::unique_ptr<PackedFunc> f(new PackedFunc(pf));
...@@ -125,7 +123,8 @@ bool RuntimeEnabled(const std::string& target) { ...@@ -125,7 +123,8 @@ bool RuntimeEnabled(const std::string& target) {
} else if (target.length() >= 4 && target.substr(0, 4) == "rocm") { } else if (target.length() >= 4 && target.substr(0, 4) == "rocm") {
f_name = "device_api.rocm"; f_name = "device_api.rocm";
} else if (target.length() >= 4 && target.substr(0, 4) == "llvm") { } else if (target.length() >= 4 && target.substr(0, 4) == "llvm") {
const PackedFunc* pf = runtime::Registry::Get("codegen.llvm_target_enabled"); const PackedFunc* pf =
runtime::Registry::Get("codegen.llvm_target_enabled");
if (pf == nullptr) return false; if (pf == nullptr) return false;
return (*pf)(target); return (*pf)(target);
} else { } else {
...@@ -135,41 +134,38 @@ bool RuntimeEnabled(const std::string& target) { ...@@ -135,41 +134,38 @@ bool RuntimeEnabled(const std::string& target) {
} }
DGL_REGISTER_GLOBAL("module._Enabled") DGL_REGISTER_GLOBAL("module._Enabled")
.set_body([](DGLArgs args, DGLRetValue *ret) { .set_body([](DGLArgs args, DGLRetValue* ret) {
*ret = RuntimeEnabled(args[0]); *ret = RuntimeEnabled(args[0]);
}); });
DGL_REGISTER_GLOBAL("module._GetSource") DGL_REGISTER_GLOBAL("module._GetSource")
.set_body([](DGLArgs args, DGLRetValue *ret) { .set_body([](DGLArgs args, DGLRetValue* ret) {
*ret = args[0].operator Module()->GetSource(args[1]); *ret = args[0].operator Module()->GetSource(args[1]);
}); });
DGL_REGISTER_GLOBAL("module._ImportsSize") DGL_REGISTER_GLOBAL("module._ImportsSize")
.set_body([](DGLArgs args, DGLRetValue *ret) { .set_body([](DGLArgs args, DGLRetValue* ret) {
*ret = static_cast<int64_t>( *ret = static_cast<int64_t>(args[0].operator Module()->imports().size());
args[0].operator Module()->imports().size());
}); });
DGL_REGISTER_GLOBAL("module._GetImport") DGL_REGISTER_GLOBAL("module._GetImport")
.set_body([](DGLArgs args, DGLRetValue *ret) { .set_body([](DGLArgs args, DGLRetValue* ret) {
*ret = args[0].operator Module()-> *ret = args[0].operator Module()->imports().at(args[1].operator int());
imports().at(args[1].operator int());
}); });
DGL_REGISTER_GLOBAL("module._GetTypeKey") DGL_REGISTER_GLOBAL("module._GetTypeKey")
.set_body([](DGLArgs args, DGLRetValue *ret) { .set_body([](DGLArgs args, DGLRetValue* ret) {
*ret = std::string(args[0].operator Module()->type_key()); *ret = std::string(args[0].operator Module()->type_key());
}); });
DGL_REGISTER_GLOBAL("module._LoadFromFile") DGL_REGISTER_GLOBAL("module._LoadFromFile")
.set_body([](DGLArgs args, DGLRetValue *ret) { .set_body([](DGLArgs args, DGLRetValue* ret) {
*ret = Module::LoadFromFile(args[0], args[1]); *ret = Module::LoadFromFile(args[0], args[1]);
}); });
DGL_REGISTER_GLOBAL("module._SaveToFile") DGL_REGISTER_GLOBAL("module._SaveToFile")
.set_body([](DGLArgs args, DGLRetValue *ret) { .set_body([](DGLArgs args, DGLRetValue* ret) {
args[0].operator Module()-> args[0].operator Module()->SaveToFile(args[1], args[2]);
SaveToFile(args[1], args[2]);
}); });
} // namespace runtime } // namespace runtime
} // namespace dgl } // namespace dgl
...@@ -8,8 +8,10 @@ ...@@ -8,8 +8,10 @@
#endif #endif
#include <dgl/runtime/module.h> #include <dgl/runtime/module.h>
#include <dgl/runtime/registry.h> #include <dgl/runtime/registry.h>
#include <string>
#include <memory> #include <memory>
#include <string>
#include "module_util.h" #include "module_util.h"
namespace dgl { namespace dgl {
...@@ -21,7 +23,7 @@ void ImportModuleBlob(const char* mblob, std::vector<Module>* mlist) { ...@@ -21,7 +23,7 @@ void ImportModuleBlob(const char* mblob, std::vector<Module>* mlist) {
uint64_t nbytes = 0; uint64_t nbytes = 0;
for (size_t i = 0; i < sizeof(nbytes); ++i) { for (size_t i = 0; i < sizeof(nbytes); ++i) {
uint64_t c = mblob[i]; uint64_t c = mblob[i];
nbytes |= (c & 0xffUL) << (i * 8); nbytes |= (c & 0xffUL) << (i * 8);
} }
dmlc::MemoryFixedSizeStream fs( dmlc::MemoryFixedSizeStream fs(
const_cast<char*>(mblob + sizeof(nbytes)), static_cast<size_t>(nbytes)); const_cast<char*>(mblob + sizeof(nbytes)), static_cast<size_t>(nbytes));
...@@ -33,9 +35,8 @@ void ImportModuleBlob(const char* mblob, std::vector<Module>* mlist) { ...@@ -33,9 +35,8 @@ void ImportModuleBlob(const char* mblob, std::vector<Module>* mlist) {
CHECK(stream->Read(&tkey)); CHECK(stream->Read(&tkey));
std::string fkey = "module.loadbinary_" + tkey; std::string fkey = "module.loadbinary_" + tkey;
const PackedFunc* f = Registry::Get(fkey); const PackedFunc* f = Registry::Get(fkey);
CHECK(f != nullptr) CHECK(f != nullptr) << "Loader of " << tkey << "(" << fkey
<< "Loader of " << tkey << "(" << ") is not presented.";
<< fkey << ") is not presented.";
Module m = (*f)(static_cast<void*>(stream)); Module m = (*f)(static_cast<void*>(stream));
mlist->push_back(m); mlist->push_back(m);
} }
...@@ -44,15 +45,14 @@ void ImportModuleBlob(const char* mblob, std::vector<Module>* mlist) { ...@@ -44,15 +45,14 @@ void ImportModuleBlob(const char* mblob, std::vector<Module>* mlist) {
#endif #endif
} }
PackedFunc WrapPackedFunc(BackendPackedCFunc faddr, PackedFunc WrapPackedFunc(
const std::shared_ptr<ModuleNode>& sptr_to_self) { BackendPackedCFunc faddr, const std::shared_ptr<ModuleNode>& sptr_to_self) {
return PackedFunc([faddr, sptr_to_self](DGLArgs args, DGLRetValue* rv) { return PackedFunc([faddr, sptr_to_self](DGLArgs args, DGLRetValue* rv) {
int ret = (*faddr)( int ret = (*faddr)(
const_cast<DGLValue*>(args.values), const_cast<DGLValue*>(args.values), const_cast<int*>(args.type_codes),
const_cast<int*>(args.type_codes), args.num_args);
args.num_args); CHECK_EQ(ret, 0) << DGLGetLastError();
CHECK_EQ(ret, 0) << DGLGetLastError(); });
});
} }
} // namespace runtime } // namespace runtime
......
...@@ -6,17 +6,16 @@ ...@@ -6,17 +6,16 @@
#ifndef DGL_RUNTIME_MODULE_UTIL_H_ #ifndef DGL_RUNTIME_MODULE_UTIL_H_
#define DGL_RUNTIME_MODULE_UTIL_H_ #define DGL_RUNTIME_MODULE_UTIL_H_
#include <dgl/runtime/module.h>
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/c_backend_api.h> #include <dgl/runtime/c_backend_api.h>
#include <vector> #include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/module.h>
#include <memory> #include <memory>
#include <vector>
extern "C" { extern "C" {
// Function signature for generated packed function in shared library // Function signature for generated packed function in shared library
typedef int (*BackendPackedCFunc)(void* args, typedef int (*BackendPackedCFunc)(void* args, int* type_codes, int num_args);
int* type_codes,
int num_args);
} // extern "C" } // extern "C"
namespace dgl { namespace dgl {
...@@ -26,7 +25,8 @@ namespace runtime { ...@@ -26,7 +25,8 @@ namespace runtime {
* \param faddr The function address * \param faddr The function address
* \param mptr The module pointer node. * \param mptr The module pointer node.
*/ */
PackedFunc WrapPackedFunc(BackendPackedCFunc faddr, const std::shared_ptr<ModuleNode>& mptr); PackedFunc WrapPackedFunc(
BackendPackedCFunc faddr, const std::shared_ptr<ModuleNode>& mptr);
/*! /*!
* \brief Load and append module blob to module list * \brief Load and append module blob to module list
* \param mblob The module blob. * \param mblob The module blob.
...@@ -39,13 +39,13 @@ void ImportModuleBlob(const char* mblob, std::vector<Module>* module_list); ...@@ -39,13 +39,13 @@ void ImportModuleBlob(const char* mblob, std::vector<Module>* module_list);
* \param flookup A symbol lookup function. * \param flookup A symbol lookup function.
* \tparam FLookup a function of signature string->void* * \tparam FLookup a function of signature string->void*
*/ */
template<typename FLookup> template <typename FLookup>
void InitContextFunctions(FLookup flookup) { void InitContextFunctions(FLookup flookup) {
#define DGL_INIT_CONTEXT_FUNC(FuncName) \ #define DGL_INIT_CONTEXT_FUNC(FuncName) \
if (auto *fp = reinterpret_cast<decltype(&FuncName)*> \ if (auto* fp = \
(flookup("__" #FuncName))) { \ reinterpret_cast<decltype(&FuncName)*>(flookup("__" #FuncName))) { \
*fp = FuncName; \ *fp = FuncName; \
} }
// Initialize the functions // Initialize the functions
DGL_INIT_CONTEXT_FUNC(DGLFuncCall); DGL_INIT_CONTEXT_FUNC(DGLFuncCall);
DGL_INIT_CONTEXT_FUNC(DGLAPISetLastError); DGL_INIT_CONTEXT_FUNC(DGLAPISetLastError);
...@@ -55,8 +55,8 @@ void InitContextFunctions(FLookup flookup) { ...@@ -55,8 +55,8 @@ void InitContextFunctions(FLookup flookup) {
DGL_INIT_CONTEXT_FUNC(DGLBackendParallelLaunch); DGL_INIT_CONTEXT_FUNC(DGLBackendParallelLaunch);
DGL_INIT_CONTEXT_FUNC(DGLBackendParallelBarrier); DGL_INIT_CONTEXT_FUNC(DGLBackendParallelBarrier);
#undef DGL_INIT_CONTEXT_FUNC #undef DGL_INIT_CONTEXT_FUNC
} }
} // namespace runtime } // namespace runtime
} // namespace dgl } // namespace dgl
#endif // DGL_RUNTIME_MODULE_UTIL_H_ #endif // DGL_RUNTIME_MODULE_UTIL_H_
...@@ -4,8 +4,9 @@ ...@@ -4,8 +4,9 @@
* \brief Implementation of runtime object APIs. * \brief Implementation of runtime object APIs.
*/ */
#include <dgl/runtime/object.h> #include <dgl/runtime/object.h>
#include <memory>
#include <atomic> #include <atomic>
#include <memory>
#include <mutex> #include <mutex>
#include <unordered_map> #include <unordered_map>
...@@ -36,7 +37,7 @@ bool Object::_DerivedFrom(uint32_t tid) const { ...@@ -36,7 +37,7 @@ bool Object::_DerivedFrom(uint32_t tid) const {
// this is slow, usually caller always hold the result in a static variable. // this is slow, usually caller always hold the result in a static variable.
uint32_t Object::TypeKey2Index(const char* key) { uint32_t Object::TypeKey2Index(const char* key) {
TypeManager *t = TypeManager::Global(); TypeManager* t = TypeManager::Global();
std::lock_guard<std::mutex> lock(t->mutex); std::lock_guard<std::mutex> lock(t->mutex);
std::string skey = key; std::string skey = key;
auto it = t->key2index.find(skey); auto it = t->key2index.find(skey);
...@@ -50,7 +51,7 @@ uint32_t Object::TypeKey2Index(const char* key) { ...@@ -50,7 +51,7 @@ uint32_t Object::TypeKey2Index(const char* key) {
} }
const char* Object::TypeIndex2Key(uint32_t index) { const char* Object::TypeIndex2Key(uint32_t index) {
TypeManager *t = TypeManager::Global(); TypeManager* t = TypeManager::Global();
std::lock_guard<std::mutex> lock(t->mutex); std::lock_guard<std::mutex> lock(t->mutex);
CHECK_NE(index, 0); CHECK_NE(index, 0);
return t->index2key.at(index - 1).c_str(); return t->index2key.at(index - 1).c_str();
......
This diff is collapsed.
...@@ -8,4 +8,4 @@ namespace dgl { ...@@ -8,4 +8,4 @@ namespace dgl {
namespace runtime { namespace runtime {
DefaultGrainSizeT default_grain_size; DefaultGrainSizeT default_grain_size;
} // namespace runtime } // namespace runtime
} // namesoace dgl } // namespace dgl
...@@ -3,13 +3,15 @@ ...@@ -3,13 +3,15 @@
* \file registry.cc * \file registry.cc
* \brief The global registry of packed function. * \brief The global registry of packed function.
*/ */
#include <dgl/runtime/registry.h>
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
#include <dgl/runtime/registry.h>
#include <unordered_map>
#include <mutex>
#include <memory>
#include <array> #include <array>
#include <memory>
#include <mutex>
#include <unordered_map>
#include "runtime_base.h" #include "runtime_base.h"
namespace dgl { namespace dgl {
...@@ -18,9 +20,10 @@ namespace runtime { ...@@ -18,9 +20,10 @@ namespace runtime {
struct Registry::Manager { struct Registry::Manager {
// map storing the functions. // map storing the functions.
// We delibrately used raw pointer // We delibrately used raw pointer
// This is because PackedFunc can contain callbacks into the host languge(python) // This is because PackedFunc can contain callbacks into the host
// and the resource can become invalid because of indeterminstic order of destruction. // languge(python) and the resource can become invalid because of
// The resources will only be recycled during program exit. // indeterminstic order of destruction. The resources will only be recycled
// during program exit.
std::unordered_map<std::string, Registry*> fmap; std::unordered_map<std::string, Registry*> fmap;
// vtable for extension type // vtable for extension type
std::array<ExtTypeVTable, kExtEnd> ext_vtable; std::array<ExtTypeVTable, kExtEnd> ext_vtable;
...@@ -44,7 +47,8 @@ Registry& Registry::set_body(PackedFunc f) { // NOLINT(*) ...@@ -44,7 +47,8 @@ Registry& Registry::set_body(PackedFunc f) { // NOLINT(*)
return *this; return *this;
} }
Registry& Registry::Register(const std::string& name, bool override) { // NOLINT(*) Registry& Registry::Register(
const std::string& name, bool override) { // NOLINT(*)
Manager* m = Manager::Global(); Manager* m = Manager::Global();
std::lock_guard<std::mutex> lock(m->mutex); std::lock_guard<std::mutex> lock(m->mutex);
auto it = m->fmap.find(name); auto it = m->fmap.find(name);
...@@ -54,8 +58,7 @@ Registry& Registry::Register(const std::string& name, bool override) { // NOLIN ...@@ -54,8 +58,7 @@ Registry& Registry::Register(const std::string& name, bool override) { // NOLIN
m->fmap[name] = r; m->fmap[name] = r;
return *r; return *r;
} else { } else {
CHECK(override) CHECK(override) << "Global PackedFunc " << name << " is already registered";
<< "Global PackedFunc " << name << " is already registered";
return *it->second; return *it->second;
} }
} }
...@@ -82,7 +85,7 @@ std::vector<std::string> Registry::ListNames() { ...@@ -82,7 +85,7 @@ std::vector<std::string> Registry::ListNames() {
std::lock_guard<std::mutex> lock(m->mutex); std::lock_guard<std::mutex> lock(m->mutex);
std::vector<std::string> keys; std::vector<std::string> keys;
keys.reserve(m->fmap.size()); keys.reserve(m->fmap.size());
for (const auto &kv : m->fmap) { for (const auto& kv : m->fmap) {
keys.push_back(kv.first); keys.push_back(kv.first);
} }
return keys; return keys;
...@@ -92,8 +95,7 @@ ExtTypeVTable* ExtTypeVTable::Get(int type_code) { ...@@ -92,8 +95,7 @@ ExtTypeVTable* ExtTypeVTable::Get(int type_code) {
CHECK(type_code > kExtBegin && type_code < kExtEnd); CHECK(type_code > kExtBegin && type_code < kExtEnd);
Registry::Manager* m = Registry::Manager::Global(); Registry::Manager* m = Registry::Manager::Global();
ExtTypeVTable* vt = &(m->ext_vtable[type_code]); ExtTypeVTable* vt = &(m->ext_vtable[type_code]);
CHECK(vt->destroy != nullptr) CHECK(vt->destroy != nullptr) << "Extension type not registered";
<< "Extension type not registered";
return vt; return vt;
} }
...@@ -114,7 +116,7 @@ struct DGLFuncThreadLocalEntry { ...@@ -114,7 +116,7 @@ struct DGLFuncThreadLocalEntry {
/*! \brief result holder for returning strings */ /*! \brief result holder for returning strings */
std::vector<std::string> ret_vec_str; std::vector<std::string> ret_vec_str;
/*! \brief result holder for returning string pointers */ /*! \brief result holder for returning string pointers */
std::vector<const char *> ret_vec_charp; std::vector<const char*> ret_vec_charp;
}; };
/*! \brief Thread local store that can be used to hold return values. */ /*! \brief Thread local store that can be used to hold return values. */
...@@ -126,8 +128,7 @@ int DGLExtTypeFree(void* handle, int type_code) { ...@@ -126,8 +128,7 @@ int DGLExtTypeFree(void* handle, int type_code) {
API_END(); API_END();
} }
int DGLFuncRegisterGlobal( int DGLFuncRegisterGlobal(const char* name, DGLFunctionHandle f, int override) {
const char* name, DGLFunctionHandle f, int override) {
API_BEGIN(); API_BEGIN();
dgl::runtime::Registry::Register(name, override != 0) dgl::runtime::Registry::Register(name, override != 0)
.set_body(*static_cast<dgl::runtime::PackedFunc*>(f)); .set_body(*static_cast<dgl::runtime::PackedFunc*>(f));
...@@ -136,8 +137,7 @@ int DGLFuncRegisterGlobal( ...@@ -136,8 +137,7 @@ int DGLFuncRegisterGlobal(
int DGLFuncGetGlobal(const char* name, DGLFunctionHandle* out) { int DGLFuncGetGlobal(const char* name, DGLFunctionHandle* out) {
API_BEGIN(); API_BEGIN();
const dgl::runtime::PackedFunc* fp = const dgl::runtime::PackedFunc* fp = dgl::runtime::Registry::Get(name);
dgl::runtime::Registry::Get(name);
if (fp != nullptr) { if (fp != nullptr) {
*out = new dgl::runtime::PackedFunc(*fp); // NOLINT(*) *out = new dgl::runtime::PackedFunc(*fp); // NOLINT(*)
} else { } else {
...@@ -146,10 +146,9 @@ int DGLFuncGetGlobal(const char* name, DGLFunctionHandle* out) { ...@@ -146,10 +146,9 @@ int DGLFuncGetGlobal(const char* name, DGLFunctionHandle* out) {
API_END(); API_END();
} }
int DGLFuncListGlobalNames(int *out_size, int DGLFuncListGlobalNames(int* out_size, const char*** out_array) {
const char*** out_array) {
API_BEGIN(); API_BEGIN();
DGLFuncThreadLocalEntry *ret = DGLFuncThreadLocalStore::Get(); DGLFuncThreadLocalEntry* ret = DGLFuncThreadLocalStore::Get();
ret->ret_vec_str = dgl::runtime::Registry::ListNames(); ret->ret_vec_str = dgl::runtime::Registry::ListNames();
ret->ret_vec_charp.clear(); ret->ret_vec_charp.clear();
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
......
...@@ -14,9 +14,10 @@ namespace dgl { ...@@ -14,9 +14,10 @@ namespace dgl {
namespace runtime { namespace runtime {
/* /*
* The runtime allocates resources during the computation. Some of the resources cannot be * The runtime allocates resources during the computation. Some of the resources
* destroyed after the process exits especially when the process doesn't exits normally. * cannot be destroyed after the process exits especially when the process
* We need to keep track of the resources in the system and clean them up properly. * doesn't exits normally. We need to keep track of the resources in the system
* and clean them up properly.
*/ */
class ResourceManager { class ResourceManager {
std::unordered_map<std::string, std::shared_ptr<Resource>> resources; std::unordered_map<std::string, std::shared_ptr<Resource>> resources;
...@@ -25,12 +26,11 @@ class ResourceManager { ...@@ -25,12 +26,11 @@ class ResourceManager {
void Add(const std::string &key, std::shared_ptr<Resource> resource) { void Add(const std::string &key, std::shared_ptr<Resource> resource) {
auto it = resources.find(key); auto it = resources.find(key);
CHECK(it == resources.end()) << key << " already exists"; CHECK(it == resources.end()) << key << " already exists";
resources.insert(std::pair<std::string, std::shared_ptr<Resource>>(key, resource)); resources.insert(
std::pair<std::string, std::shared_ptr<Resource>>(key, resource));
} }
void Erase(const std::string &key) { void Erase(const std::string &key) { resources.erase(key); }
resources.erase(key);
}
void Cleanup() { void Cleanup() {
for (auto it = resources.begin(); it != resources.end(); it++) { for (auto it = resources.begin(); it != resources.end(); it++) {
...@@ -46,13 +46,9 @@ void AddResource(const std::string &key, std::shared_ptr<Resource> resource) { ...@@ -46,13 +46,9 @@ void AddResource(const std::string &key, std::shared_ptr<Resource> resource) {
manager.Add(key, resource); manager.Add(key, resource);
} }
void DeleteResource(const std::string &key) { void DeleteResource(const std::string &key) { manager.Erase(key); }
manager.Erase(key);
}
void CleanupResources() { void CleanupResources() { manager.Cleanup(); }
manager.Cleanup();
}
} // namespace runtime } // namespace runtime
} // namespace dgl } // namespace dgl
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
#include <semaphore.h> #include <semaphore.h>
#endif #endif
namespace dgl { namespace dgl {
namespace runtime { namespace runtime {
...@@ -31,7 +30,8 @@ class Semaphore { ...@@ -31,7 +30,8 @@ class Semaphore {
void Wait(); void Wait();
/*! /*!
* \brief timed wait, decrease semaphore by 1 or returns if times out * \brief timed wait, decrease semaphore by 1 or returns if times out
* \param timeout The timeout value in milliseconds. If zero, wait indefinitely. * \param timeout The timeout value in milliseconds. If zero, wait
* indefinitely.
*/ */
bool TimedWait(int timeout); bool TimedWait(int timeout);
/*! /*!
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment