Unverified Commit 401e1278 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] clang-format auto fix. (#4811)



* [Misc] clang-format auto fix.

* fix

* manual
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 6c53f351
...@@ -2,14 +2,15 @@ ...@@ -2,14 +2,15 @@
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* \file file_util.cc * \file file_util.cc
*/ */
#include "file_util.h"
#include <dgl/runtime/serializer.h>
#include <dmlc/json.h> #include <dmlc/json.h>
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <dgl/runtime/serializer.h>
#include <fstream> #include <fstream>
#include <vector>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "file_util.h"
namespace dgl { namespace dgl {
namespace runtime { namespace runtime {
...@@ -52,8 +53,8 @@ bool FunctionInfo::Load(dmlc::Stream* reader) { ...@@ -52,8 +53,8 @@ bool FunctionInfo::Load(dmlc::Stream* reader) {
return true; return true;
} }
std::string GetFileFormat(const std::string& file_name, std::string GetFileFormat(
const std::string& format) { const std::string& file_name, const std::string& format) {
std::string fmt = format; std::string fmt = format;
if (fmt.length() == 0) { if (fmt.length() == 0) {
if (file_name.find(".signed.so") != std::string::npos) return "sgx"; if (file_name.find(".signed.so") != std::string::npos) return "sgx";
...@@ -87,7 +88,7 @@ std::string GetFileBasename(const std::string& file_name) { ...@@ -87,7 +88,7 @@ std::string GetFileBasename(const std::string& file_name) {
} }
std::string GetMetaFilePath(const std::string& file_name) { std::string GetMetaFilePath(const std::string& file_name) {
size_t pos = file_name.find_last_of("."); size_t pos = file_name.find_last_of(".");
if (pos != std::string::npos) { if (pos != std::string::npos) {
return file_name.substr(0, pos) + ".dgl_meta.json"; return file_name.substr(0, pos) + ".dgl_meta.json";
} else { } else {
...@@ -95,8 +96,7 @@ std::string GetMetaFilePath(const std::string& file_name) { ...@@ -95,8 +96,7 @@ std::string GetMetaFilePath(const std::string& file_name) {
} }
} }
void LoadBinaryFromFile(const std::string& file_name, void LoadBinaryFromFile(const std::string& file_name, std::string* data) {
std::string* data) {
std::ifstream fs(file_name, std::ios::in | std::ios::binary); std::ifstream fs(file_name, std::ios::in | std::ios::binary);
CHECK(!fs.fail()) << "Cannot open " << file_name; CHECK(!fs.fail()) << "Cannot open " << file_name;
// get its size: // get its size:
...@@ -107,9 +107,7 @@ void LoadBinaryFromFile(const std::string& file_name, ...@@ -107,9 +107,7 @@ void LoadBinaryFromFile(const std::string& file_name,
fs.read(&(*data)[0], size); fs.read(&(*data)[0], size);
} }
void SaveBinaryToFile( void SaveBinaryToFile(const std::string& file_name, const std::string& data) {
const std::string& file_name,
const std::string& data) {
std::ofstream fs(file_name, std::ios::out | std::ios::binary); std::ofstream fs(file_name, std::ios::out | std::ios::binary);
CHECK(!fs.fail()) << "Cannot open " << file_name; CHECK(!fs.fail()) << "Cannot open " << file_name;
fs.write(&data[0], data.length()); fs.write(&data[0], data.length());
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "meta_data.h" #include "meta_data.h"
namespace dgl { namespace dgl {
...@@ -17,8 +18,8 @@ namespace runtime { ...@@ -17,8 +18,8 @@ namespace runtime {
* \param file_name The name of the file. * \param file_name The name of the file.
* \param format The format of the file. * \param format The format of the file.
*/ */
std::string GetFileFormat(const std::string& file_name, std::string GetFileFormat(
const std::string& format); const std::string& file_name, const std::string& format);
/*! /*!
* \return the directory in which DGL stores cached files. * \return the directory in which DGL stores cached files.
...@@ -44,16 +45,14 @@ std::string GetFileBasename(const std::string& file_name); ...@@ -44,16 +45,14 @@ std::string GetFileBasename(const std::string& file_name);
* \param file_name The name of the file. * \param file_name The name of the file.
* \param data The data to be loaded. * \param data The data to be loaded.
*/ */
void LoadBinaryFromFile(const std::string& file_name, void LoadBinaryFromFile(const std::string& file_name, std::string* data);
std::string* data);
/*! /*!
* \brief Load binary file into a in-memory buffer. * \brief Load binary file into a in-memory buffer.
* \param file_name The name of the file. * \param file_name The name of the file.
* \param data The binary data to be saved. * \param data The binary data to be saved.
*/ */
void SaveBinaryToFile(const std::string& file_name, void SaveBinaryToFile(const std::string& file_name, const std::string& data);
const std::string& data);
/*! /*!
* \brief Save meta data to file. * \brief Save meta data to file.
......
...@@ -6,11 +6,13 @@ ...@@ -6,11 +6,13 @@
#ifndef DGL_RUNTIME_META_DATA_H_ #ifndef DGL_RUNTIME_META_DATA_H_
#define DGL_RUNTIME_META_DATA_H_ #define DGL_RUNTIME_META_DATA_H_
#include <dmlc/json.h>
#include <dmlc/io.h>
#include <dgl/runtime/packed_func.h> #include <dgl/runtime/packed_func.h>
#include <dmlc/io.h>
#include <dmlc/json.h>
#include <string> #include <string>
#include <vector> #include <vector>
#include "runtime_base.h" #include "runtime_base.h"
namespace dgl { namespace dgl {
......
...@@ -4,10 +4,11 @@ ...@@ -4,10 +4,11 @@
* \brief DGL module system * \brief DGL module system
*/ */
#include <dgl/runtime/module.h> #include <dgl/runtime/module.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/packed_func.h> #include <dgl/runtime/packed_func.h>
#include <unordered_set> #include <dgl/runtime/registry.h>
#include <cstring> #include <cstring>
#include <unordered_set>
#ifndef _LIBCPP_SGX_CONFIG #ifndef _LIBCPP_SGX_CONFIG
#include "file_util.h" #include "file_util.h"
#endif #endif
...@@ -44,20 +45,18 @@ void Module::Import(Module other) { ...@@ -44,20 +45,18 @@ void Module::Import(Module other) {
node_->imports_.emplace_back(std::move(other)); node_->imports_.emplace_back(std::move(other));
} }
Module Module::LoadFromFile(const std::string& file_name, Module Module::LoadFromFile(
const std::string& format) { const std::string& file_name, const std::string& format) {
#ifndef _LIBCPP_SGX_CONFIG #ifndef _LIBCPP_SGX_CONFIG
std::string fmt = GetFileFormat(file_name, format); std::string fmt = GetFileFormat(file_name, format);
CHECK(fmt.length() != 0) CHECK(fmt.length() != 0) << "Cannot deduce format of file " << file_name;
<< "Cannot deduce format of file " << file_name;
if (fmt == "dll" || fmt == "dylib" || fmt == "dso") { if (fmt == "dll" || fmt == "dylib" || fmt == "dso") {
fmt = "so"; fmt = "so";
} }
std::string load_f_name = "module.loadfile_" + fmt; std::string load_f_name = "module.loadfile_" + fmt;
const PackedFunc* f = Registry::Get(load_f_name); const PackedFunc* f = Registry::Get(load_f_name);
CHECK(f != nullptr) CHECK(f != nullptr) << "Loader of " << format << "(" << load_f_name
<< "Loader of " << format << "(" << ") is not presented.";
<< load_f_name << ") is not presented.";
Module m = (*f)(file_name, format); Module m = (*f)(file_name, format);
return m; return m;
#else #else
...@@ -65,8 +64,8 @@ Module Module::LoadFromFile(const std::string& file_name, ...@@ -65,8 +64,8 @@ Module Module::LoadFromFile(const std::string& file_name,
#endif #endif
} }
void ModuleNode::SaveToFile(const std::string& file_name, void ModuleNode::SaveToFile(
const std::string& format) { const std::string& file_name, const std::string& format) {
LOG(FATAL) << "Module[" << type_key() << "] does not support SaveToFile"; LOG(FATAL) << "Module[" << type_key() << "] does not support SaveToFile";
} }
...@@ -89,9 +88,8 @@ const PackedFunc* ModuleNode::GetFuncFromEnv(const std::string& name) { ...@@ -89,9 +88,8 @@ const PackedFunc* ModuleNode::GetFuncFromEnv(const std::string& name) {
} }
if (pf == nullptr) { if (pf == nullptr) {
const PackedFunc* f = Registry::Get(name); const PackedFunc* f = Registry::Get(name);
CHECK(f != nullptr) CHECK(f != nullptr) << "Cannot find function " << name
<< "Cannot find function " << name << " in the imported modules or global registry";
<< " in the imported modules or global registry";
return f; return f;
} else { } else {
std::unique_ptr<PackedFunc> f(new PackedFunc(pf)); std::unique_ptr<PackedFunc> f(new PackedFunc(pf));
...@@ -125,7 +123,8 @@ bool RuntimeEnabled(const std::string& target) { ...@@ -125,7 +123,8 @@ bool RuntimeEnabled(const std::string& target) {
} else if (target.length() >= 4 && target.substr(0, 4) == "rocm") { } else if (target.length() >= 4 && target.substr(0, 4) == "rocm") {
f_name = "device_api.rocm"; f_name = "device_api.rocm";
} else if (target.length() >= 4 && target.substr(0, 4) == "llvm") { } else if (target.length() >= 4 && target.substr(0, 4) == "llvm") {
const PackedFunc* pf = runtime::Registry::Get("codegen.llvm_target_enabled"); const PackedFunc* pf =
runtime::Registry::Get("codegen.llvm_target_enabled");
if (pf == nullptr) return false; if (pf == nullptr) return false;
return (*pf)(target); return (*pf)(target);
} else { } else {
...@@ -135,41 +134,38 @@ bool RuntimeEnabled(const std::string& target) { ...@@ -135,41 +134,38 @@ bool RuntimeEnabled(const std::string& target) {
} }
DGL_REGISTER_GLOBAL("module._Enabled") DGL_REGISTER_GLOBAL("module._Enabled")
.set_body([](DGLArgs args, DGLRetValue *ret) { .set_body([](DGLArgs args, DGLRetValue* ret) {
*ret = RuntimeEnabled(args[0]); *ret = RuntimeEnabled(args[0]);
}); });
DGL_REGISTER_GLOBAL("module._GetSource") DGL_REGISTER_GLOBAL("module._GetSource")
.set_body([](DGLArgs args, DGLRetValue *ret) { .set_body([](DGLArgs args, DGLRetValue* ret) {
*ret = args[0].operator Module()->GetSource(args[1]); *ret = args[0].operator Module()->GetSource(args[1]);
}); });
DGL_REGISTER_GLOBAL("module._ImportsSize") DGL_REGISTER_GLOBAL("module._ImportsSize")
.set_body([](DGLArgs args, DGLRetValue *ret) { .set_body([](DGLArgs args, DGLRetValue* ret) {
*ret = static_cast<int64_t>( *ret = static_cast<int64_t>(args[0].operator Module()->imports().size());
args[0].operator Module()->imports().size());
}); });
DGL_REGISTER_GLOBAL("module._GetImport") DGL_REGISTER_GLOBAL("module._GetImport")
.set_body([](DGLArgs args, DGLRetValue *ret) { .set_body([](DGLArgs args, DGLRetValue* ret) {
*ret = args[0].operator Module()-> *ret = args[0].operator Module()->imports().at(args[1].operator int());
imports().at(args[1].operator int());
}); });
DGL_REGISTER_GLOBAL("module._GetTypeKey") DGL_REGISTER_GLOBAL("module._GetTypeKey")
.set_body([](DGLArgs args, DGLRetValue *ret) { .set_body([](DGLArgs args, DGLRetValue* ret) {
*ret = std::string(args[0].operator Module()->type_key()); *ret = std::string(args[0].operator Module()->type_key());
}); });
DGL_REGISTER_GLOBAL("module._LoadFromFile") DGL_REGISTER_GLOBAL("module._LoadFromFile")
.set_body([](DGLArgs args, DGLRetValue *ret) { .set_body([](DGLArgs args, DGLRetValue* ret) {
*ret = Module::LoadFromFile(args[0], args[1]); *ret = Module::LoadFromFile(args[0], args[1]);
}); });
DGL_REGISTER_GLOBAL("module._SaveToFile") DGL_REGISTER_GLOBAL("module._SaveToFile")
.set_body([](DGLArgs args, DGLRetValue *ret) { .set_body([](DGLArgs args, DGLRetValue* ret) {
args[0].operator Module()-> args[0].operator Module()->SaveToFile(args[1], args[2]);
SaveToFile(args[1], args[2]);
}); });
} // namespace runtime } // namespace runtime
} // namespace dgl } // namespace dgl
...@@ -8,8 +8,10 @@ ...@@ -8,8 +8,10 @@
#endif #endif
#include <dgl/runtime/module.h> #include <dgl/runtime/module.h>
#include <dgl/runtime/registry.h> #include <dgl/runtime/registry.h>
#include <string>
#include <memory> #include <memory>
#include <string>
#include "module_util.h" #include "module_util.h"
namespace dgl { namespace dgl {
...@@ -21,7 +23,7 @@ void ImportModuleBlob(const char* mblob, std::vector<Module>* mlist) { ...@@ -21,7 +23,7 @@ void ImportModuleBlob(const char* mblob, std::vector<Module>* mlist) {
uint64_t nbytes = 0; uint64_t nbytes = 0;
for (size_t i = 0; i < sizeof(nbytes); ++i) { for (size_t i = 0; i < sizeof(nbytes); ++i) {
uint64_t c = mblob[i]; uint64_t c = mblob[i];
nbytes |= (c & 0xffUL) << (i * 8); nbytes |= (c & 0xffUL) << (i * 8);
} }
dmlc::MemoryFixedSizeStream fs( dmlc::MemoryFixedSizeStream fs(
const_cast<char*>(mblob + sizeof(nbytes)), static_cast<size_t>(nbytes)); const_cast<char*>(mblob + sizeof(nbytes)), static_cast<size_t>(nbytes));
...@@ -33,9 +35,8 @@ void ImportModuleBlob(const char* mblob, std::vector<Module>* mlist) { ...@@ -33,9 +35,8 @@ void ImportModuleBlob(const char* mblob, std::vector<Module>* mlist) {
CHECK(stream->Read(&tkey)); CHECK(stream->Read(&tkey));
std::string fkey = "module.loadbinary_" + tkey; std::string fkey = "module.loadbinary_" + tkey;
const PackedFunc* f = Registry::Get(fkey); const PackedFunc* f = Registry::Get(fkey);
CHECK(f != nullptr) CHECK(f != nullptr) << "Loader of " << tkey << "(" << fkey
<< "Loader of " << tkey << "(" << ") is not presented.";
<< fkey << ") is not presented.";
Module m = (*f)(static_cast<void*>(stream)); Module m = (*f)(static_cast<void*>(stream));
mlist->push_back(m); mlist->push_back(m);
} }
...@@ -44,15 +45,14 @@ void ImportModuleBlob(const char* mblob, std::vector<Module>* mlist) { ...@@ -44,15 +45,14 @@ void ImportModuleBlob(const char* mblob, std::vector<Module>* mlist) {
#endif #endif
} }
PackedFunc WrapPackedFunc(BackendPackedCFunc faddr, PackedFunc WrapPackedFunc(
const std::shared_ptr<ModuleNode>& sptr_to_self) { BackendPackedCFunc faddr, const std::shared_ptr<ModuleNode>& sptr_to_self) {
return PackedFunc([faddr, sptr_to_self](DGLArgs args, DGLRetValue* rv) { return PackedFunc([faddr, sptr_to_self](DGLArgs args, DGLRetValue* rv) {
int ret = (*faddr)( int ret = (*faddr)(
const_cast<DGLValue*>(args.values), const_cast<DGLValue*>(args.values), const_cast<int*>(args.type_codes),
const_cast<int*>(args.type_codes), args.num_args);
args.num_args); CHECK_EQ(ret, 0) << DGLGetLastError();
CHECK_EQ(ret, 0) << DGLGetLastError(); });
});
} }
} // namespace runtime } // namespace runtime
......
...@@ -6,17 +6,16 @@ ...@@ -6,17 +6,16 @@
#ifndef DGL_RUNTIME_MODULE_UTIL_H_ #ifndef DGL_RUNTIME_MODULE_UTIL_H_
#define DGL_RUNTIME_MODULE_UTIL_H_ #define DGL_RUNTIME_MODULE_UTIL_H_
#include <dgl/runtime/module.h>
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/c_backend_api.h> #include <dgl/runtime/c_backend_api.h>
#include <vector> #include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/module.h>
#include <memory> #include <memory>
#include <vector>
extern "C" { extern "C" {
// Function signature for generated packed function in shared library // Function signature for generated packed function in shared library
typedef int (*BackendPackedCFunc)(void* args, typedef int (*BackendPackedCFunc)(void* args, int* type_codes, int num_args);
int* type_codes,
int num_args);
} // extern "C" } // extern "C"
namespace dgl { namespace dgl {
...@@ -26,7 +25,8 @@ namespace runtime { ...@@ -26,7 +25,8 @@ namespace runtime {
* \param faddr The function address * \param faddr The function address
* \param mptr The module pointer node. * \param mptr The module pointer node.
*/ */
PackedFunc WrapPackedFunc(BackendPackedCFunc faddr, const std::shared_ptr<ModuleNode>& mptr); PackedFunc WrapPackedFunc(
BackendPackedCFunc faddr, const std::shared_ptr<ModuleNode>& mptr);
/*! /*!
* \brief Load and append module blob to module list * \brief Load and append module blob to module list
* \param mblob The module blob. * \param mblob The module blob.
...@@ -39,13 +39,13 @@ void ImportModuleBlob(const char* mblob, std::vector<Module>* module_list); ...@@ -39,13 +39,13 @@ void ImportModuleBlob(const char* mblob, std::vector<Module>* module_list);
* \param flookup A symbol lookup function. * \param flookup A symbol lookup function.
* \tparam FLookup a function of signature string->void* * \tparam FLookup a function of signature string->void*
*/ */
template<typename FLookup> template <typename FLookup>
void InitContextFunctions(FLookup flookup) { void InitContextFunctions(FLookup flookup) {
#define DGL_INIT_CONTEXT_FUNC(FuncName) \ #define DGL_INIT_CONTEXT_FUNC(FuncName) \
if (auto *fp = reinterpret_cast<decltype(&FuncName)*> \ if (auto* fp = \
(flookup("__" #FuncName))) { \ reinterpret_cast<decltype(&FuncName)*>(flookup("__" #FuncName))) { \
*fp = FuncName; \ *fp = FuncName; \
} }
// Initialize the functions // Initialize the functions
DGL_INIT_CONTEXT_FUNC(DGLFuncCall); DGL_INIT_CONTEXT_FUNC(DGLFuncCall);
DGL_INIT_CONTEXT_FUNC(DGLAPISetLastError); DGL_INIT_CONTEXT_FUNC(DGLAPISetLastError);
...@@ -55,8 +55,8 @@ void InitContextFunctions(FLookup flookup) { ...@@ -55,8 +55,8 @@ void InitContextFunctions(FLookup flookup) {
DGL_INIT_CONTEXT_FUNC(DGLBackendParallelLaunch); DGL_INIT_CONTEXT_FUNC(DGLBackendParallelLaunch);
DGL_INIT_CONTEXT_FUNC(DGLBackendParallelBarrier); DGL_INIT_CONTEXT_FUNC(DGLBackendParallelBarrier);
#undef DGL_INIT_CONTEXT_FUNC #undef DGL_INIT_CONTEXT_FUNC
} }
} // namespace runtime } // namespace runtime
} // namespace dgl } // namespace dgl
#endif // DGL_RUNTIME_MODULE_UTIL_H_ #endif // DGL_RUNTIME_MODULE_UTIL_H_
...@@ -4,8 +4,9 @@ ...@@ -4,8 +4,9 @@
* \brief Implementation of runtime object APIs. * \brief Implementation of runtime object APIs.
*/ */
#include <dgl/runtime/object.h> #include <dgl/runtime/object.h>
#include <memory>
#include <atomic> #include <atomic>
#include <memory>
#include <mutex> #include <mutex>
#include <unordered_map> #include <unordered_map>
...@@ -36,7 +37,7 @@ bool Object::_DerivedFrom(uint32_t tid) const { ...@@ -36,7 +37,7 @@ bool Object::_DerivedFrom(uint32_t tid) const {
// this is slow, usually caller always hold the result in a static variable. // this is slow, usually caller always hold the result in a static variable.
uint32_t Object::TypeKey2Index(const char* key) { uint32_t Object::TypeKey2Index(const char* key) {
TypeManager *t = TypeManager::Global(); TypeManager* t = TypeManager::Global();
std::lock_guard<std::mutex> lock(t->mutex); std::lock_guard<std::mutex> lock(t->mutex);
std::string skey = key; std::string skey = key;
auto it = t->key2index.find(skey); auto it = t->key2index.find(skey);
...@@ -50,7 +51,7 @@ uint32_t Object::TypeKey2Index(const char* key) { ...@@ -50,7 +51,7 @@ uint32_t Object::TypeKey2Index(const char* key) {
} }
const char* Object::TypeIndex2Key(uint32_t index) { const char* Object::TypeIndex2Key(uint32_t index) {
TypeManager *t = TypeManager::Global(); TypeManager* t = TypeManager::Global();
std::lock_guard<std::mutex> lock(t->mutex); std::lock_guard<std::mutex> lock(t->mutex);
CHECK_NE(index, 0); CHECK_NE(index, 0);
return t->index2key.at(index - 1).c_str(); return t->index2key.at(index - 1).c_str();
......
/*! /*!
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* \file pack_args.h * \file pack_args.h
* \brief Utility to pack DGLArgs to other type-erased fution calling convention. * \brief Utility to pack DGLArgs to other type-erased fution calling
* convention.
* *
* Two type erased function signatures are supported. * Two type erased function signatures are supported.
* - cuda_style(void** args, int num_args); * - cuda_style(void** args, int num_args);
...@@ -15,8 +16,9 @@ ...@@ -15,8 +16,9 @@
#include <dgl/runtime/c_runtime_api.h> #include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/packed_func.h> #include <dgl/runtime/packed_func.h>
#include <vector>
#include <cstring> #include <cstring>
#include <vector>
namespace dgl { namespace dgl {
namespace runtime { namespace runtime {
...@@ -38,10 +40,12 @@ union ArgUnion { ...@@ -38,10 +40,12 @@ union ArgUnion {
* *
* \return The wrapped packed function. * \return The wrapped packed function.
*/ */
template<typename F> template <typename F>
inline PackedFunc PackFuncVoidAddr(F f, const std::vector<DGLDataType>& arg_types); inline PackedFunc PackFuncVoidAddr(
F f, const std::vector<DGLDataType>& arg_types);
/*! /*!
* \brief Create a packed function that from function only packs buffer arguments. * \brief Create a packed function that from function only packs buffer
* arguments.
* *
* \param f with signiture (DGLArgs args, DGLRetValue* rv, ArgUnion* pack_args) * \param f with signiture (DGLArgs args, DGLRetValue* rv, ArgUnion* pack_args)
* \param arg_types The arguments type information. * \param arg_types The arguments type information.
...@@ -49,19 +53,23 @@ inline PackedFunc PackFuncVoidAddr(F f, const std::vector<DGLDataType>& arg_type ...@@ -49,19 +53,23 @@ inline PackedFunc PackFuncVoidAddr(F f, const std::vector<DGLDataType>& arg_type
* *
* \return The wrapped packed function. * \return The wrapped packed function.
*/ */
template<typename F> template <typename F>
inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<DGLDataType>& arg_types); inline PackedFunc PackFuncNonBufferArg(
F f, const std::vector<DGLDataType>& arg_types);
/*! /*!
* \brief Create a packed function that from function that takes a packed arguments. * \brief Create a packed function that from function that takes a packed
* arguments.
* *
* \param f with signature (DGLArgs args, DGLRetValue* rv, void* pack_args, size_t nbytes) * \param f with signature (DGLArgs args, DGLRetValue* rv, void* pack_args,
* size_t nbytes)
* \param arg_types The arguments that wish to get from * \param arg_types The arguments that wish to get from
* \tparam F the function type * \tparam F the function type
* *
* \return The wrapped packed function. * \return The wrapped packed function.
*/ */
template<typename F> template <typename F>
inline PackedFunc PackFuncPackedArg(F f, const std::vector<DGLDataType>& arg_types); inline PackedFunc PackFuncPackedArg(
F f, const std::vector<DGLDataType>& arg_types);
/*! /*!
* \brief Extract number of buffer argument from the argument types. * \brief Extract number of buffer argument from the argument types.
* \param arg_types The argument types. * \param arg_types The argument types.
...@@ -71,23 +79,21 @@ inline size_t NumBufferArgs(const std::vector<DGLDataType>& arg_types); ...@@ -71,23 +79,21 @@ inline size_t NumBufferArgs(const std::vector<DGLDataType>& arg_types);
// implementations details // implementations details
namespace detail { namespace detail {
template<typename T, int kSize> template <typename T, int kSize>
class TempArray { class TempArray {
public: public:
explicit TempArray(int size) {} explicit TempArray(int size) {}
T* data() { T* data() { return data_; }
return data_;
}
private: private:
T data_[kSize]; T data_[kSize];
}; };
template<typename T> template <typename T>
class TempArray<T, 0> { class TempArray<T, 0> {
public: public:
explicit TempArray(int size) : data_(size) {} explicit TempArray(int size) : data_(size) {}
T* data() { T* data() { return data_.data(); }
return data_.data();
}
private: private:
std::vector<T> data_; std::vector<T> data_;
}; };
...@@ -120,8 +126,9 @@ inline ArgConvertCode GetArgConvertCode(DGLDataType t) { ...@@ -120,8 +126,9 @@ inline ArgConvertCode GetArgConvertCode(DGLDataType t) {
return HANDLE_TO_HANDLE; return HANDLE_TO_HANDLE;
} }
template<int N, typename F> template <int N, typename F>
inline PackedFunc PackFuncVoidAddr_(F f, const std::vector<ArgConvertCode>& codes) { inline PackedFunc PackFuncVoidAddr_(
F f, const std::vector<ArgConvertCode>& codes) {
int num_args = static_cast<int>(codes.size()); int num_args = static_cast<int>(codes.size());
auto ret = [f, codes, num_args](DGLArgs args, DGLRetValue* ret) { auto ret = [f, codes, num_args](DGLArgs args, DGLRetValue* ret) {
TempArray<void*, N> addr_(num_args); TempArray<void*, N> addr_(num_args);
...@@ -141,7 +148,7 @@ inline PackedFunc PackFuncVoidAddr_(F f, const std::vector<ArgConvertCode>& code ...@@ -141,7 +148,7 @@ inline PackedFunc PackFuncVoidAddr_(F f, const std::vector<ArgConvertCode>& code
addr[i] = &(holder[i]); addr[i] = &(holder[i]);
break; break;
} }
case INT64_TO_UINT32 : { case INT64_TO_UINT32: {
holder[i].v_uint32 = static_cast<uint32_t>(args.values[i].v_int64); holder[i].v_uint32 = static_cast<uint32_t>(args.values[i].v_int64);
addr[i] = &(holder[i]); addr[i] = &(holder[i]);
break; break;
...@@ -158,7 +165,7 @@ inline PackedFunc PackFuncVoidAddr_(F f, const std::vector<ArgConvertCode>& code ...@@ -158,7 +165,7 @@ inline PackedFunc PackFuncVoidAddr_(F f, const std::vector<ArgConvertCode>& code
return PackedFunc(ret); return PackedFunc(ret);
} }
template<int N, typename F> template <int N, typename F>
inline PackedFunc PackFuncNonBufferArg_( inline PackedFunc PackFuncNonBufferArg_(
F f, int base, const std::vector<ArgConvertCode>& codes) { F f, int base, const std::vector<ArgConvertCode>& codes) {
int num_args = static_cast<int>(codes.size()); int num_args = static_cast<int>(codes.size());
...@@ -169,22 +176,27 @@ inline PackedFunc PackFuncNonBufferArg_( ...@@ -169,22 +176,27 @@ inline PackedFunc PackFuncNonBufferArg_(
switch (codes[i]) { switch (codes[i]) {
case INT64_TO_INT64: case INT64_TO_INT64:
case FLOAT64_TO_FLOAT64: { case FLOAT64_TO_FLOAT64: {
LOG(FATAL) << "Donot support 64bit argument to device function"; break; LOG(FATAL) << "Donot support 64bit argument to device function";
break;
} }
case INT64_TO_INT32: { case INT64_TO_INT32: {
holder[i].v_int32 = static_cast<int32_t>(args.values[base + i].v_int64); holder[i].v_int32 =
static_cast<int32_t>(args.values[base + i].v_int64);
break; break;
} }
case INT64_TO_UINT32 : { case INT64_TO_UINT32: {
holder[i].v_uint32 = static_cast<uint32_t>(args.values[base + i].v_int64); holder[i].v_uint32 =
static_cast<uint32_t>(args.values[base + i].v_int64);
break; break;
} }
case FLOAT64_TO_FLOAT32: { case FLOAT64_TO_FLOAT32: {
holder[i].v_float32 = static_cast<float>(args.values[base + i].v_float64); holder[i].v_float32 =
static_cast<float>(args.values[base + i].v_float64);
break; break;
} }
case HANDLE_TO_HANDLE: { case HANDLE_TO_HANDLE: {
LOG(FATAL) << "not reached"; break; LOG(FATAL) << "not reached";
break;
} }
} }
} }
...@@ -193,7 +205,7 @@ inline PackedFunc PackFuncNonBufferArg_( ...@@ -193,7 +205,7 @@ inline PackedFunc PackFuncNonBufferArg_(
return PackedFunc(ret); return PackedFunc(ret);
} }
template<int N, typename F> template <int N, typename F>
inline PackedFunc PackFuncPackedArg_( inline PackedFunc PackFuncPackedArg_(
F f, const std::vector<ArgConvertCode>& codes) { F f, const std::vector<ArgConvertCode>& codes) {
int num_args = static_cast<int>(codes.size()); int num_args = static_cast<int>(codes.size());
...@@ -221,7 +233,7 @@ inline PackedFunc PackFuncPackedArg_( ...@@ -221,7 +233,7 @@ inline PackedFunc PackFuncPackedArg_(
++ptr; ++ptr;
break; break;
} }
case INT64_TO_UINT32 : { case INT64_TO_UINT32: {
*reinterpret_cast<uint32_t*>(ptr) = *reinterpret_cast<uint32_t*>(ptr) =
static_cast<uint32_t>(args.values[i].v_int64); static_cast<uint32_t>(args.values[i].v_int64);
++ptr; ++ptr;
...@@ -234,7 +246,8 @@ inline PackedFunc PackFuncPackedArg_( ...@@ -234,7 +246,8 @@ inline PackedFunc PackFuncPackedArg_(
break; break;
} }
default: { default: {
LOG(FATAL) << "not reached"; break; LOG(FATAL) << "not reached";
break;
} }
} }
} }
...@@ -244,8 +257,9 @@ inline PackedFunc PackFuncPackedArg_( ...@@ -244,8 +257,9 @@ inline PackedFunc PackFuncPackedArg_(
} }
} // namespace detail } // namespace detail
template<typename F> template <typename F>
inline PackedFunc PackFuncVoidAddr(F f, const std::vector<DGLDataType>& arg_types) { inline PackedFunc PackFuncVoidAddr(
F f, const std::vector<DGLDataType>& arg_types) {
std::vector<detail::ArgConvertCode> codes(arg_types.size()); std::vector<detail::ArgConvertCode> codes(arg_types.size());
for (size_t i = 0; i < arg_types.size(); ++i) { for (size_t i = 0; i < arg_types.size(); ++i) {
codes[i] = detail::GetArgConvertCode(arg_types[i]); codes[i] = detail::GetArgConvertCode(arg_types[i]);
...@@ -265,7 +279,8 @@ inline size_t NumBufferArgs(const std::vector<DGLDataType>& arg_types) { ...@@ -265,7 +279,8 @@ inline size_t NumBufferArgs(const std::vector<DGLDataType>& arg_types) {
size_t base = arg_types.size(); size_t base = arg_types.size();
for (size_t i = 0; i < arg_types.size(); ++i) { for (size_t i = 0; i < arg_types.size(); ++i) {
if (arg_types[i].code != kHandle) { if (arg_types[i].code != kHandle) {
base = i; break; base = i;
break;
} }
} }
for (size_t i = base; i < arg_types.size(); ++i) { for (size_t i = base; i < arg_types.size(); ++i) {
...@@ -275,8 +290,9 @@ inline size_t NumBufferArgs(const std::vector<DGLDataType>& arg_types) { ...@@ -275,8 +290,9 @@ inline size_t NumBufferArgs(const std::vector<DGLDataType>& arg_types) {
return base; return base;
} }
template<typename F> template <typename F>
inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<DGLDataType>& arg_types) { inline PackedFunc PackFuncNonBufferArg(
F f, const std::vector<DGLDataType>& arg_types) {
size_t num_buffer = NumBufferArgs(arg_types); size_t num_buffer = NumBufferArgs(arg_types);
std::vector<detail::ArgConvertCode> codes; std::vector<detail::ArgConvertCode> codes;
for (size_t i = num_buffer; i < arg_types.size(); ++i) { for (size_t i = num_buffer; i < arg_types.size(); ++i) {
...@@ -292,8 +308,9 @@ inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<DGLDataType>& arg_ ...@@ -292,8 +308,9 @@ inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<DGLDataType>& arg_
} }
} }
template<typename F> template <typename F>
inline PackedFunc PackFuncPackedArg(F f, const std::vector<DGLDataType>& arg_types) { inline PackedFunc PackFuncPackedArg(
F f, const std::vector<DGLDataType>& arg_types) {
std::vector<detail::ArgConvertCode> codes; std::vector<detail::ArgConvertCode> codes;
for (size_t i = 0; i < arg_types.size(); ++i) { for (size_t i = 0; i < arg_types.size(); ++i) {
codes.push_back(detail::GetArgConvertCode(arg_types[i])); codes.push_back(detail::GetArgConvertCode(arg_types[i]));
......
...@@ -8,4 +8,4 @@ namespace dgl { ...@@ -8,4 +8,4 @@ namespace dgl {
namespace runtime { namespace runtime {
DefaultGrainSizeT default_grain_size; DefaultGrainSizeT default_grain_size;
} // namespace runtime } // namespace runtime
} // namesoace dgl } // namespace dgl
...@@ -3,13 +3,15 @@ ...@@ -3,13 +3,15 @@
* \file registry.cc * \file registry.cc
* \brief The global registry of packed function. * \brief The global registry of packed function.
*/ */
#include <dgl/runtime/registry.h>
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
#include <dgl/runtime/registry.h>
#include <unordered_map>
#include <mutex>
#include <memory>
#include <array> #include <array>
#include <memory>
#include <mutex>
#include <unordered_map>
#include "runtime_base.h" #include "runtime_base.h"
namespace dgl { namespace dgl {
...@@ -18,9 +20,10 @@ namespace runtime { ...@@ -18,9 +20,10 @@ namespace runtime {
struct Registry::Manager { struct Registry::Manager {
// map storing the functions. // map storing the functions.
// We delibrately used raw pointer // We delibrately used raw pointer
// This is because PackedFunc can contain callbacks into the host languge(python) // This is because PackedFunc can contain callbacks into the host
// and the resource can become invalid because of indeterminstic order of destruction. // languge(python) and the resource can become invalid because of
// The resources will only be recycled during program exit. // indeterminstic order of destruction. The resources will only be recycled
// during program exit.
std::unordered_map<std::string, Registry*> fmap; std::unordered_map<std::string, Registry*> fmap;
// vtable for extension type // vtable for extension type
std::array<ExtTypeVTable, kExtEnd> ext_vtable; std::array<ExtTypeVTable, kExtEnd> ext_vtable;
...@@ -44,7 +47,8 @@ Registry& Registry::set_body(PackedFunc f) { // NOLINT(*) ...@@ -44,7 +47,8 @@ Registry& Registry::set_body(PackedFunc f) { // NOLINT(*)
return *this; return *this;
} }
Registry& Registry::Register(const std::string& name, bool override) { // NOLINT(*) Registry& Registry::Register(
const std::string& name, bool override) { // NOLINT(*)
Manager* m = Manager::Global(); Manager* m = Manager::Global();
std::lock_guard<std::mutex> lock(m->mutex); std::lock_guard<std::mutex> lock(m->mutex);
auto it = m->fmap.find(name); auto it = m->fmap.find(name);
...@@ -54,8 +58,7 @@ Registry& Registry::Register(const std::string& name, bool override) { // NOLIN ...@@ -54,8 +58,7 @@ Registry& Registry::Register(const std::string& name, bool override) { // NOLIN
m->fmap[name] = r; m->fmap[name] = r;
return *r; return *r;
} else { } else {
CHECK(override) CHECK(override) << "Global PackedFunc " << name << " is already registered";
<< "Global PackedFunc " << name << " is already registered";
return *it->second; return *it->second;
} }
} }
...@@ -82,7 +85,7 @@ std::vector<std::string> Registry::ListNames() { ...@@ -82,7 +85,7 @@ std::vector<std::string> Registry::ListNames() {
std::lock_guard<std::mutex> lock(m->mutex); std::lock_guard<std::mutex> lock(m->mutex);
std::vector<std::string> keys; std::vector<std::string> keys;
keys.reserve(m->fmap.size()); keys.reserve(m->fmap.size());
for (const auto &kv : m->fmap) { for (const auto& kv : m->fmap) {
keys.push_back(kv.first); keys.push_back(kv.first);
} }
return keys; return keys;
...@@ -92,8 +95,7 @@ ExtTypeVTable* ExtTypeVTable::Get(int type_code) { ...@@ -92,8 +95,7 @@ ExtTypeVTable* ExtTypeVTable::Get(int type_code) {
CHECK(type_code > kExtBegin && type_code < kExtEnd); CHECK(type_code > kExtBegin && type_code < kExtEnd);
Registry::Manager* m = Registry::Manager::Global(); Registry::Manager* m = Registry::Manager::Global();
ExtTypeVTable* vt = &(m->ext_vtable[type_code]); ExtTypeVTable* vt = &(m->ext_vtable[type_code]);
CHECK(vt->destroy != nullptr) CHECK(vt->destroy != nullptr) << "Extension type not registered";
<< "Extension type not registered";
return vt; return vt;
} }
...@@ -114,7 +116,7 @@ struct DGLFuncThreadLocalEntry { ...@@ -114,7 +116,7 @@ struct DGLFuncThreadLocalEntry {
/*! \brief result holder for returning strings */ /*! \brief result holder for returning strings */
std::vector<std::string> ret_vec_str; std::vector<std::string> ret_vec_str;
/*! \brief result holder for returning string pointers */ /*! \brief result holder for returning string pointers */
std::vector<const char *> ret_vec_charp; std::vector<const char*> ret_vec_charp;
}; };
/*! \brief Thread local store that can be used to hold return values. */ /*! \brief Thread local store that can be used to hold return values. */
...@@ -126,8 +128,7 @@ int DGLExtTypeFree(void* handle, int type_code) { ...@@ -126,8 +128,7 @@ int DGLExtTypeFree(void* handle, int type_code) {
API_END(); API_END();
} }
int DGLFuncRegisterGlobal( int DGLFuncRegisterGlobal(const char* name, DGLFunctionHandle f, int override) {
const char* name, DGLFunctionHandle f, int override) {
API_BEGIN(); API_BEGIN();
dgl::runtime::Registry::Register(name, override != 0) dgl::runtime::Registry::Register(name, override != 0)
.set_body(*static_cast<dgl::runtime::PackedFunc*>(f)); .set_body(*static_cast<dgl::runtime::PackedFunc*>(f));
...@@ -136,8 +137,7 @@ int DGLFuncRegisterGlobal( ...@@ -136,8 +137,7 @@ int DGLFuncRegisterGlobal(
int DGLFuncGetGlobal(const char* name, DGLFunctionHandle* out) { int DGLFuncGetGlobal(const char* name, DGLFunctionHandle* out) {
API_BEGIN(); API_BEGIN();
const dgl::runtime::PackedFunc* fp = const dgl::runtime::PackedFunc* fp = dgl::runtime::Registry::Get(name);
dgl::runtime::Registry::Get(name);
if (fp != nullptr) { if (fp != nullptr) {
*out = new dgl::runtime::PackedFunc(*fp); // NOLINT(*) *out = new dgl::runtime::PackedFunc(*fp); // NOLINT(*)
} else { } else {
...@@ -146,10 +146,9 @@ int DGLFuncGetGlobal(const char* name, DGLFunctionHandle* out) { ...@@ -146,10 +146,9 @@ int DGLFuncGetGlobal(const char* name, DGLFunctionHandle* out) {
API_END(); API_END();
} }
int DGLFuncListGlobalNames(int *out_size, int DGLFuncListGlobalNames(int* out_size, const char*** out_array) {
const char*** out_array) {
API_BEGIN(); API_BEGIN();
DGLFuncThreadLocalEntry *ret = DGLFuncThreadLocalStore::Get(); DGLFuncThreadLocalEntry* ret = DGLFuncThreadLocalStore::Get();
ret->ret_vec_str = dgl::runtime::Registry::ListNames(); ret->ret_vec_str = dgl::runtime::Registry::ListNames();
ret->ret_vec_charp.clear(); ret->ret_vec_charp.clear();
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
......
...@@ -14,9 +14,10 @@ namespace dgl { ...@@ -14,9 +14,10 @@ namespace dgl {
namespace runtime { namespace runtime {
/* /*
* The runtime allocates resources during the computation. Some of the resources cannot be * The runtime allocates resources during the computation. Some of the resources
* destroyed after the process exits especially when the process doesn't exits normally. * cannot be destroyed after the process exits especially when the process
* We need to keep track of the resources in the system and clean them up properly. * doesn't exits normally. We need to keep track of the resources in the system
* and clean them up properly.
*/ */
class ResourceManager { class ResourceManager {
std::unordered_map<std::string, std::shared_ptr<Resource>> resources; std::unordered_map<std::string, std::shared_ptr<Resource>> resources;
...@@ -25,12 +26,11 @@ class ResourceManager { ...@@ -25,12 +26,11 @@ class ResourceManager {
void Add(const std::string &key, std::shared_ptr<Resource> resource) { void Add(const std::string &key, std::shared_ptr<Resource> resource) {
auto it = resources.find(key); auto it = resources.find(key);
CHECK(it == resources.end()) << key << " already exists"; CHECK(it == resources.end()) << key << " already exists";
resources.insert(std::pair<std::string, std::shared_ptr<Resource>>(key, resource)); resources.insert(
std::pair<std::string, std::shared_ptr<Resource>>(key, resource));
} }
void Erase(const std::string &key) { void Erase(const std::string &key) { resources.erase(key); }
resources.erase(key);
}
void Cleanup() { void Cleanup() {
for (auto it = resources.begin(); it != resources.end(); it++) { for (auto it = resources.begin(); it != resources.end(); it++) {
...@@ -46,13 +46,9 @@ void AddResource(const std::string &key, std::shared_ptr<Resource> resource) { ...@@ -46,13 +46,9 @@ void AddResource(const std::string &key, std::shared_ptr<Resource> resource) {
manager.Add(key, resource); manager.Add(key, resource);
} }
void DeleteResource(const std::string &key) { void DeleteResource(const std::string &key) { manager.Erase(key); }
manager.Erase(key);
}
void CleanupResources() { void CleanupResources() { manager.Cleanup(); }
manager.Cleanup();
}
} // namespace runtime } // namespace runtime
} // namespace dgl } // namespace dgl
...@@ -6,24 +6,23 @@ ...@@ -6,24 +6,23 @@
#ifndef DGL_RUNTIME_RESOURCE_MANAGER_H_ #ifndef DGL_RUNTIME_RESOURCE_MANAGER_H_
#define DGL_RUNTIME_RESOURCE_MANAGER_H_ #define DGL_RUNTIME_RESOURCE_MANAGER_H_
#include <unordered_map>
#include <string>
#include <memory> #include <memory>
#include <string>
#include <unordered_map>
namespace dgl { namespace dgl {
namespace runtime { namespace runtime {
/* /*
* A class that provides the interface to describe a resource that can be managed by * A class that provides the interface to describe a resource that can be
* a resource manager. Some of the resources cannot be free'd automatically when * managed by a resource manager. Some of the resources cannot be free'd
* the process exits, especially when the process doesn't exit normally. One example * automatically when the process exits, especially when the process doesn't
* is shared memory. We can keep track of this kind of resources and manage them * exit normally. One example is shared memory. We can keep track of this kind
* properly. * of resources and manage them properly.
*/ */
class Resource { class Resource {
public: public:
virtual ~Resource() { virtual ~Resource() {}
}
virtual void Destroy() = 0; virtual void Destroy() = 0;
}; };
......
...@@ -7,19 +7,32 @@ ...@@ -7,19 +7,32 @@
#define DGL_RUNTIME_RUNTIME_BASE_H_ #define DGL_RUNTIME_RUNTIME_BASE_H_
#include <dgl/runtime/c_runtime_api.h> #include <dgl/runtime/c_runtime_api.h>
#include <stdexcept> #include <stdexcept>
/*! \brief macro to guard beginning and end section of all functions */ /*! \brief macro to guard beginning and end section of all functions */
#define API_BEGIN() try { #define API_BEGIN() try {
/*! \brief every function starts with API_BEGIN(); /*! \brief every function starts with API_BEGIN();
and finishes with API_END() or API_END_HANDLE_ERROR */ and finishes with API_END() or API_END_HANDLE_ERROR */
#define API_END() } catch(std::runtime_error &_except_) { return DGLAPIHandleException(_except_); } return 0; // NOLINT(*) #define API_END() \
} \
catch (std::runtime_error & _except_) { \
return DGLAPIHandleException(_except_); \
} \
return 0; // NOLINT(*)
/*! /*!
* \brief every function starts with API_BEGIN(); * \brief every function starts with API_BEGIN();
* and finishes with API_END() or API_END_HANDLE_ERROR * and finishes with API_END() or API_END_HANDLE_ERROR
* The finally clause contains procedure to cleanup states when an error happens. * The finally clause contains procedure to cleanup states when an error
* happens.
*/ */
#define API_END_HANDLE_ERROR(Finalize) } catch(std::runtime_error &_except_) { Finalize; return DGLAPIHandleException(_except_); } return 0; // NOLINT(*) #define API_END_HANDLE_ERROR(Finalize) \
} \
catch (std::runtime_error & _except_) { \
Finalize; \
return DGLAPIHandleException(_except_); \
} \
return 0; // NOLINT(*)
/*! /*!
* \brief handle exception throwed out * \brief handle exception throwed out
......
...@@ -25,9 +25,7 @@ Semaphore::Semaphore() { ...@@ -25,9 +25,7 @@ Semaphore::Semaphore() {
} }
} }
void Semaphore::Wait() { void Semaphore::Wait() { WaitForSingleObject(sem_, INFINITE); }
WaitForSingleObject(sem_, INFINITE);
}
bool Semaphore::TimedWait(int) { bool Semaphore::TimedWait(int) {
// Timed wait is not supported on WIN32. // Timed wait is not supported on WIN32.
...@@ -35,19 +33,13 @@ bool Semaphore::TimedWait(int) { ...@@ -35,19 +33,13 @@ bool Semaphore::TimedWait(int) {
return true; return true;
} }
void Semaphore::Post() { void Semaphore::Post() { ReleaseSemaphore(sem_, 1, nullptr); }
ReleaseSemaphore(sem_, 1, nullptr);
}
#else #else
Semaphore::Semaphore() { Semaphore::Semaphore() { sem_init(&sem_, 0, 0); }
sem_init(&sem_, 0, 0);
}
void Semaphore::Wait() { void Semaphore::Wait() { sem_wait(&sem_); }
sem_wait(&sem_);
}
bool Semaphore::TimedWait(int timeout) { bool Semaphore::TimedWait(int timeout) {
// sem_timedwait does not exist in Mac OS. // sem_timedwait does not exist in Mac OS.
...@@ -92,9 +84,7 @@ bool Semaphore::TimedWait(int timeout) { ...@@ -92,9 +84,7 @@ bool Semaphore::TimedWait(int timeout) {
return true; return true;
} }
void Semaphore::Post() { void Semaphore::Post() { sem_post(&sem_); }
sem_post(&sem_);
}
#endif #endif
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
#include <semaphore.h> #include <semaphore.h>
#endif #endif
namespace dgl { namespace dgl {
namespace runtime { namespace runtime {
...@@ -31,7 +30,8 @@ class Semaphore { ...@@ -31,7 +30,8 @@ class Semaphore {
void Wait(); void Wait();
/*! /*!
* \brief timed wait, decrease semaphore by 1 or returns if times out * \brief timed wait, decrease semaphore by 1 or returns if times out
* \param timeout The timeout value in milliseconds. If zero, wait indefinitely. * \param timeout The timeout value in milliseconds. If zero, wait
* indefinitely.
*/ */
bool TimedWait(int timeout); bool TimedWait(int timeout);
/*! /*!
......
...@@ -4,14 +4,14 @@ ...@@ -4,14 +4,14 @@
* \brief Shared memory management. * \brief Shared memory management.
*/ */
#ifndef _WIN32 #ifndef _WIN32
#include <sys/mman.h>
#include <fcntl.h> #include <fcntl.h>
#include <sys/mman.h>
#include <unistd.h> #include <unistd.h>
#endif #endif
#include <dgl/runtime/shared_mem.h>
#include <dmlc/logging.h>
#include <stdio.h> #include <stdio.h>
#include <string.h> #include <string.h>
#include <dmlc/logging.h>
#include <dgl/runtime/shared_mem.h>
#include "resource_manager.h" #include "resource_manager.h"
...@@ -22,21 +22,19 @@ namespace runtime { ...@@ -22,21 +22,19 @@ namespace runtime {
* Shared memory is a resource that cannot be cleaned up if the process doesn't * Shared memory is a resource that cannot be cleaned up if the process doesn't
* exit normally. We'll manage the resource with ResourceManager. * exit normally. We'll manage the resource with ResourceManager.
*/ */
class SharedMemoryResource: public Resource { class SharedMemoryResource : public Resource {
std::string name; std::string name;
public: public:
explicit SharedMemoryResource(const std::string &name) { explicit SharedMemoryResource(const std::string &name) { this->name = name; }
this->name = name;
}
void Destroy() { void Destroy() {
// LOG(INFO) << "remove " << name << " for shared memory"; // LOG(INFO) << "remove " << name << " for shared memory";
#ifndef _WIN32 #ifndef _WIN32
shm_unlink(name.c_str()); shm_unlink(name.c_str());
#else // _WIN32 #else // _WIN32
// NOTHING; Windows automatically removes the shared memory object once all handles // NOTHING; Windows automatically removes the shared memory object once all
// are unmapped. // handles are unmapped.
#endif #endif
} }
}; };
...@@ -55,24 +53,21 @@ SharedMemory::SharedMemory(const std::string &name) { ...@@ -55,24 +53,21 @@ SharedMemory::SharedMemory(const std::string &name) {
SharedMemory::~SharedMemory() { SharedMemory::~SharedMemory() {
#ifndef _WIN32 #ifndef _WIN32
if (ptr_ && size_ != 0) if (ptr_ && size_ != 0) CHECK(munmap(ptr_, size_) != -1) << strerror(errno);
CHECK(munmap(ptr_, size_) != -1) << strerror(errno); if (fd_ != -1) close(fd_);
if (fd_ != -1)
close(fd_);
if (own_) { if (own_) {
// LOG(INFO) << "remove " << name << " for shared memory"; // LOG(INFO) << "remove " << name << " for shared memory";
if (name != "") { if (name != "") {
shm_unlink(name.c_str()); shm_unlink(name.c_str());
// The resource has been deleted. We don't need to keep track of it any more. // The resource has been deleted. We don't need to keep track of it any
// more.
DeleteResource(name); DeleteResource(name);
} }
} }
#else #else
if (ptr_) if (ptr_) CHECK(UnmapViewOfFile(ptr_)) << "Win32 Error: " << GetLastError();
CHECK(UnmapViewOfFile(ptr_)) << "Win32 Error: " << GetLastError(); if (handle_) CloseHandle(handle_);
if (handle_) // Windows do not need a separate shm_unlink step.
CloseHandle(handle_);
// Windows do not need a separate shm_unlink step.
#endif // _WIN32 #endif // _WIN32
} }
...@@ -82,28 +77,26 @@ void *SharedMemory::CreateNew(size_t sz) { ...@@ -82,28 +77,26 @@ void *SharedMemory::CreateNew(size_t sz) {
// We need to create a shared-memory file. // We need to create a shared-memory file.
// TODO(zhengda) we need to report error if the shared-memory file exists. // TODO(zhengda) we need to report error if the shared-memory file exists.
int flag = O_RDWR|O_CREAT; int flag = O_RDWR | O_CREAT;
fd_ = shm_open(name.c_str(), flag, S_IRUSR | S_IWUSR); fd_ = shm_open(name.c_str(), flag, S_IRUSR | S_IWUSR);
CHECK_NE(fd_, -1) << "fail to open " << name << ": " << strerror(errno); CHECK_NE(fd_, -1) << "fail to open " << name << ": " << strerror(errno);
// Shared memory cannot be deleted if the process exits abnormally in Linux. // Shared memory cannot be deleted if the process exits abnormally in Linux.
AddResource(name, std::shared_ptr<Resource>(new SharedMemoryResource(name))); AddResource(name, std::shared_ptr<Resource>(new SharedMemoryResource(name)));
auto res = ftruncate(fd_, sz); auto res = ftruncate(fd_, sz);
CHECK_NE(res, -1) CHECK_NE(res, -1) << "Failed to truncate the file. " << strerror(errno);
<< "Failed to truncate the file. " << strerror(errno); ptr_ = mmap(NULL, sz, PROT_READ | PROT_WRITE, MAP_SHARED, fd_, 0);
ptr_ = mmap(NULL, sz, PROT_READ|PROT_WRITE, MAP_SHARED, fd_, 0);
CHECK_NE(ptr_, MAP_FAILED) CHECK_NE(ptr_, MAP_FAILED)
<< "Failed to map shared memory. mmap failed with error " << strerror(errno); << "Failed to map shared memory. mmap failed with error "
<< strerror(errno);
this->size_ = sz; this->size_ = sz;
return ptr_; return ptr_;
#else #else
handle_ = CreateFileMapping( handle_ = CreateFileMapping(
INVALID_HANDLE_VALUE, INVALID_HANDLE_VALUE, nullptr, PAGE_READWRITE,
nullptr, static_cast<DWORD>(sz >> 32), static_cast<DWORD>(sz & 0xFFFFFFFF),
PAGE_READWRITE,
static_cast<DWORD>(sz >> 32),
static_cast<DWORD>(sz & 0xFFFFFFFF),
name.c_str()); name.c_str());
CHECK(handle_ != nullptr) << "fail to open " << name << ", Win32 error: " << GetLastError(); CHECK(handle_ != nullptr)
<< "fail to open " << name << ", Win32 error: " << GetLastError();
ptr_ = MapViewOfFile(handle_, FILE_MAP_ALL_ACCESS, 0, 0, sz); ptr_ = MapViewOfFile(handle_, FILE_MAP_ALL_ACCESS, 0, 0, sz);
if (ptr_ == nullptr) { if (ptr_ == nullptr) {
LOG(FATAL) << "Memory mapping failed, Win32 error: " << GetLastError(); LOG(FATAL) << "Memory mapping failed, Win32 error: " << GetLastError();
...@@ -120,14 +113,16 @@ void *SharedMemory::Open(size_t sz) { ...@@ -120,14 +113,16 @@ void *SharedMemory::Open(size_t sz) {
int flag = O_RDWR; int flag = O_RDWR;
fd_ = shm_open(name.c_str(), flag, S_IRUSR | S_IWUSR); fd_ = shm_open(name.c_str(), flag, S_IRUSR | S_IWUSR);
CHECK_NE(fd_, -1) << "fail to open " << name << ": " << strerror(errno); CHECK_NE(fd_, -1) << "fail to open " << name << ": " << strerror(errno);
ptr_ = mmap(NULL, sz, PROT_READ|PROT_WRITE, MAP_SHARED, fd_, 0); ptr_ = mmap(NULL, sz, PROT_READ | PROT_WRITE, MAP_SHARED, fd_, 0);
CHECK_NE(ptr_, MAP_FAILED) CHECK_NE(ptr_, MAP_FAILED)
<< "Failed to map shared memory. mmap failed with error " << strerror(errno); << "Failed to map shared memory. mmap failed with error "
<< strerror(errno);
this->size_ = sz; this->size_ = sz;
return ptr_; return ptr_;
#else #else
handle_ = OpenFileMapping(FILE_MAP_ALL_ACCESS, FALSE, name.c_str()); handle_ = OpenFileMapping(FILE_MAP_ALL_ACCESS, FALSE, name.c_str());
CHECK(handle_ != nullptr) << "fail to open " << name << ", Win32 Error: " << GetLastError(); CHECK(handle_ != nullptr)
<< "fail to open " << name << ", Win32 Error: " << GetLastError();
ptr_ = MapViewOfFile(handle_, FILE_MAP_ALL_ACCESS, 0, 0, sz); ptr_ = MapViewOfFile(handle_, FILE_MAP_ALL_ACCESS, 0, 0, sz);
if (ptr_ == nullptr) { if (ptr_ == nullptr) {
LOG(FATAL) << "Memory mapping failed, Win32 error: " << GetLastError(); LOG(FATAL) << "Memory mapping failed, Win32 error: " << GetLastError();
......
...@@ -3,9 +3,11 @@ ...@@ -3,9 +3,11 @@
* \file system_lib_module.cc * \file system_lib_module.cc
* \brief SystemLib module. * \brief SystemLib module.
*/ */
#include <dgl/runtime/registry.h>
#include <dgl/runtime/c_backend_api.h> #include <dgl/runtime/c_backend_api.h>
#include <dgl/runtime/registry.h>
#include <mutex> #include <mutex>
#include "module_util.h" #include "module_util.h"
namespace dgl { namespace dgl {
...@@ -15,9 +17,7 @@ class SystemLibModuleNode : public ModuleNode { ...@@ -15,9 +17,7 @@ class SystemLibModuleNode : public ModuleNode {
public: public:
SystemLibModuleNode() = default; SystemLibModuleNode() = default;
const char* type_key() const final { const char* type_key() const final { return "system_lib"; }
return "system_lib";
}
PackedFunc GetFunction( PackedFunc GetFunction(
const std::string& name, const std::string& name,
...@@ -57,8 +57,8 @@ class SystemLibModuleNode : public ModuleNode { ...@@ -57,8 +57,8 @@ class SystemLibModuleNode : public ModuleNode {
auto it = tbl_.find(name); auto it = tbl_.find(name);
if (it != tbl_.end() && ptr != it->second) { if (it != tbl_.end() && ptr != it->second) {
LOG(WARNING) << "SystemLib symbol " << name LOG(WARNING) << "SystemLib symbol " << name
<< " get overriden to a different address " << " get overriden to a different address " << ptr << "->"
<< ptr << "->" << it->second; << it->second;
} }
tbl_[name] = ptr; tbl_[name] = ptr;
} }
...@@ -80,9 +80,9 @@ class SystemLibModuleNode : public ModuleNode { ...@@ -80,9 +80,9 @@ class SystemLibModuleNode : public ModuleNode {
}; };
DGL_REGISTER_GLOBAL("module._GetSystemLib") DGL_REGISTER_GLOBAL("module._GetSystemLib")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
*rv = runtime::Module(SystemLibModuleNode::Global()); *rv = runtime::Module(SystemLibModuleNode::Global());
}); });
} // namespace runtime } // namespace runtime
} // namespace dgl } // namespace dgl
......
...@@ -4,12 +4,12 @@ ...@@ -4,12 +4,12 @@
* \brief Adapter library caller * \brief Adapter library caller
*/ */
#include <dgl/runtime/tensordispatch.h>
#include <dgl/runtime/registry.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/tensordispatch.h>
#if defined(WIN32) || defined(_WIN32) #if defined(WIN32) || defined(_WIN32)
#include <windows.h> #include <windows.h>
#else // !WIN32 #else // !WIN32
#include <dlfcn.h> #include <dlfcn.h>
#endif // WIN32 #endif // WIN32
#include <cstring> #include <cstring>
...@@ -23,25 +23,27 @@ bool TensorDispatcher::Load(const char *path) { ...@@ -23,25 +23,27 @@ bool TensorDispatcher::Load(const char *path) {
CHECK(!available_) << "The tensor adapter can only load once."; CHECK(!available_) << "The tensor adapter can only load once.";
if (path == nullptr || strlen(path) == 0) if (path == nullptr || strlen(path) == 0)
// does not have dispatcher library; all operators fall back to DGL's implementation // does not have dispatcher library; all operators fall back to DGL's
// implementation
return false; return false;
#if defined(WIN32) || defined(_WIN32) #if defined(WIN32) || defined(_WIN32)
handle_ = LoadLibrary(path); handle_ = LoadLibrary(path);
if (!handle_) if (!handle_) return false;
return false;
for (int i = 0; i < num_entries_; ++i) { for (int i = 0; i < num_entries_; ++i) {
entrypoints_[i] = reinterpret_cast<void*>(GetProcAddress(handle_, names_[i])); entrypoints_[i] =
reinterpret_cast<void *>(GetProcAddress(handle_, names_[i]));
CHECK(entrypoints_[i]) << "cannot locate symbol " << names_[i]; CHECK(entrypoints_[i]) << "cannot locate symbol " << names_[i];
} }
#else // !WIN32 #else // !WIN32
handle_ = dlopen(path, RTLD_LAZY); handle_ = dlopen(path, RTLD_LAZY);
if (!handle_) { if (!handle_) {
DLOG(WARNING) << "Could not open file: " << dlerror() DLOG(WARNING)
<< ". This does not affect DGL's but might impact its performance."; << "Could not open file: " << dlerror()
<< ". This does not affect DGL's but might impact its performance.";
return false; return false;
} }
......
...@@ -3,23 +3,24 @@ ...@@ -3,23 +3,24 @@
* \file thread_pool.cc * \file thread_pool.cc
* \brief Threadpool for multi-threading runtime. * \brief Threadpool for multi-threading runtime.
*/ */
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/c_backend_api.h> #include <dgl/runtime/c_backend_api.h>
#include <dgl/runtime/registry.h> #include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/packed_func.h> #include <dgl/runtime/packed_func.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/threading_backend.h> #include <dgl/runtime/threading_backend.h>
#include <dmlc/thread_local.h>
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <thread> #include <dmlc/thread_local.h>
#include <condition_variable>
#include <mutex>
#include <atomic>
#include <algorithm> #include <algorithm>
#include <vector> #include <atomic>
#include <string> #include <condition_variable>
#include <cstring> #include <cstring>
#include <memory> #include <memory>
#include <mutex>
#include <sstream> #include <sstream>
#include <string>
#include <thread>
#include <vector>
const constexpr int kL1CacheBytes = 64; const constexpr int kL1CacheBytes = 64;
...@@ -35,10 +36,8 @@ constexpr int kSyncStride = 64 / sizeof(std::atomic<int>); ...@@ -35,10 +36,8 @@ constexpr int kSyncStride = 64 / sizeof(std::atomic<int>);
class ParallelLauncher { class ParallelLauncher {
public: public:
// Reset the the task request. // Reset the the task request.
void Init(FDGLParallelLambda flambda, void Init(
void* cdata, FDGLParallelLambda flambda, void* cdata, int num_task, bool need_sync) {
int num_task,
bool need_sync) {
num_pending_.store(num_task); num_pending_.store(num_task);
this->cdata = cdata; this->cdata = cdata;
this->flambda = flambda; this->flambda = flambda;
...@@ -54,17 +53,14 @@ class ParallelLauncher { ...@@ -54,17 +53,14 @@ class ParallelLauncher {
} }
if (need_sync) { if (need_sync) {
for (int i = 0; i < num_task; ++i) { for (int i = 0; i < num_task; ++i) {
sync_counter_[i * kSyncStride].store( sync_counter_[i * kSyncStride].store(0, std::memory_order_relaxed);
0, std::memory_order_relaxed);
} }
this->env.sync_handle = sync_counter_; this->env.sync_handle = sync_counter_;
} else { } else {
this->env.sync_handle = nullptr; this->env.sync_handle = nullptr;
} }
} }
~ParallelLauncher() { ~ParallelLauncher() { delete[] sync_counter_; }
delete[] sync_counter_;
}
// Wait n jobs to finish // Wait n jobs to finish
int WaitForJobs() { int WaitForJobs() {
while (num_pending_.load() != 0) { while (num_pending_.load() != 0) {
...@@ -90,9 +86,7 @@ class ParallelLauncher { ...@@ -90,9 +86,7 @@ class ParallelLauncher {
has_error_.store(true); has_error_.store(true);
} }
// Signal that one job has finished. // Signal that one job has finished.
void SignalJobFinish() { void SignalJobFinish() { num_pending_.fetch_sub(1); }
num_pending_.fetch_sub(1);
}
// Get thread local version of the store. // Get thread local version of the store.
static ParallelLauncher* ThreadLocal() { static ParallelLauncher* ThreadLocal() {
return dmlc::ThreadLocalStore<ParallelLauncher>::Get(); return dmlc::ThreadLocalStore<ParallelLauncher>::Get();
...@@ -127,15 +121,9 @@ class SpscTaskQueue { ...@@ -127,15 +121,9 @@ class SpscTaskQueue {
int32_t task_id; int32_t task_id;
}; };
SpscTaskQueue() : SpscTaskQueue() : buffer_(new Task[kRingSize]), head_(0), tail_(0) {}
buffer_(new Task[kRingSize]),
head_(0),
tail_(0) {
}
~SpscTaskQueue() { ~SpscTaskQueue() { delete[] buffer_; }
delete[] buffer_;
}
/*! /*!
* \brief Push a task into the queue and notify the comsumer if it is on wait. * \brief Push a task into the queue and notify the comsumer if it is on wait.
...@@ -159,16 +147,16 @@ class SpscTaskQueue { ...@@ -159,16 +147,16 @@ class SpscTaskQueue {
*/ */
bool Pop(Task* output, uint32_t spin_count = 300000) { bool Pop(Task* output, uint32_t spin_count = 300000) {
// Busy wait a bit when the queue is empty. // Busy wait a bit when the queue is empty.
// If a new task comes to the queue quickly, this wait avoid the worker from sleeping. // If a new task comes to the queue quickly, this wait avoid the worker from
// The default spin count is set by following the typical omp convention // sleeping. The default spin count is set by following the typical omp
// convention
for (uint32_t i = 0; i < spin_count && pending_.load() == 0; ++i) { for (uint32_t i = 0; i < spin_count && pending_.load() == 0; ++i) {
dgl::runtime::threading::YieldThread(); dgl::runtime::threading::YieldThread();
} }
if (pending_.fetch_sub(1) == 0) { if (pending_.fetch_sub(1) == 0) {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [this] { cv_.wait(
return pending_.load() >= 0 || exit_now_.load(); lock, [this] { return pending_.load() >= 0 || exit_now_.load(); });
});
} }
if (exit_now_.load(std::memory_order_relaxed)) { if (exit_now_.load(std::memory_order_relaxed)) {
return false; return false;
...@@ -209,7 +197,8 @@ class SpscTaskQueue { ...@@ -209,7 +197,8 @@ class SpscTaskQueue {
return false; return false;
} }
// the cache line paddings are used for avoid false sharing between atomic variables // the cache line paddings are used for avoid false sharing between atomic
// variables
typedef char cache_line_pad_t[kL1CacheBytes]; typedef char cache_line_pad_t[kL1CacheBytes];
cache_line_pad_t pad0_; cache_line_pad_t pad0_;
// size of the queue, the queue can host size_ - 1 items at most // size of the queue, the queue can host size_ - 1 items at most
...@@ -243,16 +232,17 @@ class SpscTaskQueue { ...@@ -243,16 +232,17 @@ class SpscTaskQueue {
// The thread pool // The thread pool
class ThreadPool { class ThreadPool {
public: public:
ThreadPool(): num_workers_(dgl::runtime::threading::MaxConcurrency()) { ThreadPool() : num_workers_(dgl::runtime::threading::MaxConcurrency()) {
for (int i = 0; i < num_workers_; ++i) { for (int i = 0; i < num_workers_; ++i) {
// The SpscTaskQueue only hosts ONE item at a time // The SpscTaskQueue only hosts ONE item at a time
queues_.emplace_back(std::unique_ptr<SpscTaskQueue>(new SpscTaskQueue())); queues_.emplace_back(std::unique_ptr<SpscTaskQueue>(new SpscTaskQueue()));
} }
threads_ = std::unique_ptr<dgl::runtime::threading::ThreadGroup>( threads_ = std::unique_ptr<dgl::runtime::threading::ThreadGroup>(
new dgl::runtime::threading::ThreadGroup( new dgl::runtime::threading::ThreadGroup(
num_workers_, [this](int worker_id) { this->RunWorker(worker_id); }, num_workers_, [this](int worker_id) { this->RunWorker(worker_id); },
exclude_worker0_ /* include_main_thread */)); exclude_worker0_ /* include_main_thread */));
num_workers_used_ = threads_->Configure(threading::ThreadGroup::kBig, 0, exclude_worker0_); num_workers_used_ =
threads_->Configure(threading::ThreadGroup::kBig, 0, exclude_worker0_);
} }
~ThreadPool() { ~ThreadPool() {
for (std::unique_ptr<SpscTaskQueue>& q : queues_) { for (std::unique_ptr<SpscTaskQueue>& q : queues_) {
...@@ -260,13 +250,11 @@ class ThreadPool { ...@@ -260,13 +250,11 @@ class ThreadPool {
} }
threads_.reset(); threads_.reset();
} }
int Launch(FDGLParallelLambda flambda, int Launch(
void* cdata, FDGLParallelLambda flambda, void* cdata, int num_task, int need_sync) {
int num_task,
int need_sync) {
ParallelLauncher* launcher = ParallelLauncher::ThreadLocal(); ParallelLauncher* launcher = ParallelLauncher::ThreadLocal();
CHECK(!launcher->is_worker) CHECK(!launcher->is_worker) << "Cannot launch parallel job inside worker, "
<< "Cannot launch parallel job inside worker, consider fuse then parallel"; "consider fuse then parallel";
if (num_task == 0) { if (num_task == 0) {
num_task = num_workers_used_; num_task = num_workers_used_;
} }
...@@ -300,11 +288,11 @@ class ThreadPool { ...@@ -300,11 +288,11 @@ class ThreadPool {
return dmlc::ThreadLocalStore<ThreadPool>::Get(); return dmlc::ThreadLocalStore<ThreadPool>::Get();
} }
void UpdateWorkerConfiguration(threading::ThreadGroup::AffinityMode mode, int nthreads) { void UpdateWorkerConfiguration(
threading::ThreadGroup::AffinityMode mode, int nthreads) {
// this will also reset the affinity of the ThreadGroup // this will also reset the affinity of the ThreadGroup
// may use less than the MaxConcurrency number of workers // may use less than the MaxConcurrency number of workers
num_workers_used_ = threads_->Configure(mode, nthreads, num_workers_used_ = threads_->Configure(mode, nthreads, exclude_worker0_);
exclude_worker0_);
// if MaxConcurrency restricted the number of workers (e.g., due to // if MaxConcurrency restricted the number of workers (e.g., due to
// hyperthreading), respect the restriction // hyperthreading), respect the restriction
num_workers_used_ = std::min(num_workers_, num_workers_used_); num_workers_used_ = std::min(num_workers_, num_workers_used_);
...@@ -341,23 +329,19 @@ class ThreadPool { ...@@ -341,23 +329,19 @@ class ThreadPool {
}; };
DGL_REGISTER_GLOBAL("runtime.config_threadpool") DGL_REGISTER_GLOBAL("runtime.config_threadpool")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
threading::ThreadGroup::AffinityMode mode =\ threading::ThreadGroup::AffinityMode mode =
static_cast<threading::ThreadGroup::AffinityMode>(\ static_cast<threading::ThreadGroup::AffinityMode>(
static_cast<int>(args[0])); static_cast<int>(args[0]));
int nthreads = args[1]; int nthreads = args[1];
ThreadPool::ThreadLocal()->UpdateWorkerConfiguration(mode, nthreads); ThreadPool::ThreadLocal()->UpdateWorkerConfiguration(mode, nthreads);
}); });
} // namespace runtime } // namespace runtime
} // namespace dgl } // namespace dgl
int DGLBackendParallelLaunch( int DGLBackendParallelLaunch(
FDGLParallelLambda flambda, FDGLParallelLambda flambda, void* cdata, int num_task) {
void* cdata,
int num_task) {
int res = dgl::runtime::ThreadPool::ThreadLocal()->Launch( int res = dgl::runtime::ThreadPool::ThreadLocal()->Launch(
flambda, cdata, num_task, 1); flambda, cdata, num_task, 1);
return res; return res;
...@@ -372,8 +356,8 @@ int DGLBackendParallelBarrier(int task_id, DGLParallelGroupEnv* penv) { ...@@ -372,8 +356,8 @@ int DGLBackendParallelBarrier(int task_id, DGLParallelGroupEnv* penv) {
1, std::memory_order_release); 1, std::memory_order_release);
for (int i = 0; i < num_task; ++i) { for (int i = 0; i < num_task; ++i) {
if (i != task_id) { if (i != task_id) {
while (sync_counter[i * kSyncStride].load( while (sync_counter[i * kSyncStride].load(std::memory_order_relaxed) <=
std::memory_order_relaxed) <= old_counter) { old_counter) {
dgl::runtime::threading::YieldThread(); dgl::runtime::threading::YieldThread();
} }
} }
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#define DGL_RUNTIME_THREAD_STORAGE_SCOPE_H_ #define DGL_RUNTIME_THREAD_STORAGE_SCOPE_H_
#include <dgl/runtime/packed_func.h> #include <dgl/runtime/packed_func.h>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -40,9 +41,12 @@ enum class StorageRank { ...@@ -40,9 +41,12 @@ enum class StorageRank {
*/ */
inline StorageRank DefaultStorageRank(int thread_scope_rank) { inline StorageRank DefaultStorageRank(int thread_scope_rank) {
switch (thread_scope_rank) { switch (thread_scope_rank) {
case -1: return StorageRank::kGlobal; case -1:
case 0: return StorageRank::kShared; return StorageRank::kGlobal;
case 1: return StorageRank::kLocal; case 0:
return StorageRank::kShared;
case 1:
return StorageRank::kLocal;
default: { default: {
LOG(FATAL) << "unknown rank"; LOG(FATAL) << "unknown rank";
return StorageRank::kGlobal; return StorageRank::kGlobal;
...@@ -66,11 +70,17 @@ struct StorageScope { ...@@ -66,11 +70,17 @@ struct StorageScope {
inline std::string to_string() const { inline std::string to_string() const {
std::string ret; std::string ret;
switch (rank) { switch (rank) {
case StorageRank::kGlobal: return "global" + tag; case StorageRank::kGlobal:
case StorageRank::kShared: return "shared" + tag; return "global" + tag;
case StorageRank::kWarp: return "warp" + tag; case StorageRank::kShared:
case StorageRank::kLocal: return "local" + tag; return "shared" + tag;
default: LOG(FATAL) << "unknown storage scope"; return ""; case StorageRank::kWarp:
return "warp" + tag;
case StorageRank::kLocal:
return "local" + tag;
default:
LOG(FATAL) << "unknown storage scope";
return "";
} }
} }
/*! /*!
...@@ -80,7 +90,7 @@ struct StorageScope { ...@@ -80,7 +90,7 @@ struct StorageScope {
*/ */
static StorageScope make(const std::string& s) { static StorageScope make(const std::string& s) {
StorageScope r; StorageScope r;
if (s.compare(0, 6, "global") == 0) { if (s.compare(0, 6, "global") == 0) {
r.rank = StorageRank::kGlobal; r.rank = StorageRank::kGlobal;
r.tag = s.substr(6, std::string::npos); r.tag = s.substr(6, std::string::npos);
} else if (s.compare(0, 6, "shared") == 0) { } else if (s.compare(0, 6, "shared") == 0) {
...@@ -129,7 +139,6 @@ struct ThreadScope { ...@@ -129,7 +139,6 @@ struct ThreadScope {
} }
}; };
/*! \brief workload speccification */ /*! \brief workload speccification */
struct ThreadWorkLoad { struct ThreadWorkLoad {
// array, first three are thread configuration. // array, first three are thread configuration.
...@@ -138,22 +147,17 @@ struct ThreadWorkLoad { ...@@ -138,22 +147,17 @@ struct ThreadWorkLoad {
* \param i The block dimension. * \param i The block dimension.
* \return i-th block dim * \return i-th block dim
*/ */
inline size_t block_dim(size_t i) const { inline size_t block_dim(size_t i) const { return work_size[i + 3]; }
return work_size[i + 3];
}
/*! /*!
* \param i The grid dimension. * \param i The grid dimension.
* \return i-th grid dim * \return i-th grid dim
*/ */
inline size_t grid_dim(size_t i) const { inline size_t grid_dim(size_t i) const { return work_size[i]; }
return work_size[i];
}
}; };
/*! \brief Thread axis configuration */ /*! \brief Thread axis configuration */
class ThreadAxisConfig { class ThreadAxisConfig {
public: public:
void Init(size_t base, void Init(size_t base, const std::vector<std::string>& thread_axis_tags) {
const std::vector<std::string>& thread_axis_tags) {
base_ = base; base_ = base;
std::vector<bool> filled(6, false); std::vector<bool> filled(6, false);
for (size_t i = 0; i < thread_axis_tags.size(); ++i) { for (size_t i = 0; i < thread_axis_tags.size(); ++i) {
...@@ -180,9 +184,7 @@ class ThreadAxisConfig { ...@@ -180,9 +184,7 @@ class ThreadAxisConfig {
return w; return w;
} }
// return the work dim // return the work dim
size_t work_dim() const { size_t work_dim() const { return work_dim_; }
return work_dim_;
}
private: private:
/*! \brief base axis */ /*! \brief base axis */
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment