Unverified Commit 401e1278 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] clang-format auto fix. (#4811)



* [Misc] clang-format auto fix.

* fix

* manual
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 6c53f351
......@@ -2,14 +2,15 @@
* Copyright (c) 2017 by Contributors
* \file file_util.cc
*/
#include "file_util.h"
#include <dgl/runtime/serializer.h>
#include <dmlc/json.h>
#include <dmlc/logging.h>
#include <dgl/runtime/serializer.h>
#include <fstream>
#include <vector>
#include <unordered_map>
#include "file_util.h"
#include <vector>
namespace dgl {
namespace runtime {
......@@ -52,8 +53,8 @@ bool FunctionInfo::Load(dmlc::Stream* reader) {
return true;
}
std::string GetFileFormat(const std::string& file_name,
const std::string& format) {
std::string GetFileFormat(
const std::string& file_name, const std::string& format) {
std::string fmt = format;
if (fmt.length() == 0) {
if (file_name.find(".signed.so") != std::string::npos) return "sgx";
......@@ -87,7 +88,7 @@ std::string GetFileBasename(const std::string& file_name) {
}
std::string GetMetaFilePath(const std::string& file_name) {
size_t pos = file_name.find_last_of(".");
size_t pos = file_name.find_last_of(".");
if (pos != std::string::npos) {
return file_name.substr(0, pos) + ".dgl_meta.json";
} else {
......@@ -95,8 +96,7 @@ std::string GetMetaFilePath(const std::string& file_name) {
}
}
void LoadBinaryFromFile(const std::string& file_name,
std::string* data) {
void LoadBinaryFromFile(const std::string& file_name, std::string* data) {
std::ifstream fs(file_name, std::ios::in | std::ios::binary);
CHECK(!fs.fail()) << "Cannot open " << file_name;
// get its size:
......@@ -107,9 +107,7 @@ void LoadBinaryFromFile(const std::string& file_name,
fs.read(&(*data)[0], size);
}
void SaveBinaryToFile(
const std::string& file_name,
const std::string& data) {
void SaveBinaryToFile(const std::string& file_name, const std::string& data) {
std::ofstream fs(file_name, std::ios::out | std::ios::binary);
CHECK(!fs.fail()) << "Cannot open " << file_name;
fs.write(&data[0], data.length());
......
......@@ -8,6 +8,7 @@
#include <string>
#include <unordered_map>
#include "meta_data.h"
namespace dgl {
......@@ -17,8 +18,8 @@ namespace runtime {
* \param file_name The name of the file.
* \param format The format of the file.
*/
std::string GetFileFormat(const std::string& file_name,
const std::string& format);
std::string GetFileFormat(
const std::string& file_name, const std::string& format);
/*!
* \return the directory in which DGL stores cached files.
......@@ -44,16 +45,14 @@ std::string GetFileBasename(const std::string& file_name);
* \param file_name The name of the file.
* \param data The data to be loaded.
*/
void LoadBinaryFromFile(const std::string& file_name,
std::string* data);
void LoadBinaryFromFile(const std::string& file_name, std::string* data);
/*!
* \brief Load binary file into a in-memory buffer.
* \param file_name The name of the file.
* \param data The binary data to be saved.
*/
void SaveBinaryToFile(const std::string& file_name,
const std::string& data);
void SaveBinaryToFile(const std::string& file_name, const std::string& data);
/*!
* \brief Save meta data to file.
......
......@@ -6,11 +6,13 @@
#ifndef DGL_RUNTIME_META_DATA_H_
#define DGL_RUNTIME_META_DATA_H_
#include <dmlc/json.h>
#include <dmlc/io.h>
#include <dgl/runtime/packed_func.h>
#include <dmlc/io.h>
#include <dmlc/json.h>
#include <string>
#include <vector>
#include "runtime_base.h"
namespace dgl {
......
......@@ -4,10 +4,11 @@
* \brief DGL module system
*/
#include <dgl/runtime/module.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/packed_func.h>
#include <unordered_set>
#include <dgl/runtime/registry.h>
#include <cstring>
#include <unordered_set>
#ifndef _LIBCPP_SGX_CONFIG
#include "file_util.h"
#endif
......@@ -44,20 +45,18 @@ void Module::Import(Module other) {
node_->imports_.emplace_back(std::move(other));
}
Module Module::LoadFromFile(const std::string& file_name,
const std::string& format) {
Module Module::LoadFromFile(
const std::string& file_name, const std::string& format) {
#ifndef _LIBCPP_SGX_CONFIG
std::string fmt = GetFileFormat(file_name, format);
CHECK(fmt.length() != 0)
<< "Cannot deduce format of file " << file_name;
CHECK(fmt.length() != 0) << "Cannot deduce format of file " << file_name;
if (fmt == "dll" || fmt == "dylib" || fmt == "dso") {
fmt = "so";
}
std::string load_f_name = "module.loadfile_" + fmt;
const PackedFunc* f = Registry::Get(load_f_name);
CHECK(f != nullptr)
<< "Loader of " << format << "("
<< load_f_name << ") is not presented.";
CHECK(f != nullptr) << "Loader of " << format << "(" << load_f_name
<< ") is not presented.";
Module m = (*f)(file_name, format);
return m;
#else
......@@ -65,8 +64,8 @@ Module Module::LoadFromFile(const std::string& file_name,
#endif
}
void ModuleNode::SaveToFile(const std::string& file_name,
const std::string& format) {
void ModuleNode::SaveToFile(
const std::string& file_name, const std::string& format) {
LOG(FATAL) << "Module[" << type_key() << "] does not support SaveToFile";
}
......@@ -89,9 +88,8 @@ const PackedFunc* ModuleNode::GetFuncFromEnv(const std::string& name) {
}
if (pf == nullptr) {
const PackedFunc* f = Registry::Get(name);
CHECK(f != nullptr)
<< "Cannot find function " << name
<< " in the imported modules or global registry";
CHECK(f != nullptr) << "Cannot find function " << name
<< " in the imported modules or global registry";
return f;
} else {
std::unique_ptr<PackedFunc> f(new PackedFunc(pf));
......@@ -125,7 +123,8 @@ bool RuntimeEnabled(const std::string& target) {
} else if (target.length() >= 4 && target.substr(0, 4) == "rocm") {
f_name = "device_api.rocm";
} else if (target.length() >= 4 && target.substr(0, 4) == "llvm") {
const PackedFunc* pf = runtime::Registry::Get("codegen.llvm_target_enabled");
const PackedFunc* pf =
runtime::Registry::Get("codegen.llvm_target_enabled");
if (pf == nullptr) return false;
return (*pf)(target);
} else {
......@@ -135,41 +134,38 @@ bool RuntimeEnabled(const std::string& target) {
}
DGL_REGISTER_GLOBAL("module._Enabled")
.set_body([](DGLArgs args, DGLRetValue *ret) {
*ret = RuntimeEnabled(args[0]);
.set_body([](DGLArgs args, DGLRetValue* ret) {
*ret = RuntimeEnabled(args[0]);
});
DGL_REGISTER_GLOBAL("module._GetSource")
.set_body([](DGLArgs args, DGLRetValue *ret) {
*ret = args[0].operator Module()->GetSource(args[1]);
.set_body([](DGLArgs args, DGLRetValue* ret) {
*ret = args[0].operator Module()->GetSource(args[1]);
});
DGL_REGISTER_GLOBAL("module._ImportsSize")
.set_body([](DGLArgs args, DGLRetValue *ret) {
*ret = static_cast<int64_t>(
args[0].operator Module()->imports().size());
.set_body([](DGLArgs args, DGLRetValue* ret) {
*ret = static_cast<int64_t>(args[0].operator Module()->imports().size());
});
DGL_REGISTER_GLOBAL("module._GetImport")
.set_body([](DGLArgs args, DGLRetValue *ret) {
*ret = args[0].operator Module()->
imports().at(args[1].operator int());
.set_body([](DGLArgs args, DGLRetValue* ret) {
*ret = args[0].operator Module()->imports().at(args[1].operator int());
});
DGL_REGISTER_GLOBAL("module._GetTypeKey")
.set_body([](DGLArgs args, DGLRetValue *ret) {
*ret = std::string(args[0].operator Module()->type_key());
.set_body([](DGLArgs args, DGLRetValue* ret) {
*ret = std::string(args[0].operator Module()->type_key());
});
DGL_REGISTER_GLOBAL("module._LoadFromFile")
.set_body([](DGLArgs args, DGLRetValue *ret) {
*ret = Module::LoadFromFile(args[0], args[1]);
.set_body([](DGLArgs args, DGLRetValue* ret) {
*ret = Module::LoadFromFile(args[0], args[1]);
});
DGL_REGISTER_GLOBAL("module._SaveToFile")
.set_body([](DGLArgs args, DGLRetValue *ret) {
args[0].operator Module()->
SaveToFile(args[1], args[2]);
.set_body([](DGLArgs args, DGLRetValue* ret) {
args[0].operator Module()->SaveToFile(args[1], args[2]);
});
} // namespace runtime
} // namespace dgl
......@@ -8,8 +8,10 @@
#endif
#include <dgl/runtime/module.h>
#include <dgl/runtime/registry.h>
#include <string>
#include <memory>
#include <string>
#include "module_util.h"
namespace dgl {
......@@ -21,7 +23,7 @@ void ImportModuleBlob(const char* mblob, std::vector<Module>* mlist) {
uint64_t nbytes = 0;
for (size_t i = 0; i < sizeof(nbytes); ++i) {
uint64_t c = mblob[i];
nbytes |= (c & 0xffUL) << (i * 8);
nbytes |= (c & 0xffUL) << (i * 8);
}
dmlc::MemoryFixedSizeStream fs(
const_cast<char*>(mblob + sizeof(nbytes)), static_cast<size_t>(nbytes));
......@@ -33,9 +35,8 @@ void ImportModuleBlob(const char* mblob, std::vector<Module>* mlist) {
CHECK(stream->Read(&tkey));
std::string fkey = "module.loadbinary_" + tkey;
const PackedFunc* f = Registry::Get(fkey);
CHECK(f != nullptr)
<< "Loader of " << tkey << "("
<< fkey << ") is not presented.";
CHECK(f != nullptr) << "Loader of " << tkey << "(" << fkey
<< ") is not presented.";
Module m = (*f)(static_cast<void*>(stream));
mlist->push_back(m);
}
......@@ -44,15 +45,14 @@ void ImportModuleBlob(const char* mblob, std::vector<Module>* mlist) {
#endif
}
PackedFunc WrapPackedFunc(BackendPackedCFunc faddr,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
PackedFunc WrapPackedFunc(
BackendPackedCFunc faddr, const std::shared_ptr<ModuleNode>& sptr_to_self) {
return PackedFunc([faddr, sptr_to_self](DGLArgs args, DGLRetValue* rv) {
int ret = (*faddr)(
const_cast<DGLValue*>(args.values),
const_cast<int*>(args.type_codes),
args.num_args);
CHECK_EQ(ret, 0) << DGLGetLastError();
});
int ret = (*faddr)(
const_cast<DGLValue*>(args.values), const_cast<int*>(args.type_codes),
args.num_args);
CHECK_EQ(ret, 0) << DGLGetLastError();
});
}
} // namespace runtime
......
......@@ -6,17 +6,16 @@
#ifndef DGL_RUNTIME_MODULE_UTIL_H_
#define DGL_RUNTIME_MODULE_UTIL_H_
#include <dgl/runtime/module.h>
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/c_backend_api.h>
#include <vector>
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/module.h>
#include <memory>
#include <vector>
extern "C" {
// Function signature for generated packed function in shared library
typedef int (*BackendPackedCFunc)(void* args,
int* type_codes,
int num_args);
typedef int (*BackendPackedCFunc)(void* args, int* type_codes, int num_args);
} // extern "C"
namespace dgl {
......@@ -26,7 +25,8 @@ namespace runtime {
* \param faddr The function address
* \param mptr The module pointer node.
*/
PackedFunc WrapPackedFunc(BackendPackedCFunc faddr, const std::shared_ptr<ModuleNode>& mptr);
PackedFunc WrapPackedFunc(
BackendPackedCFunc faddr, const std::shared_ptr<ModuleNode>& mptr);
/*!
* \brief Load and append module blob to module list
* \param mblob The module blob.
......@@ -39,13 +39,13 @@ void ImportModuleBlob(const char* mblob, std::vector<Module>* module_list);
* \param flookup A symbol lookup function.
* \tparam FLookup a function of signature string->void*
*/
template<typename FLookup>
template <typename FLookup>
void InitContextFunctions(FLookup flookup) {
#define DGL_INIT_CONTEXT_FUNC(FuncName) \
if (auto *fp = reinterpret_cast<decltype(&FuncName)*> \
(flookup("__" #FuncName))) { \
*fp = FuncName; \
}
#define DGL_INIT_CONTEXT_FUNC(FuncName) \
if (auto* fp = \
reinterpret_cast<decltype(&FuncName)*>(flookup("__" #FuncName))) { \
*fp = FuncName; \
}
// Initialize the functions
DGL_INIT_CONTEXT_FUNC(DGLFuncCall);
DGL_INIT_CONTEXT_FUNC(DGLAPISetLastError);
......@@ -55,8 +55,8 @@ void InitContextFunctions(FLookup flookup) {
DGL_INIT_CONTEXT_FUNC(DGLBackendParallelLaunch);
DGL_INIT_CONTEXT_FUNC(DGLBackendParallelBarrier);
#undef DGL_INIT_CONTEXT_FUNC
#undef DGL_INIT_CONTEXT_FUNC
}
} // namespace runtime
} // namespace dgl
#endif // DGL_RUNTIME_MODULE_UTIL_H_
#endif // DGL_RUNTIME_MODULE_UTIL_H_
......@@ -4,8 +4,9 @@
* \brief Implementation of runtime object APIs.
*/
#include <dgl/runtime/object.h>
#include <memory>
#include <atomic>
#include <memory>
#include <mutex>
#include <unordered_map>
......@@ -36,7 +37,7 @@ bool Object::_DerivedFrom(uint32_t tid) const {
// this is slow, usually caller always hold the result in a static variable.
uint32_t Object::TypeKey2Index(const char* key) {
TypeManager *t = TypeManager::Global();
TypeManager* t = TypeManager::Global();
std::lock_guard<std::mutex> lock(t->mutex);
std::string skey = key;
auto it = t->key2index.find(skey);
......@@ -50,7 +51,7 @@ uint32_t Object::TypeKey2Index(const char* key) {
}
const char* Object::TypeIndex2Key(uint32_t index) {
TypeManager *t = TypeManager::Global();
TypeManager* t = TypeManager::Global();
std::lock_guard<std::mutex> lock(t->mutex);
CHECK_NE(index, 0);
return t->index2key.at(index - 1).c_str();
......
/*!
* Copyright (c) 2017 by Contributors
* \file pack_args.h
* \brief Utility to pack DGLArgs to other type-erased fution calling convention.
* \brief Utility to pack DGLArgs to other type-erased fution calling
* convention.
*
* Two type erased function signatures are supported.
* - cuda_style(void** args, int num_args);
......@@ -15,8 +16,9 @@
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/packed_func.h>
#include <vector>
#include <cstring>
#include <vector>
namespace dgl {
namespace runtime {
......@@ -38,10 +40,12 @@ union ArgUnion {
*
* \return The wrapped packed function.
*/
template<typename F>
inline PackedFunc PackFuncVoidAddr(F f, const std::vector<DGLDataType>& arg_types);
template <typename F>
inline PackedFunc PackFuncVoidAddr(
F f, const std::vector<DGLDataType>& arg_types);
/*!
* \brief Create a packed function that from function only packs buffer arguments.
* \brief Create a packed function that from function only packs buffer
* arguments.
*
* \param f with signiture (DGLArgs args, DGLRetValue* rv, ArgUnion* pack_args)
* \param arg_types The arguments type information.
......@@ -49,19 +53,23 @@ inline PackedFunc PackFuncVoidAddr(F f, const std::vector<DGLDataType>& arg_type
*
* \return The wrapped packed function.
*/
template<typename F>
inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<DGLDataType>& arg_types);
template <typename F>
inline PackedFunc PackFuncNonBufferArg(
F f, const std::vector<DGLDataType>& arg_types);
/*!
* \brief Create a packed function that from function that takes a packed arguments.
* \brief Create a packed function that from function that takes a packed
* arguments.
*
* \param f with signature (DGLArgs args, DGLRetValue* rv, void* pack_args, size_t nbytes)
* \param f with signature (DGLArgs args, DGLRetValue* rv, void* pack_args,
* size_t nbytes)
* \param arg_types The arguments that wish to get from
* \tparam F the function type
*
* \return The wrapped packed function.
*/
template<typename F>
inline PackedFunc PackFuncPackedArg(F f, const std::vector<DGLDataType>& arg_types);
template <typename F>
inline PackedFunc PackFuncPackedArg(
F f, const std::vector<DGLDataType>& arg_types);
/*!
* \brief Extract number of buffer argument from the argument types.
* \param arg_types The argument types.
......@@ -71,23 +79,21 @@ inline size_t NumBufferArgs(const std::vector<DGLDataType>& arg_types);
// implementations details
namespace detail {
template<typename T, int kSize>
template <typename T, int kSize>
class TempArray {
public:
explicit TempArray(int size) {}
T* data() {
return data_;
}
T* data() { return data_; }
private:
T data_[kSize];
};
template<typename T>
template <typename T>
class TempArray<T, 0> {
public:
explicit TempArray(int size) : data_(size) {}
T* data() {
return data_.data();
}
T* data() { return data_.data(); }
private:
std::vector<T> data_;
};
......@@ -120,8 +126,9 @@ inline ArgConvertCode GetArgConvertCode(DGLDataType t) {
return HANDLE_TO_HANDLE;
}
template<int N, typename F>
inline PackedFunc PackFuncVoidAddr_(F f, const std::vector<ArgConvertCode>& codes) {
template <int N, typename F>
inline PackedFunc PackFuncVoidAddr_(
F f, const std::vector<ArgConvertCode>& codes) {
int num_args = static_cast<int>(codes.size());
auto ret = [f, codes, num_args](DGLArgs args, DGLRetValue* ret) {
TempArray<void*, N> addr_(num_args);
......@@ -141,7 +148,7 @@ inline PackedFunc PackFuncVoidAddr_(F f, const std::vector<ArgConvertCode>& code
addr[i] = &(holder[i]);
break;
}
case INT64_TO_UINT32 : {
case INT64_TO_UINT32: {
holder[i].v_uint32 = static_cast<uint32_t>(args.values[i].v_int64);
addr[i] = &(holder[i]);
break;
......@@ -158,7 +165,7 @@ inline PackedFunc PackFuncVoidAddr_(F f, const std::vector<ArgConvertCode>& code
return PackedFunc(ret);
}
template<int N, typename F>
template <int N, typename F>
inline PackedFunc PackFuncNonBufferArg_(
F f, int base, const std::vector<ArgConvertCode>& codes) {
int num_args = static_cast<int>(codes.size());
......@@ -169,22 +176,27 @@ inline PackedFunc PackFuncNonBufferArg_(
switch (codes[i]) {
case INT64_TO_INT64:
case FLOAT64_TO_FLOAT64: {
LOG(FATAL) << "Donot support 64bit argument to device function"; break;
LOG(FATAL) << "Donot support 64bit argument to device function";
break;
}
case INT64_TO_INT32: {
holder[i].v_int32 = static_cast<int32_t>(args.values[base + i].v_int64);
holder[i].v_int32 =
static_cast<int32_t>(args.values[base + i].v_int64);
break;
}
case INT64_TO_UINT32 : {
holder[i].v_uint32 = static_cast<uint32_t>(args.values[base + i].v_int64);
case INT64_TO_UINT32: {
holder[i].v_uint32 =
static_cast<uint32_t>(args.values[base + i].v_int64);
break;
}
case FLOAT64_TO_FLOAT32: {
holder[i].v_float32 = static_cast<float>(args.values[base + i].v_float64);
holder[i].v_float32 =
static_cast<float>(args.values[base + i].v_float64);
break;
}
case HANDLE_TO_HANDLE: {
LOG(FATAL) << "not reached"; break;
LOG(FATAL) << "not reached";
break;
}
}
}
......@@ -193,7 +205,7 @@ inline PackedFunc PackFuncNonBufferArg_(
return PackedFunc(ret);
}
template<int N, typename F>
template <int N, typename F>
inline PackedFunc PackFuncPackedArg_(
F f, const std::vector<ArgConvertCode>& codes) {
int num_args = static_cast<int>(codes.size());
......@@ -221,7 +233,7 @@ inline PackedFunc PackFuncPackedArg_(
++ptr;
break;
}
case INT64_TO_UINT32 : {
case INT64_TO_UINT32: {
*reinterpret_cast<uint32_t*>(ptr) =
static_cast<uint32_t>(args.values[i].v_int64);
++ptr;
......@@ -234,7 +246,8 @@ inline PackedFunc PackFuncPackedArg_(
break;
}
default: {
LOG(FATAL) << "not reached"; break;
LOG(FATAL) << "not reached";
break;
}
}
}
......@@ -244,8 +257,9 @@ inline PackedFunc PackFuncPackedArg_(
}
} // namespace detail
template<typename F>
inline PackedFunc PackFuncVoidAddr(F f, const std::vector<DGLDataType>& arg_types) {
template <typename F>
inline PackedFunc PackFuncVoidAddr(
F f, const std::vector<DGLDataType>& arg_types) {
std::vector<detail::ArgConvertCode> codes(arg_types.size());
for (size_t i = 0; i < arg_types.size(); ++i) {
codes[i] = detail::GetArgConvertCode(arg_types[i]);
......@@ -265,7 +279,8 @@ inline size_t NumBufferArgs(const std::vector<DGLDataType>& arg_types) {
size_t base = arg_types.size();
for (size_t i = 0; i < arg_types.size(); ++i) {
if (arg_types[i].code != kHandle) {
base = i; break;
base = i;
break;
}
}
for (size_t i = base; i < arg_types.size(); ++i) {
......@@ -275,8 +290,9 @@ inline size_t NumBufferArgs(const std::vector<DGLDataType>& arg_types) {
return base;
}
template<typename F>
inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<DGLDataType>& arg_types) {
template <typename F>
inline PackedFunc PackFuncNonBufferArg(
F f, const std::vector<DGLDataType>& arg_types) {
size_t num_buffer = NumBufferArgs(arg_types);
std::vector<detail::ArgConvertCode> codes;
for (size_t i = num_buffer; i < arg_types.size(); ++i) {
......@@ -292,8 +308,9 @@ inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<DGLDataType>& arg_
}
}
template<typename F>
inline PackedFunc PackFuncPackedArg(F f, const std::vector<DGLDataType>& arg_types) {
template <typename F>
inline PackedFunc PackFuncPackedArg(
F f, const std::vector<DGLDataType>& arg_types) {
std::vector<detail::ArgConvertCode> codes;
for (size_t i = 0; i < arg_types.size(); ++i) {
codes.push_back(detail::GetArgConvertCode(arg_types[i]));
......
......@@ -8,4 +8,4 @@ namespace dgl {
namespace runtime {
DefaultGrainSizeT default_grain_size;
} // namespace runtime
} // namesoace dgl
} // namespace dgl
......@@ -3,13 +3,15 @@
* \file registry.cc
* \brief The global registry of packed function.
*/
#include <dgl/runtime/registry.h>
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <dgl/runtime/registry.h>
#include <unordered_map>
#include <mutex>
#include <memory>
#include <array>
#include <memory>
#include <mutex>
#include <unordered_map>
#include "runtime_base.h"
namespace dgl {
......@@ -18,9 +20,10 @@ namespace runtime {
struct Registry::Manager {
// map storing the functions.
// We delibrately used raw pointer
// This is because PackedFunc can contain callbacks into the host languge(python)
// and the resource can become invalid because of indeterminstic order of destruction.
// The resources will only be recycled during program exit.
// This is because PackedFunc can contain callbacks into the host
// languge(python) and the resource can become invalid because of
// indeterminstic order of destruction. The resources will only be recycled
// during program exit.
std::unordered_map<std::string, Registry*> fmap;
// vtable for extension type
std::array<ExtTypeVTable, kExtEnd> ext_vtable;
......@@ -44,7 +47,8 @@ Registry& Registry::set_body(PackedFunc f) { // NOLINT(*)
return *this;
}
Registry& Registry::Register(const std::string& name, bool override) { // NOLINT(*)
Registry& Registry::Register(
const std::string& name, bool override) { // NOLINT(*)
Manager* m = Manager::Global();
std::lock_guard<std::mutex> lock(m->mutex);
auto it = m->fmap.find(name);
......@@ -54,8 +58,7 @@ Registry& Registry::Register(const std::string& name, bool override) { // NOLIN
m->fmap[name] = r;
return *r;
} else {
CHECK(override)
<< "Global PackedFunc " << name << " is already registered";
CHECK(override) << "Global PackedFunc " << name << " is already registered";
return *it->second;
}
}
......@@ -82,7 +85,7 @@ std::vector<std::string> Registry::ListNames() {
std::lock_guard<std::mutex> lock(m->mutex);
std::vector<std::string> keys;
keys.reserve(m->fmap.size());
for (const auto &kv : m->fmap) {
for (const auto& kv : m->fmap) {
keys.push_back(kv.first);
}
return keys;
......@@ -92,8 +95,7 @@ ExtTypeVTable* ExtTypeVTable::Get(int type_code) {
CHECK(type_code > kExtBegin && type_code < kExtEnd);
Registry::Manager* m = Registry::Manager::Global();
ExtTypeVTable* vt = &(m->ext_vtable[type_code]);
CHECK(vt->destroy != nullptr)
<< "Extension type not registered";
CHECK(vt->destroy != nullptr) << "Extension type not registered";
return vt;
}
......@@ -114,7 +116,7 @@ struct DGLFuncThreadLocalEntry {
/*! \brief result holder for returning strings */
std::vector<std::string> ret_vec_str;
/*! \brief result holder for returning string pointers */
std::vector<const char *> ret_vec_charp;
std::vector<const char*> ret_vec_charp;
};
/*! \brief Thread local store that can be used to hold return values. */
......@@ -126,8 +128,7 @@ int DGLExtTypeFree(void* handle, int type_code) {
API_END();
}
int DGLFuncRegisterGlobal(
const char* name, DGLFunctionHandle f, int override) {
int DGLFuncRegisterGlobal(const char* name, DGLFunctionHandle f, int override) {
API_BEGIN();
dgl::runtime::Registry::Register(name, override != 0)
.set_body(*static_cast<dgl::runtime::PackedFunc*>(f));
......@@ -136,8 +137,7 @@ int DGLFuncRegisterGlobal(
int DGLFuncGetGlobal(const char* name, DGLFunctionHandle* out) {
API_BEGIN();
const dgl::runtime::PackedFunc* fp =
dgl::runtime::Registry::Get(name);
const dgl::runtime::PackedFunc* fp = dgl::runtime::Registry::Get(name);
if (fp != nullptr) {
*out = new dgl::runtime::PackedFunc(*fp); // NOLINT(*)
} else {
......@@ -146,10 +146,9 @@ int DGLFuncGetGlobal(const char* name, DGLFunctionHandle* out) {
API_END();
}
int DGLFuncListGlobalNames(int *out_size,
const char*** out_array) {
int DGLFuncListGlobalNames(int* out_size, const char*** out_array) {
API_BEGIN();
DGLFuncThreadLocalEntry *ret = DGLFuncThreadLocalStore::Get();
DGLFuncThreadLocalEntry* ret = DGLFuncThreadLocalStore::Get();
ret->ret_vec_str = dgl::runtime::Registry::ListNames();
ret->ret_vec_charp.clear();
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
......
......@@ -14,9 +14,10 @@ namespace dgl {
namespace runtime {
/*
* The runtime allocates resources during the computation. Some of the resources cannot be
* destroyed after the process exits especially when the process doesn't exits normally.
* We need to keep track of the resources in the system and clean them up properly.
* The runtime allocates resources during the computation. Some of the resources
* cannot be destroyed after the process exits especially when the process
* doesn't exits normally. We need to keep track of the resources in the system
* and clean them up properly.
*/
class ResourceManager {
std::unordered_map<std::string, std::shared_ptr<Resource>> resources;
......@@ -25,12 +26,11 @@ class ResourceManager {
void Add(const std::string &key, std::shared_ptr<Resource> resource) {
auto it = resources.find(key);
CHECK(it == resources.end()) << key << " already exists";
resources.insert(std::pair<std::string, std::shared_ptr<Resource>>(key, resource));
resources.insert(
std::pair<std::string, std::shared_ptr<Resource>>(key, resource));
}
void Erase(const std::string &key) {
resources.erase(key);
}
void Erase(const std::string &key) { resources.erase(key); }
void Cleanup() {
for (auto it = resources.begin(); it != resources.end(); it++) {
......@@ -46,13 +46,9 @@ void AddResource(const std::string &key, std::shared_ptr<Resource> resource) {
manager.Add(key, resource);
}
void DeleteResource(const std::string &key) {
manager.Erase(key);
}
void DeleteResource(const std::string &key) { manager.Erase(key); }
void CleanupResources() {
manager.Cleanup();
}
void CleanupResources() { manager.Cleanup(); }
} // namespace runtime
} // namespace dgl
......@@ -6,24 +6,23 @@
#ifndef DGL_RUNTIME_RESOURCE_MANAGER_H_
#define DGL_RUNTIME_RESOURCE_MANAGER_H_
#include <unordered_map>
#include <string>
#include <memory>
#include <string>
#include <unordered_map>
namespace dgl {
namespace runtime {
/*
* A class that provides the interface to describe a resource that can be managed by
* a resource manager. Some of the resources cannot be free'd automatically when
* the process exits, especially when the process doesn't exit normally. One example
* is shared memory. We can keep track of this kind of resources and manage them
* properly.
* A class that provides the interface to describe a resource that can be
* managed by a resource manager. Some of the resources cannot be free'd
* automatically when the process exits, especially when the process doesn't
* exit normally. One example is shared memory. We can keep track of this kind
* of resources and manage them properly.
*/
class Resource {
public:
virtual ~Resource() {
}
virtual ~Resource() {}
virtual void Destroy() = 0;
};
......
......@@ -7,19 +7,32 @@
#define DGL_RUNTIME_RUNTIME_BASE_H_
#include <dgl/runtime/c_runtime_api.h>
#include <stdexcept>
/*! \brief macro to guard beginning and end section of all functions */
#define API_BEGIN() try {
/*! \brief every function starts with API_BEGIN();
and finishes with API_END() or API_END_HANDLE_ERROR */
#define API_END() } catch(std::runtime_error &_except_) { return DGLAPIHandleException(_except_); } return 0; // NOLINT(*)
#define API_END() \
} \
catch (std::runtime_error & _except_) { \
return DGLAPIHandleException(_except_); \
} \
return 0; // NOLINT(*)
/*!
* \brief every function starts with API_BEGIN();
* and finishes with API_END() or API_END_HANDLE_ERROR
* The finally clause contains procedure to cleanup states when an error happens.
* The finally clause contains procedure to cleanup states when an error
* happens.
*/
#define API_END_HANDLE_ERROR(Finalize) } catch(std::runtime_error &_except_) { Finalize; return DGLAPIHandleException(_except_); } return 0; // NOLINT(*)
#define API_END_HANDLE_ERROR(Finalize) \
} \
catch (std::runtime_error & _except_) { \
Finalize; \
return DGLAPIHandleException(_except_); \
} \
return 0; // NOLINT(*)
/*!
* \brief handle exception throwed out
......
......@@ -25,9 +25,7 @@ Semaphore::Semaphore() {
}
}
void Semaphore::Wait() {
WaitForSingleObject(sem_, INFINITE);
}
void Semaphore::Wait() { WaitForSingleObject(sem_, INFINITE); }
bool Semaphore::TimedWait(int) {
// Timed wait is not supported on WIN32.
......@@ -35,19 +33,13 @@ bool Semaphore::TimedWait(int) {
return true;
}
void Semaphore::Post() {
ReleaseSemaphore(sem_, 1, nullptr);
}
void Semaphore::Post() { ReleaseSemaphore(sem_, 1, nullptr); }
#else
Semaphore::Semaphore() {
sem_init(&sem_, 0, 0);
}
Semaphore::Semaphore() { sem_init(&sem_, 0, 0); }
void Semaphore::Wait() {
sem_wait(&sem_);
}
void Semaphore::Wait() { sem_wait(&sem_); }
bool Semaphore::TimedWait(int timeout) {
// sem_timedwait does not exist in Mac OS.
......@@ -92,9 +84,7 @@ bool Semaphore::TimedWait(int timeout) {
return true;
}
void Semaphore::Post() {
sem_post(&sem_);
}
void Semaphore::Post() { sem_post(&sem_); }
#endif
......
......@@ -12,7 +12,6 @@
#include <semaphore.h>
#endif
namespace dgl {
namespace runtime {
......@@ -31,7 +30,8 @@ class Semaphore {
void Wait();
/*!
* \brief timed wait, decrease semaphore by 1 or returns if times out
* \param timeout The timeout value in milliseconds. If zero, wait indefinitely.
* \param timeout The timeout value in milliseconds. If zero, wait
* indefinitely.
*/
bool TimedWait(int timeout);
/*!
......
......@@ -4,14 +4,14 @@
* \brief Shared memory management.
*/
#ifndef _WIN32
#include <sys/mman.h>
#include <fcntl.h>
#include <sys/mman.h>
#include <unistd.h>
#endif
#include <dgl/runtime/shared_mem.h>
#include <dmlc/logging.h>
#include <stdio.h>
#include <string.h>
#include <dmlc/logging.h>
#include <dgl/runtime/shared_mem.h>
#include "resource_manager.h"
......@@ -22,21 +22,19 @@ namespace runtime {
* Shared memory is a resource that cannot be cleaned up if the process doesn't
* exit normally. We'll manage the resource with ResourceManager.
*/
class SharedMemoryResource: public Resource {
class SharedMemoryResource : public Resource {
std::string name;
public:
explicit SharedMemoryResource(const std::string &name) {
this->name = name;
}
explicit SharedMemoryResource(const std::string &name) { this->name = name; }
void Destroy() {
// LOG(INFO) << "remove " << name << " for shared memory";
#ifndef _WIN32
shm_unlink(name.c_str());
#else // _WIN32
// NOTHING; Windows automatically removes the shared memory object once all handles
// are unmapped.
#else // _WIN32
// NOTHING; Windows automatically removes the shared memory object once all
// handles are unmapped.
#endif
}
};
......@@ -55,24 +53,21 @@ SharedMemory::SharedMemory(const std::string &name) {
SharedMemory::~SharedMemory() {
#ifndef _WIN32
if (ptr_ && size_ != 0)
CHECK(munmap(ptr_, size_) != -1) << strerror(errno);
if (fd_ != -1)
close(fd_);
if (ptr_ && size_ != 0) CHECK(munmap(ptr_, size_) != -1) << strerror(errno);
if (fd_ != -1) close(fd_);
if (own_) {
// LOG(INFO) << "remove " << name << " for shared memory";
if (name != "") {
shm_unlink(name.c_str());
// The resource has been deleted. We don't need to keep track of it any more.
// The resource has been deleted. We don't need to keep track of it any
// more.
DeleteResource(name);
}
}
#else
if (ptr_)
CHECK(UnmapViewOfFile(ptr_)) << "Win32 Error: " << GetLastError();
if (handle_)
CloseHandle(handle_);
// Windows do not need a separate shm_unlink step.
if (ptr_) CHECK(UnmapViewOfFile(ptr_)) << "Win32 Error: " << GetLastError();
if (handle_) CloseHandle(handle_);
// Windows do not need a separate shm_unlink step.
#endif // _WIN32
}
......@@ -82,28 +77,26 @@ void *SharedMemory::CreateNew(size_t sz) {
// We need to create a shared-memory file.
// TODO(zhengda) we need to report error if the shared-memory file exists.
int flag = O_RDWR|O_CREAT;
int flag = O_RDWR | O_CREAT;
fd_ = shm_open(name.c_str(), flag, S_IRUSR | S_IWUSR);
CHECK_NE(fd_, -1) << "fail to open " << name << ": " << strerror(errno);
// Shared memory cannot be deleted if the process exits abnormally in Linux.
AddResource(name, std::shared_ptr<Resource>(new SharedMemoryResource(name)));
auto res = ftruncate(fd_, sz);
CHECK_NE(res, -1)
<< "Failed to truncate the file. " << strerror(errno);
ptr_ = mmap(NULL, sz, PROT_READ|PROT_WRITE, MAP_SHARED, fd_, 0);
CHECK_NE(res, -1) << "Failed to truncate the file. " << strerror(errno);
ptr_ = mmap(NULL, sz, PROT_READ | PROT_WRITE, MAP_SHARED, fd_, 0);
CHECK_NE(ptr_, MAP_FAILED)
<< "Failed to map shared memory. mmap failed with error " << strerror(errno);
<< "Failed to map shared memory. mmap failed with error "
<< strerror(errno);
this->size_ = sz;
return ptr_;
#else
handle_ = CreateFileMapping(
INVALID_HANDLE_VALUE,
nullptr,
PAGE_READWRITE,
static_cast<DWORD>(sz >> 32),
static_cast<DWORD>(sz & 0xFFFFFFFF),
INVALID_HANDLE_VALUE, nullptr, PAGE_READWRITE,
static_cast<DWORD>(sz >> 32), static_cast<DWORD>(sz & 0xFFFFFFFF),
name.c_str());
CHECK(handle_ != nullptr) << "fail to open " << name << ", Win32 error: " << GetLastError();
CHECK(handle_ != nullptr)
<< "fail to open " << name << ", Win32 error: " << GetLastError();
ptr_ = MapViewOfFile(handle_, FILE_MAP_ALL_ACCESS, 0, 0, sz);
if (ptr_ == nullptr) {
LOG(FATAL) << "Memory mapping failed, Win32 error: " << GetLastError();
......@@ -120,14 +113,16 @@ void *SharedMemory::Open(size_t sz) {
int flag = O_RDWR;
fd_ = shm_open(name.c_str(), flag, S_IRUSR | S_IWUSR);
CHECK_NE(fd_, -1) << "fail to open " << name << ": " << strerror(errno);
ptr_ = mmap(NULL, sz, PROT_READ|PROT_WRITE, MAP_SHARED, fd_, 0);
ptr_ = mmap(NULL, sz, PROT_READ | PROT_WRITE, MAP_SHARED, fd_, 0);
CHECK_NE(ptr_, MAP_FAILED)
<< "Failed to map shared memory. mmap failed with error " << strerror(errno);
<< "Failed to map shared memory. mmap failed with error "
<< strerror(errno);
this->size_ = sz;
return ptr_;
#else
handle_ = OpenFileMapping(FILE_MAP_ALL_ACCESS, FALSE, name.c_str());
CHECK(handle_ != nullptr) << "fail to open " << name << ", Win32 Error: " << GetLastError();
CHECK(handle_ != nullptr)
<< "fail to open " << name << ", Win32 Error: " << GetLastError();
ptr_ = MapViewOfFile(handle_, FILE_MAP_ALL_ACCESS, 0, 0, sz);
if (ptr_ == nullptr) {
LOG(FATAL) << "Memory mapping failed, Win32 error: " << GetLastError();
......
......@@ -3,9 +3,11 @@
* \file system_lib_module.cc
* \brief SystemLib module.
*/
#include <dgl/runtime/registry.h>
#include <dgl/runtime/c_backend_api.h>
#include <dgl/runtime/registry.h>
#include <mutex>
#include "module_util.h"
namespace dgl {
......@@ -15,9 +17,7 @@ class SystemLibModuleNode : public ModuleNode {
public:
SystemLibModuleNode() = default;
const char* type_key() const final {
return "system_lib";
}
const char* type_key() const final { return "system_lib"; }
PackedFunc GetFunction(
const std::string& name,
......@@ -57,8 +57,8 @@ class SystemLibModuleNode : public ModuleNode {
auto it = tbl_.find(name);
if (it != tbl_.end() && ptr != it->second) {
LOG(WARNING) << "SystemLib symbol " << name
<< " get overriden to a different address "
<< ptr << "->" << it->second;
<< " get overriden to a different address " << ptr << "->"
<< it->second;
}
tbl_[name] = ptr;
}
......@@ -80,9 +80,9 @@ class SystemLibModuleNode : public ModuleNode {
};
DGL_REGISTER_GLOBAL("module._GetSystemLib")
.set_body([](DGLArgs args, DGLRetValue* rv) {
*rv = runtime::Module(SystemLibModuleNode::Global());
});
.set_body([](DGLArgs args, DGLRetValue* rv) {
*rv = runtime::Module(SystemLibModuleNode::Global());
});
} // namespace runtime
} // namespace dgl
......
......@@ -4,12 +4,12 @@
* \brief Adapter library caller
*/
#include <dgl/runtime/tensordispatch.h>
#include <dgl/runtime/registry.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/tensordispatch.h>
#if defined(WIN32) || defined(_WIN32)
#include <windows.h>
#else // !WIN32
#else // !WIN32
#include <dlfcn.h>
#endif // WIN32
#include <cstring>
......@@ -23,25 +23,27 @@ bool TensorDispatcher::Load(const char *path) {
CHECK(!available_) << "The tensor adapter can only load once.";
if (path == nullptr || strlen(path) == 0)
// does not have dispatcher library; all operators fall back to DGL's implementation
// does not have dispatcher library; all operators fall back to DGL's
// implementation
return false;
#if defined(WIN32) || defined(_WIN32)
handle_ = LoadLibrary(path);
if (!handle_)
return false;
if (!handle_) return false;
for (int i = 0; i < num_entries_; ++i) {
entrypoints_[i] = reinterpret_cast<void*>(GetProcAddress(handle_, names_[i]));
entrypoints_[i] =
reinterpret_cast<void *>(GetProcAddress(handle_, names_[i]));
CHECK(entrypoints_[i]) << "cannot locate symbol " << names_[i];
}
#else // !WIN32
handle_ = dlopen(path, RTLD_LAZY);
if (!handle_) {
DLOG(WARNING) << "Could not open file: " << dlerror()
<< ". This does not affect DGL's but might impact its performance.";
DLOG(WARNING)
<< "Could not open file: " << dlerror()
<< ". This does not affect DGL's but might impact its performance.";
return false;
}
......
......@@ -3,23 +3,24 @@
* \file thread_pool.cc
* \brief Threadpool for multi-threading runtime.
*/
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/c_backend_api.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/packed_func.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/threading_backend.h>
#include <dmlc/thread_local.h>
#include <dmlc/logging.h>
#include <thread>
#include <condition_variable>
#include <mutex>
#include <atomic>
#include <dmlc/thread_local.h>
#include <algorithm>
#include <vector>
#include <string>
#include <atomic>
#include <condition_variable>
#include <cstring>
#include <memory>
#include <mutex>
#include <sstream>
#include <string>
#include <thread>
#include <vector>
const constexpr int kL1CacheBytes = 64;
......@@ -35,10 +36,8 @@ constexpr int kSyncStride = 64 / sizeof(std::atomic<int>);
class ParallelLauncher {
public:
// Reset the the task request.
void Init(FDGLParallelLambda flambda,
void* cdata,
int num_task,
bool need_sync) {
void Init(
FDGLParallelLambda flambda, void* cdata, int num_task, bool need_sync) {
num_pending_.store(num_task);
this->cdata = cdata;
this->flambda = flambda;
......@@ -54,17 +53,14 @@ class ParallelLauncher {
}
if (need_sync) {
for (int i = 0; i < num_task; ++i) {
sync_counter_[i * kSyncStride].store(
0, std::memory_order_relaxed);
sync_counter_[i * kSyncStride].store(0, std::memory_order_relaxed);
}
this->env.sync_handle = sync_counter_;
} else {
this->env.sync_handle = nullptr;
}
}
~ParallelLauncher() {
delete[] sync_counter_;
}
~ParallelLauncher() { delete[] sync_counter_; }
// Wait n jobs to finish
int WaitForJobs() {
while (num_pending_.load() != 0) {
......@@ -90,9 +86,7 @@ class ParallelLauncher {
has_error_.store(true);
}
// Signal that one job has finished.
void SignalJobFinish() {
num_pending_.fetch_sub(1);
}
void SignalJobFinish() { num_pending_.fetch_sub(1); }
// Get thread local version of the store.
static ParallelLauncher* ThreadLocal() {
return dmlc::ThreadLocalStore<ParallelLauncher>::Get();
......@@ -127,15 +121,9 @@ class SpscTaskQueue {
int32_t task_id;
};
SpscTaskQueue() :
buffer_(new Task[kRingSize]),
head_(0),
tail_(0) {
}
SpscTaskQueue() : buffer_(new Task[kRingSize]), head_(0), tail_(0) {}
~SpscTaskQueue() {
delete[] buffer_;
}
~SpscTaskQueue() { delete[] buffer_; }
/*!
* \brief Push a task into the queue and notify the comsumer if it is on wait.
......@@ -159,16 +147,16 @@ class SpscTaskQueue {
*/
bool Pop(Task* output, uint32_t spin_count = 300000) {
// Busy wait a bit when the queue is empty.
// If a new task comes to the queue quickly, this wait avoid the worker from sleeping.
// The default spin count is set by following the typical omp convention
// If a new task comes to the queue quickly, this wait avoid the worker from
// sleeping. The default spin count is set by following the typical omp
// convention
for (uint32_t i = 0; i < spin_count && pending_.load() == 0; ++i) {
dgl::runtime::threading::YieldThread();
}
if (pending_.fetch_sub(1) == 0) {
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [this] {
return pending_.load() >= 0 || exit_now_.load();
});
cv_.wait(
lock, [this] { return pending_.load() >= 0 || exit_now_.load(); });
}
if (exit_now_.load(std::memory_order_relaxed)) {
return false;
......@@ -209,7 +197,8 @@ class SpscTaskQueue {
return false;
}
// the cache line paddings are used for avoid false sharing between atomic variables
// the cache line paddings are used for avoid false sharing between atomic
// variables
typedef char cache_line_pad_t[kL1CacheBytes];
cache_line_pad_t pad0_;
// size of the queue, the queue can host size_ - 1 items at most
......@@ -243,16 +232,17 @@ class SpscTaskQueue {
// The thread pool
class ThreadPool {
public:
ThreadPool(): num_workers_(dgl::runtime::threading::MaxConcurrency()) {
ThreadPool() : num_workers_(dgl::runtime::threading::MaxConcurrency()) {
for (int i = 0; i < num_workers_; ++i) {
// The SpscTaskQueue only hosts ONE item at a time
queues_.emplace_back(std::unique_ptr<SpscTaskQueue>(new SpscTaskQueue()));
}
threads_ = std::unique_ptr<dgl::runtime::threading::ThreadGroup>(
new dgl::runtime::threading::ThreadGroup(
num_workers_, [this](int worker_id) { this->RunWorker(worker_id); },
exclude_worker0_ /* include_main_thread */));
num_workers_used_ = threads_->Configure(threading::ThreadGroup::kBig, 0, exclude_worker0_);
num_workers_, [this](int worker_id) { this->RunWorker(worker_id); },
exclude_worker0_ /* include_main_thread */));
num_workers_used_ =
threads_->Configure(threading::ThreadGroup::kBig, 0, exclude_worker0_);
}
~ThreadPool() {
for (std::unique_ptr<SpscTaskQueue>& q : queues_) {
......@@ -260,13 +250,11 @@ class ThreadPool {
}
threads_.reset();
}
int Launch(FDGLParallelLambda flambda,
void* cdata,
int num_task,
int need_sync) {
int Launch(
FDGLParallelLambda flambda, void* cdata, int num_task, int need_sync) {
ParallelLauncher* launcher = ParallelLauncher::ThreadLocal();
CHECK(!launcher->is_worker)
<< "Cannot launch parallel job inside worker, consider fuse then parallel";
CHECK(!launcher->is_worker) << "Cannot launch parallel job inside worker, "
"consider fuse then parallel";
if (num_task == 0) {
num_task = num_workers_used_;
}
......@@ -300,11 +288,11 @@ class ThreadPool {
return dmlc::ThreadLocalStore<ThreadPool>::Get();
}
void UpdateWorkerConfiguration(threading::ThreadGroup::AffinityMode mode, int nthreads) {
void UpdateWorkerConfiguration(
threading::ThreadGroup::AffinityMode mode, int nthreads) {
// this will also reset the affinity of the ThreadGroup
// may use less than the MaxConcurrency number of workers
num_workers_used_ = threads_->Configure(mode, nthreads,
exclude_worker0_);
num_workers_used_ = threads_->Configure(mode, nthreads, exclude_worker0_);
// if MaxConcurrency restricted the number of workers (e.g., due to
// hyperthreading), respect the restriction
num_workers_used_ = std::min(num_workers_, num_workers_used_);
......@@ -341,23 +329,19 @@ class ThreadPool {
};
DGL_REGISTER_GLOBAL("runtime.config_threadpool")
.set_body([](DGLArgs args, DGLRetValue* rv) {
threading::ThreadGroup::AffinityMode mode =\
static_cast<threading::ThreadGroup::AffinityMode>(\
static_cast<int>(args[0]));
int nthreads = args[1];
ThreadPool::ThreadLocal()->UpdateWorkerConfiguration(mode, nthreads);
});
.set_body([](DGLArgs args, DGLRetValue* rv) {
threading::ThreadGroup::AffinityMode mode =
static_cast<threading::ThreadGroup::AffinityMode>(
static_cast<int>(args[0]));
int nthreads = args[1];
ThreadPool::ThreadLocal()->UpdateWorkerConfiguration(mode, nthreads);
});
} // namespace runtime
} // namespace dgl
int DGLBackendParallelLaunch(
FDGLParallelLambda flambda,
void* cdata,
int num_task) {
FDGLParallelLambda flambda, void* cdata, int num_task) {
int res = dgl::runtime::ThreadPool::ThreadLocal()->Launch(
flambda, cdata, num_task, 1);
return res;
......@@ -372,8 +356,8 @@ int DGLBackendParallelBarrier(int task_id, DGLParallelGroupEnv* penv) {
1, std::memory_order_release);
for (int i = 0; i < num_task; ++i) {
if (i != task_id) {
while (sync_counter[i * kSyncStride].load(
std::memory_order_relaxed) <= old_counter) {
while (sync_counter[i * kSyncStride].load(std::memory_order_relaxed) <=
old_counter) {
dgl::runtime::threading::YieldThread();
}
}
......
......@@ -7,6 +7,7 @@
#define DGL_RUNTIME_THREAD_STORAGE_SCOPE_H_
#include <dgl/runtime/packed_func.h>
#include <string>
#include <vector>
......@@ -40,9 +41,12 @@ enum class StorageRank {
*/
inline StorageRank DefaultStorageRank(int thread_scope_rank) {
switch (thread_scope_rank) {
case -1: return StorageRank::kGlobal;
case 0: return StorageRank::kShared;
case 1: return StorageRank::kLocal;
case -1:
return StorageRank::kGlobal;
case 0:
return StorageRank::kShared;
case 1:
return StorageRank::kLocal;
default: {
LOG(FATAL) << "unknown rank";
return StorageRank::kGlobal;
......@@ -66,11 +70,17 @@ struct StorageScope {
inline std::string to_string() const {
std::string ret;
switch (rank) {
case StorageRank::kGlobal: return "global" + tag;
case StorageRank::kShared: return "shared" + tag;
case StorageRank::kWarp: return "warp" + tag;
case StorageRank::kLocal: return "local" + tag;
default: LOG(FATAL) << "unknown storage scope"; return "";
case StorageRank::kGlobal:
return "global" + tag;
case StorageRank::kShared:
return "shared" + tag;
case StorageRank::kWarp:
return "warp" + tag;
case StorageRank::kLocal:
return "local" + tag;
default:
LOG(FATAL) << "unknown storage scope";
return "";
}
}
/*!
......@@ -80,7 +90,7 @@ struct StorageScope {
*/
static StorageScope make(const std::string& s) {
StorageScope r;
if (s.compare(0, 6, "global") == 0) {
if (s.compare(0, 6, "global") == 0) {
r.rank = StorageRank::kGlobal;
r.tag = s.substr(6, std::string::npos);
} else if (s.compare(0, 6, "shared") == 0) {
......@@ -129,7 +139,6 @@ struct ThreadScope {
}
};
/*! \brief workload speccification */
struct ThreadWorkLoad {
// array, first three are thread configuration.
......@@ -138,22 +147,17 @@ struct ThreadWorkLoad {
* \param i The block dimension.
* \return i-th block dim
*/
inline size_t block_dim(size_t i) const {
return work_size[i + 3];
}
inline size_t block_dim(size_t i) const { return work_size[i + 3]; }
/*!
* \param i The grid dimension.
* \return i-th grid dim
*/
inline size_t grid_dim(size_t i) const {
return work_size[i];
}
inline size_t grid_dim(size_t i) const { return work_size[i]; }
};
/*! \brief Thread axis configuration */
class ThreadAxisConfig {
public:
void Init(size_t base,
const std::vector<std::string>& thread_axis_tags) {
void Init(size_t base, const std::vector<std::string>& thread_axis_tags) {
base_ = base;
std::vector<bool> filled(6, false);
for (size_t i = 0; i < thread_axis_tags.size(); ++i) {
......@@ -180,9 +184,7 @@ class ThreadAxisConfig {
return w;
}
// return the work dim
size_t work_dim() const {
return work_dim_;
}
size_t work_dim() const { return work_dim_; }
private:
/*! \brief base axis */
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment