Unverified Commit 401e1278 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] clang-format auto fix. (#4811)



* [Misc] clang-format auto fix.

* fix

* manual
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 6c53f351
......@@ -11,7 +11,7 @@
#include <ws2tcpip.h>
#pragma comment(lib, "Ws2_32.lib")
#else // !_WIN32
#else // !_WIN32
#include <sys/socket.h>
#endif // _WIN32
#include <string>
......@@ -20,7 +20,7 @@ namespace dgl {
namespace network {
/*!
* \brief TCPSocket is a simple wrapper around a socket.
* \brief TCPSocket is a simple wrapper around a socket.
* It supports only TCP connections.
*/
class TCPSocket {
......@@ -32,7 +32,7 @@ class TCPSocket {
/*!
* \brief TCPSocket deconstructor
*/
*/
~TCPSocket();
/*!
......@@ -41,7 +41,7 @@ class TCPSocket {
* \param port end port
* \return true for success and false for failure
*/
bool Connect(const char * ip, int port);
bool Connect(const char* ip, int port);
/*!
* \brief Bind on the given IP and PORT
......@@ -49,7 +49,7 @@ class TCPSocket {
* \param port end port
* \return true for success and false for failure
*/
bool Bind(const char * ip, int port);
bool Bind(const char* ip, int port);
/*!
* \brief listen for remote connection
......@@ -65,9 +65,7 @@ class TCPSocket {
* \param port_client new PORT will be stored to port_client
* \return true for success and false for failure
*/
bool Accept(TCPSocket * socket,
std::string * ip_client,
int * port_client);
bool Accept(TCPSocket* socket, std::string* ip_client, int* port_client);
/*!
* \brief SetNonBlocking() is needed refering to this example of epoll:
......@@ -103,27 +101,27 @@ class TCPSocket {
* \param data data for sending
* \param len_data length of data
* \return return number of bytes sent if OK, -1 on error
*/
int64_t Send(const char * data, int64_t len_data);
*/
int64_t Send(const char* data, int64_t len_data);
/*!
* \brief Receive data.
* \param buffer buffer for receving
* \param size_buffer size of buffer
* \return return number of bytes received if OK, -1 on error
*/
int64_t Receive(char * buffer, int64_t size_buffer);
*/
int64_t Receive(char* buffer, int64_t size_buffer);
/*!
* \brief Get socket's file descriptor
* \return socket's file descriptor
*/
*/
int Socket() const;
private:
/*!
* \brief socket's file descriptor
*/
*/
int socket_;
};
......
This diff is collapsed.
......@@ -6,24 +6,25 @@
#ifndef DGL_RPC_RPC_H_
#define DGL_RPC_RPC_H_
#include <dgl/runtime/object.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/runtime/object.h>
#include <dgl/zerocopy_serializer.h>
#include <dmlc/thread_local.h>
#include <cstdint>
#include <memory>
#include <deque>
#include <vector>
#include <string>
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
#include <vector>
#include "./network/common.h"
#include "./rpc_msg.h"
#include "./server_state.h"
#include "net_type.h"
#include "network/socket_communicator.h"
#include "tensorpipe/tp_communicator.h"
#include "./network/common.h"
#include "./server_state.h"
namespace dgl {
namespace rpc {
......@@ -138,7 +139,7 @@ struct RPCContext {
}
int32_t RegisterClient(int32_t client_id, int32_t group_id) {
auto &&m = clients_[group_id];
auto&& m = clients_[group_id];
if (m.find(client_id) != m.end()) {
return -1;
}
......@@ -150,7 +151,7 @@ struct RPCContext {
if (clients_.find(group_id) == clients_.end()) {
return -1;
}
const auto &m = clients_.at(group_id);
const auto& m = clients_.at(group_id);
if (m.find(client_id) == m.end()) {
return -1;
}
......
......@@ -6,8 +6,8 @@
#ifndef DGL_RPC_RPC_MSG_H_
#define DGL_RPC_RPC_MSG_H_
#include <dgl/runtime/object.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/runtime/object.h>
#include <dgl/zerocopy_serializer.h>
#include <string>
......
......@@ -7,11 +7,12 @@
#ifndef DGL_RPC_SERVER_STATE_H_
#define DGL_RPC_SERVER_STATE_H_
#include <dgl/runtime/object.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/base_heterograph.h>
#include <unordered_map>
#include <dgl/runtime/ndarray.h>
#include <dgl/runtime/object.h>
#include <string>
#include <unordered_map>
namespace dgl {
namespace rpc {
......
......@@ -9,10 +9,11 @@
#define DGL_RPC_TENSORPIPE_QUEUE_H_
#include <dmlc/logging.h>
#include <chrono>
#include <condition_variable>
#include <deque>
#include <mutex>
#include <chrono>
#include <utility>
namespace dgl {
......@@ -39,8 +40,9 @@ class Queue {
DLOG(WARNING) << "Will wait infinitely until message is popped...";
cv_.wait(lock, [this] { return items_.size() > 0; });
} else {
if (!cv_.wait_for(lock, std::chrono::milliseconds(timeout),
[this] { return items_.size() > 0; })) {
if (!cv_.wait_for(lock, std::chrono::milliseconds(timeout), [this] {
return items_.size() > 0;
})) {
DLOG(WARNING) << "Times out for popping message after " << timeout
<< " milliseconds.";
return false;
......
......@@ -48,8 +48,8 @@ void TPSender::Send(const RPCMessage &msg, int recv_id) {
StreamWithBuffer zc_write_strm(zerocopy_blob_ptr, true);
zc_write_strm.Write(msg);
int32_t nonempty_ndarray_count = zc_write_strm.buffer_list().size();
zerocopy_blob_ptr->append(reinterpret_cast<char *>(&nonempty_ndarray_count),
sizeof(int32_t));
zerocopy_blob_ptr->append(
reinterpret_cast<char *>(&nonempty_ndarray_count), sizeof(int32_t));
tp_msg.tensors.resize(nonempty_ndarray_count);
// Hold the NDArray that ensure it's valid until write operation completes
auto ndarray_holder = std::make_shared<std::vector<NDArray>>();
......@@ -68,14 +68,14 @@ void TPSender::Send(const RPCMessage &msg, int recv_id) {
}
// Let's write blockingly in case of congestion in underlying transports.
auto done = std::make_shared<std::promise<void>>();
pipe->write(tp_msg,
[ndarray_holder, recv_id, done](const tensorpipe::Error &error) {
if (error) {
LOG(FATAL) << "Failed to send message to " << recv_id
<< ". Details: " << error.what();
}
done->set_value();
});
pipe->write(
tp_msg, [ndarray_holder, recv_id, done](const tensorpipe::Error &error) {
if (error) {
LOG(FATAL) << "Failed to send message to " << recv_id
<< ". Details: " << error.what();
}
done->set_value();
});
done->get_future().wait();
}
......@@ -120,7 +120,8 @@ void TPReceiver::OnAccepted(const Error &error, std::shared_ptr<Pipe> pipe) {
if (error.isOfType<ListenerClosedError>()) {
// Expected.
} else {
LOG(WARNING) << "Unexpected error when accepting incoming pipe: " << error.what();
LOG(WARNING) << "Unexpected error when accepting incoming pipe: "
<< error.what();
}
return;
}
......@@ -133,7 +134,8 @@ void TPReceiver::OnAccepted(const Error &error, std::shared_ptr<Pipe> pipe) {
// read the handshake message: "dglconnect"
pipe->readDescriptor([pipe, this](const Error &error, Descriptor descriptor) {
if (error) {
LOG(ERROR) << "Unexpected error when reading from accepted pipe: " << error.what();
LOG(ERROR) << "Unexpected error when reading from accepted pipe: "
<< error.what();
return;
}
Allocation allocation;
......@@ -145,10 +147,10 @@ void TPReceiver::OnAccepted(const Error &error, std::shared_ptr<Pipe> pipe) {
});
}
void TPReceiver::ReceiveFromPipe(std::shared_ptr<Pipe> pipe,
std::shared_ptr<RPCMessageQueue> queue) {
pipe->readDescriptor([pipe, queue = std::move(queue)](const Error &error,
Descriptor descriptor) {
void TPReceiver::ReceiveFromPipe(
std::shared_ptr<Pipe> pipe, std::shared_ptr<RPCMessageQueue> queue) {
pipe->readDescriptor([pipe, queue = std::move(queue)](
const Error &error, Descriptor descriptor) {
if (error) {
// Error may happen when the pipe is closed
return;
......@@ -165,31 +167,33 @@ void TPReceiver::ReceiveFromPipe(std::shared_ptr<Pipe> pipe,
allocation.tensors[i].buffer = cpu_buffer;
}
}
pipe->read(allocation, [allocation, descriptor = std::move(descriptor),
queue = std::move(queue),
pipe](const Error &error) {
if (error) {
// Because we always have a read event posted to the epoll,
// Therefore when pipe is closed, error will be raised.
// But this error is expected.
// Other error is not expected. But we cannot identify the error with
// each Other for now. Thus here we skip handling for all errors
return;
}
char *meta_msg_begin = const_cast<char *>(&descriptor.metadata[0]);
std::vector<void *> buffer_list(descriptor.tensors.size());
for (size_t i = 0; i < descriptor.tensors.size(); i++) {
buffer_list[i] = allocation.tensors[i].buffer.unwrap<CpuBuffer>().ptr;
}
StreamWithBuffer zc_read_strm(
meta_msg_begin, descriptor.metadata.size() - sizeof(int32_t),
buffer_list);
RPCMessage msg;
zc_read_strm.Read(&msg);
queue->push(msg);
TPReceiver::ReceiveFromPipe(pipe, queue);
});
pipe->read(
allocation, [allocation, descriptor = std::move(descriptor),
queue = std::move(queue), pipe](const Error &error) {
if (error) {
// Because we always have a read event posted to the epoll,
// Therefore when pipe is closed, error will be raised.
// But this error is expected.
// Other error is not expected. But we cannot identify the error
// with each Other for now. Thus here we skip handling for all
// errors
return;
}
char *meta_msg_begin = const_cast<char *>(&descriptor.metadata[0]);
std::vector<void *> buffer_list(descriptor.tensors.size());
for (size_t i = 0; i < descriptor.tensors.size(); i++) {
buffer_list[i] =
allocation.tensors[i].buffer.unwrap<CpuBuffer>().ptr;
}
StreamWithBuffer zc_read_strm(
meta_msg_begin, descriptor.metadata.size() - sizeof(int32_t),
buffer_list);
RPCMessage msg;
zc_read_strm.Read(&msg);
queue->push(msg);
TPReceiver::ReceiveFromPipe(pipe, queue);
});
});
}
......
......@@ -9,15 +9,16 @@
#include <dmlc/logging.h>
#include <tensorpipe/tensorpipe.h>
#include <atomic>
#include <deque>
#include <memory>
#include <string>
#include <thread>
#include <unordered_map>
#include <vector>
#include <atomic>
#include "./queue.h"
#include "../net_type.h"
#include "./queue.h"
namespace dgl {
namespace rpc {
......@@ -47,11 +48,12 @@ class TPSender : public RPCSender {
/*!
* \brief Connect to a receiver.
*
* When there are multiple receivers to be connected, application will call `ConnectReceiver`
* for each and then call `ConnectReceiverFinalize` to make sure that either all the connections are
* successfully established or some of them fail.
*
*
* When there are multiple receivers to be connected, application will call
* `ConnectReceiver` for each and then call `ConnectReceiverFinalize` to make
* sure that either all the connections are successfully established or some
* of them fail.
*
* \param addr Networking address, e.g., 'tcp://127.0.0.1:50091'
* \param recv_id receiver's ID
* \return True for success and False for fail
......@@ -75,7 +77,7 @@ class TPSender : public RPCSender {
/*!
* \brief Communicator type: 'tp'
*/
const std::string &NetType() const override {
const std::string& NetType() const override {
static const std::string net_type = "tensorpipe";
return net_type;
}
......@@ -90,7 +92,7 @@ class TPSender : public RPCSender {
* \brief pipe for each connection of receiver
*/
std::unordered_map<int /* receiver ID */, std::shared_ptr<tensorpipe::Pipe>>
pipes_;
pipes_;
/*!
* \brief receivers' listening address
......@@ -129,13 +131,14 @@ class TPReceiver : public RPCReceiver {
*
* Wait() is not thread-safe and only one thread can invoke this API.
*/
bool Wait(const std::string &addr, int num_sender,
bool blocking = true) override;
bool Wait(
const std::string& addr, int num_sender, bool blocking = true) override;
/*!
* \brief Recv RPCMessage from Sender. Actually removing data from queue.
* \param msg pointer of RPCmessage
* \param timeout The timeout value in milliseconds. If zero, wait indefinitely.
* \param timeout The timeout value in milliseconds. If zero, wait
* indefinitely.
* \return RPCStatus: kRPCSuccess or kRPCTimeOut.
*/
RPCStatus Recv(RPCMessage* msg, int timeout) override;
......@@ -150,7 +153,7 @@ class TPReceiver : public RPCReceiver {
/*!
* \brief Communicator type: 'tp' (tensorpipe)
*/
const std::string &NetType() const override {
const std::string& NetType() const override {
static const std::string net_type = "tensorpipe";
return net_type;
}
......@@ -158,8 +161,9 @@ class TPReceiver : public RPCReceiver {
/*!
* \brief Issue a receive request on pipe, and push the result into queue
*/
static void ReceiveFromPipe(std::shared_ptr<tensorpipe::Pipe> pipe,
std::shared_ptr<RPCMessageQueue> queue);
static void ReceiveFromPipe(
std::shared_ptr<tensorpipe::Pipe> pipe,
std::shared_ptr<RPCMessageQueue> queue);
private:
/*!
......@@ -186,9 +190,9 @@ class TPReceiver : public RPCReceiver {
/*!
* \brief pipe for each client connections
*/
std::unordered_map<int /* Sender (virutal) ID */,
std::shared_ptr<tensorpipe::Pipe>>
pipes_;
std::unordered_map<
int /* Sender (virutal) ID */, std::shared_ptr<tensorpipe::Pipe>>
pipes_;
/*!
* \brief RPCMessage queue
......
......@@ -3,16 +3,18 @@
* Implementation of C API (reference: tvm/src/api/c_api.cc)
* \file c_api.cc
*/
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/c_object_api.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/runtime/object.h>
#include <dmlc/base.h>
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <dgl/runtime/c_object_api.h>
#include <dgl/runtime/object.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/packed_func_ext.h>
#include <vector>
#include <string>
#include <exception>
#include <string>
#include <vector>
#include "runtime_base.h"
/*! \brief entry to to easily hold returning information */
......@@ -20,7 +22,7 @@ struct DGLAPIThreadLocalEntry {
/*! \brief result holder for returning strings */
std::vector<std::string> ret_vec_str;
/*! \brief result holder for returning string pointers */
std::vector<const char *> ret_vec_charp;
std::vector<const char*> ret_vec_charp;
/*! \brief result holder for retruning string */
std::string ret_str;
};
......@@ -44,7 +46,8 @@ struct APIAttrGetter : public AttrVisitor {
if (skey == key) *ret = value[0];
}
void Visit(const char* key, uint64_t* value) final {
CHECK_LE(value[0], static_cast<uint64_t>(std::numeric_limits<int64_t>::max()))
CHECK_LE(
value[0], static_cast<uint64_t>(std::numeric_limits<int64_t>::max()))
<< "cannot return too big constant";
if (skey == key) *ret = static_cast<int64_t>(value[0]);
}
......@@ -71,30 +74,16 @@ struct APIAttrGetter : public AttrVisitor {
struct APIAttrDir : public AttrVisitor {
std::vector<std::string>* names;
void Visit(const char* key, double* value) final {
names->push_back(key);
}
void Visit(const char* key, int64_t* value) final {
names->push_back(key);
}
void Visit(const char* key, uint64_t* value) final {
names->push_back(key);
}
void Visit(const char* key, bool* value) final {
names->push_back(key);
}
void Visit(const char* key, int* value) final {
names->push_back(key);
}
void Visit(const char* key, double* value) final { names->push_back(key); }
void Visit(const char* key, int64_t* value) final { names->push_back(key); }
void Visit(const char* key, uint64_t* value) final { names->push_back(key); }
void Visit(const char* key, bool* value) final { names->push_back(key); }
void Visit(const char* key, int* value) final { names->push_back(key); }
void Visit(const char* key, std::string* value) final {
names->push_back(key);
}
void Visit(const char* key, ObjectRef* value) final {
names->push_back(key);
}
void Visit(const char* key, NDArray* value) final {
names->push_back(key);
}
void Visit(const char* key, ObjectRef* value) final { names->push_back(key); }
void Visit(const char* key, NDArray* value) final { names->push_back(key); }
};
int DGLObjectFree(ObjectHandle handle) {
......@@ -103,26 +92,22 @@ int DGLObjectFree(ObjectHandle handle) {
API_END();
}
int DGLObjectTypeKey2Index(const char* type_key,
int* out_index) {
int DGLObjectTypeKey2Index(const char* type_key, int* out_index) {
API_BEGIN();
*out_index = static_cast<int>(Object::TypeKey2Index(type_key));
API_END();
}
int DGLObjectGetTypeIndex(ObjectHandle handle,
int* out_index) {
int DGLObjectGetTypeIndex(ObjectHandle handle, int* out_index) {
API_BEGIN();
*out_index = static_cast<int>(
(*static_cast<DGLAPIObject*>(handle))->type_index());
*out_index =
static_cast<int>((*static_cast<DGLAPIObject*>(handle))->type_index());
API_END();
}
int DGLObjectGetAttr(ObjectHandle handle,
const char* key,
DGLValue* ret_val,
int* ret_type_code,
int* ret_success) {
int DGLObjectGetAttr(
ObjectHandle handle, const char* key, DGLValue* ret_val, int* ret_type_code,
int* ret_success) {
API_BEGIN();
DGLRetValue rv;
APIAttrGetter getter;
......@@ -136,9 +121,8 @@ int DGLObjectGetAttr(ObjectHandle handle,
} else {
(*tobject)->VisitAttrs(&getter);
*ret_success = getter.found_object_ref || rv.type_code() != kNull;
if (rv.type_code() == kStr ||
rv.type_code() == kDGLDataType) {
DGLAPIThreadLocalEntry *e = DGLAPIThreadLocalStore::Get();
if (rv.type_code() == kStr || rv.type_code() == kDGLDataType) {
DGLAPIThreadLocalEntry* e = DGLAPIThreadLocalStore::Get();
e->ret_str = rv.operator std::string();
*ret_type_code = kStr;
ret_val->v_str = e->ret_str.c_str();
......@@ -149,10 +133,9 @@ int DGLObjectGetAttr(ObjectHandle handle,
API_END();
}
int DGLObjectListAttrNames(ObjectHandle handle,
int *out_size,
const char*** out_array) {
DGLAPIThreadLocalEntry *ret = DGLAPIThreadLocalStore::Get();
int DGLObjectListAttrNames(
ObjectHandle handle, int* out_size, const char*** out_array) {
DGLAPIThreadLocalEntry* ret = DGLAPIThreadLocalStore::Get();
API_BEGIN();
ret->ret_vec_str.clear();
DGLAPIObject* tobject = static_cast<DGLAPIObject*>(handle);
......
......@@ -3,18 +3,20 @@
* \file c_runtime_api.cc
* \brief Runtime API implementation
*/
#include <dmlc/thread_local.h>
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/c_backend_api.h>
#include <dgl/runtime/packed_func.h>
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/device_api.h>
#include <dgl/runtime/module.h>
#include <dgl/runtime/packed_func.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/device_api.h>
#include <dgl/runtime/tensordispatch.h>
#include <array>
#include <dmlc/thread_local.h>
#include <algorithm>
#include <string>
#include <array>
#include <cstdlib>
#include <string>
#include "runtime_base.h"
namespace dgl {
......@@ -26,10 +28,14 @@ namespace runtime {
*/
inline std::string DeviceName(int type) {
switch (type) {
case kDGLCPU: return "cpu";
case kDGLCUDA: return "cuda";
case kDGLCPU:
return "cpu";
case kDGLCUDA:
return "cuda";
// add more device here once supported
default: LOG(FATAL) << "unknown type =" << type; return "Unknown";
default:
LOG(FATAL) << "unknown type =" << type;
return "Unknown";
}
}
......@@ -37,9 +43,7 @@ class DeviceAPIManager {
public:
static const int kMaxDeviceAPI = 32;
// Get API
static DeviceAPI* Get(const DGLContext& ctx) {
return Get(ctx.device_type);
}
static DeviceAPI* Get(const DGLContext& ctx) { return Get(ctx.device_type); }
static DeviceAPI* Get(int dev_type, bool allow_missing = false) {
return Global()->GetAPI(dev_type, allow_missing);
}
......@@ -49,9 +53,7 @@ class DeviceAPIManager {
DeviceAPI* rpc_api_{nullptr};
std::mutex mutex_;
// constructor
DeviceAPIManager() {
std::fill(api_.begin(), api_.end(), nullptr);
}
DeviceAPIManager() { std::fill(api_.begin(), api_.end(), nullptr); }
// Global static variable.
static DeviceAPIManager* Global() {
static DeviceAPIManager inst;
......@@ -78,7 +80,8 @@ class DeviceAPIManager {
auto* f = Registry::Get(factory);
if (f == nullptr) {
CHECK(allow_missing)
<< "Device API " << name << " is not enabled. Please install the cuda version of dgl.";
<< "Device API " << name
<< " is not enabled. Please install the cuda version of dgl.";
return nullptr;
}
void* ptr = (*f)();
......@@ -95,9 +98,8 @@ DeviceAPI* DeviceAPI::Get(DGLDeviceType dev_type, bool allow_missing) {
return DeviceAPIManager::Get(static_cast<int>(dev_type), allow_missing);
}
void* DeviceAPI::AllocWorkspace(DGLContext ctx,
size_t size,
DGLDataType type_hint) {
void* DeviceAPI::AllocWorkspace(
DGLContext ctx, size_t size, DGLDataType type_hint) {
return AllocDataSpace(ctx, size, kTempAllocaAlignment, type_hint);
}
......@@ -114,9 +116,8 @@ void DeviceAPI::FreeStream(DGLContext ctx, DGLStreamHandle stream) {
LOG(FATAL) << "Device does not support stream api.";
}
void DeviceAPI::SyncStreamFromTo(DGLContext ctx,
DGLStreamHandle event_src,
DGLStreamHandle event_dst) {
void DeviceAPI::SyncStreamFromTo(
DGLContext ctx, DGLStreamHandle event_src, DGLStreamHandle event_dst) {
LOG(FATAL) << "Device does not support stream api.";
}
......@@ -140,7 +141,7 @@ struct DGLRuntimeEntry {
typedef dmlc::ThreadLocalStore<DGLRuntimeEntry> DGLAPIRuntimeStore;
const char *DGLGetLastError() {
const char* DGLGetLastError() {
return DGLAPIRuntimeStore::Get()->last_error.c_str();
}
......@@ -152,30 +153,26 @@ void DGLAPISetLastError(const char* msg) {
#endif
}
int DGLModLoadFromFile(const char* file_name,
const char* format,
DGLModuleHandle* out) {
int DGLModLoadFromFile(
const char* file_name, const char* format, DGLModuleHandle* out) {
API_BEGIN();
Module m = Module::LoadFromFile(file_name, format);
*out = new Module(m);
API_END();
}
int DGLModImport(DGLModuleHandle mod,
DGLModuleHandle dep) {
int DGLModImport(DGLModuleHandle mod, DGLModuleHandle dep) {
API_BEGIN();
static_cast<Module*>(mod)->Import(
*static_cast<Module*>(dep));
static_cast<Module*>(mod)->Import(*static_cast<Module*>(dep));
API_END();
}
int DGLModGetFunction(DGLModuleHandle mod,
const char* func_name,
int query_imports,
DGLFunctionHandle *func) {
int DGLModGetFunction(
DGLModuleHandle mod, const char* func_name, int query_imports,
DGLFunctionHandle* func) {
API_BEGIN();
PackedFunc pf = static_cast<Module*>(mod)->GetFunction(
func_name, query_imports != 0);
PackedFunc pf =
static_cast<Module*>(mod)->GetFunction(func_name, query_imports != 0);
if (pf != nullptr) {
*func = new PackedFunc(pf);
} else {
......@@ -190,20 +187,18 @@ int DGLModFree(DGLModuleHandle mod) {
API_END();
}
int DGLBackendGetFuncFromEnv(void* mod_node,
const char* func_name,
DGLFunctionHandle *func) {
int DGLBackendGetFuncFromEnv(
void* mod_node, const char* func_name, DGLFunctionHandle* func) {
API_BEGIN();
*func = (DGLFunctionHandle)(
static_cast<ModuleNode*>(mod_node)->GetFuncFromEnv(func_name));
*func =
(DGLFunctionHandle)(static_cast<ModuleNode*>(mod_node)->GetFuncFromEnv(
func_name));
API_END();
}
void* DGLBackendAllocWorkspace(int device_type,
int device_id,
uint64_t size,
int dtype_code_hint,
int dtype_bits_hint) {
void* DGLBackendAllocWorkspace(
int device_type, int device_id, uint64_t size, int dtype_code_hint,
int dtype_bits_hint) {
DGLContext ctx;
ctx.device_type = static_cast<DGLDeviceType>(device_type);
ctx.device_id = device_id;
......@@ -213,14 +208,11 @@ void* DGLBackendAllocWorkspace(int device_type,
type_hint.bits = static_cast<decltype(type_hint.bits)>(dtype_bits_hint);
type_hint.lanes = 1;
return DeviceAPIManager::Get(ctx)->AllocWorkspace(ctx,
static_cast<size_t>(size),
type_hint);
return DeviceAPIManager::Get(ctx)->AllocWorkspace(
ctx, static_cast<size_t>(size), type_hint);
}
int DGLBackendFreeWorkspace(int device_type,
int device_id,
void* ptr) {
int DGLBackendFreeWorkspace(int device_type, int device_id, void* ptr) {
DGLContext ctx;
ctx.device_type = static_cast<DGLDeviceType>(device_type);
ctx.device_id = device_id;
......@@ -228,10 +220,7 @@ int DGLBackendFreeWorkspace(int device_type,
return 0;
}
int DGLBackendRunOnce(void** handle,
int (*f)(void*),
void* cdata,
int nbytes) {
int DGLBackendRunOnce(void** handle, int (*f)(void*), void* cdata, int nbytes) {
if (*handle == nullptr) {
*handle = reinterpret_cast<void*>(1);
return (*f)(cdata);
......@@ -245,19 +234,15 @@ int DGLFuncFree(DGLFunctionHandle func) {
API_END();
}
int DGLFuncCall(DGLFunctionHandle func,
DGLValue* args,
int* arg_type_codes,
int num_args,
DGLValue* ret_val,
int* ret_type_code) {
int DGLFuncCall(
DGLFunctionHandle func, DGLValue* args, int* arg_type_codes, int num_args,
DGLValue* ret_val, int* ret_type_code) {
API_BEGIN();
DGLRetValue rv;
(*static_cast<const PackedFunc*>(func)).CallPacked(
DGLArgs(args, arg_type_codes, num_args), &rv);
(*static_cast<const PackedFunc*>(func))
.CallPacked(DGLArgs(args, arg_type_codes, num_args), &rv);
// handle return string.
if (rv.type_code() == kStr ||
rv.type_code() == kDGLDataType ||
if (rv.type_code() == kStr || rv.type_code() == kDGLDataType ||
rv.type_code() == kBytes) {
DGLRuntimeEntry* e = DGLAPIRuntimeStore::Get();
if (rv.type_code() != kDGLDataType) {
......@@ -280,10 +265,8 @@ int DGLFuncCall(DGLFunctionHandle func,
API_END();
}
int DGLCFuncSetReturn(DGLRetValueHandle ret,
DGLValue* value,
int* type_code,
int num_ret) {
int DGLCFuncSetReturn(
DGLRetValueHandle ret, DGLValue* value, int* type_code, int num_ret) {
API_BEGIN();
CHECK_EQ(num_ret, 1);
DGLRetValue* rv = static_cast<DGLRetValue*>(ret);
......@@ -291,16 +274,16 @@ int DGLCFuncSetReturn(DGLRetValueHandle ret,
API_END();
}
int DGLFuncCreateFromCFunc(DGLPackedCFunc func,
void* resource_handle,
DGLPackedCFuncFinalizer fin,
DGLFunctionHandle *out) {
int DGLFuncCreateFromCFunc(
DGLPackedCFunc func, void* resource_handle, DGLPackedCFuncFinalizer fin,
DGLFunctionHandle* out) {
API_BEGIN();
if (fin == nullptr) {
*out = new PackedFunc(
[func, resource_handle](DGLArgs args, DGLRetValue* rv) {
int ret = func((DGLValue*)args.values, (int*)args.type_codes, // NOLINT(*)
args.num_args, rv, resource_handle);
*out =
new PackedFunc([func, resource_handle](DGLArgs args, DGLRetValue* rv) {
int ret = func(
(DGLValue*)args.values, (int*)args.type_codes, // NOLINT(*)
args.num_args, rv, resource_handle);
if (ret != 0) {
std::string err = "DGLCall CFunc Error:\n";
err += DGLGetLastError();
......@@ -311,16 +294,16 @@ int DGLFuncCreateFromCFunc(DGLPackedCFunc func,
// wrap it in a shared_ptr, with fin as deleter.
// so fin will be called when the lambda went out of scope.
std::shared_ptr<void> rpack(resource_handle, fin);
*out = new PackedFunc(
[func, rpack](DGLArgs args, DGLRetValue* rv) {
int ret = func((DGLValue*)args.values, (int*)args.type_codes, // NOLINT(*)
args.num_args, rv, rpack.get());
if (ret != 0) {
std::string err = "DGLCall CFunc Error:\n";
err += DGLGetLastError();
throw dmlc::Error(err);
}
});
*out = new PackedFunc([func, rpack](DGLArgs args, DGLRetValue* rv) {
int ret = func(
(DGLValue*)args.values, (int*)args.type_codes, // NOLINT(*)
args.num_args, rv, rpack.get());
if (ret != 0) {
std::string err = "DGLCall CFunc Error:\n";
err += DGLGetLastError();
throw dmlc::Error(err);
}
});
}
API_END();
}
......@@ -370,10 +353,8 @@ int DGLSynchronize(int device_type, int device_id, DGLStreamHandle stream) {
API_END();
}
int DGLStreamStreamSynchronize(int device_type,
int device_id,
DGLStreamHandle src,
DGLStreamHandle dst) {
int DGLStreamStreamSynchronize(
int device_type, int device_id, DGLStreamHandle src, DGLStreamHandle dst) {
API_BEGIN();
DGLContext ctx;
ctx.device_type = static_cast<DGLDeviceType>(device_type);
......@@ -392,36 +373,35 @@ int DGLCbArgToReturn(DGLValue* value, int code) {
API_END();
}
int DGLLoadTensorAdapter(const char *path) {
int DGLLoadTensorAdapter(const char* path) {
return TensorDispatcher::Global()->Load(path) ? 0 : -1;
}
// set device api
DGL_REGISTER_GLOBAL(dgl::runtime::symbol::dgl_set_device)
.set_body([](DGLArgs args, DGLRetValue *ret) {
DGLContext ctx;
ctx.device_type = static_cast<DGLDeviceType>(args[0].operator int());
ctx.device_id = args[1];
DeviceAPIManager::Get(ctx)->SetDevice(ctx);
});
.set_body([](DGLArgs args, DGLRetValue* ret) {
DGLContext ctx;
ctx.device_type = static_cast<DGLDeviceType>(args[0].operator int());
ctx.device_id = args[1];
DeviceAPIManager::Get(ctx)->SetDevice(ctx);
});
// set device api
DGL_REGISTER_GLOBAL("_GetDeviceAttr")
.set_body([](DGLArgs args, DGLRetValue *ret) {
DGLContext ctx;
ctx.device_type = static_cast<DGLDeviceType>(args[0].operator int());
ctx.device_id = args[1];
DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[2].operator int());
if (kind == kExist) {
DeviceAPI* api = DeviceAPIManager::Get(ctx.device_type, true);
if (api != nullptr) {
api->GetAttr(ctx, kind, ret);
.set_body([](DGLArgs args, DGLRetValue* ret) {
DGLContext ctx;
ctx.device_type = static_cast<DGLDeviceType>(args[0].operator int());
ctx.device_id = args[1];
DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[2].operator int());
if (kind == kExist) {
DeviceAPI* api = DeviceAPIManager::Get(ctx.device_type, true);
if (api != nullptr) {
api->GetAttr(ctx, kind, ret);
} else {
*ret = 0;
}
} else {
*ret = 0;
DeviceAPIManager::Get(ctx)->GetAttr(ctx, kind, ret);
}
} else {
DeviceAPIManager::Get(ctx)->GetAttr(ctx, kind, ret);
}
});
});
......@@ -4,32 +4,28 @@
* \brief DGL runtime config
*/
#include <dgl/runtime/registry.h>
#include <dgl/runtime/config.h>
#include <dgl/runtime/registry.h>
using namespace dgl::runtime;
namespace dgl {
namespace runtime {
void Config::EnableLibxsmm(bool b) {
libxsmm_ = b;
}
void Config::EnableLibxsmm(bool b) { libxsmm_ = b; }
bool Config::IsLibxsmmAvailable() const {
return libxsmm_;
}
bool Config::IsLibxsmmAvailable() const { return libxsmm_; }
DGL_REGISTER_GLOBAL("global_config._CAPI_DGLConfigSetLibxsmm")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
bool use_libxsmm = args[0];
dgl::runtime::Config::Global()->EnableLibxsmm(use_libxsmm);
});
.set_body([](DGLArgs args, DGLRetValue* rv) {
bool use_libxsmm = args[0];
dgl::runtime::Config::Global()->EnableLibxsmm(use_libxsmm);
});
DGL_REGISTER_GLOBAL("global_config._CAPI_DGLConfigGetLibxsmm")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
*rv = dgl::runtime::Config::Global()->IsLibxsmmAvailable();
});
.set_body([](DGLArgs args, DGLRetValue* rv) {
*rv = dgl::runtime::Config::Global()->IsLibxsmmAvailable();
});
} // namespace runtime
} // namespace dgl
......@@ -2,13 +2,15 @@
* Copyright (c) 2016-2022 by Contributors
* \file cpu_device_api.cc
*/
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/device_api.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/tensordispatch.h>
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <cstdlib>
#include <cstring>
#include "workspace_pool.h"
namespace dgl {
......@@ -21,13 +23,11 @@ class CPUDeviceAPI final : public DeviceAPI {
*rv = 1;
}
}
void* AllocDataSpace(DGLContext ctx,
size_t nbytes,
size_t alignment,
DGLDataType type_hint) final {
void* AllocDataSpace(
DGLContext ctx, size_t nbytes, size_t alignment,
DGLDataType type_hint) final {
TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable())
return td->CPUAllocWorkspace(nbytes);
if (td->IsAvailable()) return td->CPUAllocWorkspace(nbytes);
void* ptr;
#if _MSC_VER || defined(__MINGW32__)
......@@ -45,8 +45,7 @@ class CPUDeviceAPI final : public DeviceAPI {
void FreeDataSpace(DGLContext ctx, void* ptr) final {
TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable())
return td->CPUFreeWorkspace(ptr);
if (td->IsAvailable()) return td->CPUFreeWorkspace(ptr);
#if _MSC_VER || defined(__MINGW32__)
_aligned_free(ptr);
......@@ -55,25 +54,21 @@ class CPUDeviceAPI final : public DeviceAPI {
#endif
}
void CopyDataFromTo(const void* from,
size_t from_offset,
void* to,
size_t to_offset,
size_t size,
DGLContext ctx_from,
DGLContext ctx_to,
DGLDataType type_hint) final {
memcpy(static_cast<char*>(to) + to_offset,
static_cast<const char*>(from) + from_offset,
size);
void CopyDataFromTo(
const void* from, size_t from_offset, void* to, size_t to_offset,
size_t size, DGLContext ctx_from, DGLContext ctx_to,
DGLDataType type_hint) final {
memcpy(
static_cast<char*>(to) + to_offset,
static_cast<const char*>(from) + from_offset, size);
}
DGLStreamHandle CreateStream(DGLContext) final { return nullptr; }
void StreamSync(DGLContext ctx, DGLStreamHandle stream) final {
}
void StreamSync(DGLContext ctx, DGLStreamHandle stream) final {}
void* AllocWorkspace(DGLContext ctx, size_t size, DGLDataType type_hint) final;
void* AllocWorkspace(
DGLContext ctx, size_t size, DGLDataType type_hint) final;
void FreeWorkspace(DGLContext ctx, void* data) final;
static const std::shared_ptr<CPUDeviceAPI>& Global() {
......@@ -84,32 +79,29 @@ class CPUDeviceAPI final : public DeviceAPI {
};
struct CPUWorkspacePool : public WorkspacePool {
CPUWorkspacePool() :
WorkspacePool(kDGLCPU, CPUDeviceAPI::Global()) {}
CPUWorkspacePool() : WorkspacePool(kDGLCPU, CPUDeviceAPI::Global()) {}
};
void* CPUDeviceAPI::AllocWorkspace(DGLContext ctx,
size_t size,
DGLDataType type_hint) {
void* CPUDeviceAPI::AllocWorkspace(
DGLContext ctx, size_t size, DGLDataType type_hint) {
TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable())
return td->CPUAllocWorkspace(size);
if (td->IsAvailable()) return td->CPUAllocWorkspace(size);
return dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()->AllocWorkspace(ctx, size);
return dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()->AllocWorkspace(
ctx, size);
}
void CPUDeviceAPI::FreeWorkspace(DGLContext ctx, void* data) {
TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable())
return td->CPUFreeWorkspace(data);
if (td->IsAvailable()) return td->CPUFreeWorkspace(data);
dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()->FreeWorkspace(ctx, data);
}
DGL_REGISTER_GLOBAL("device_api.cpu")
.set_body([](DGLArgs args, DGLRetValue* rv) {
DeviceAPI* ptr = CPUDeviceAPI::Global().get();
*rv = static_cast<void*>(ptr);
});
.set_body([](DGLArgs args, DGLRetValue* rv) {
DeviceAPI* ptr = CPUDeviceAPI::Global().get();
*rv = static_cast<void*>(ptr);
});
} // namespace runtime
} // namespace dgl
......@@ -7,11 +7,13 @@
#define DGL_RUNTIME_CUDA_CUDA_COMMON_H_
#include <cublas_v2.h>
#include <cusparse.h>
#include <cuda_runtime.h>
#include <curand.h>
#include <cusparse.h>
#include <dgl/runtime/packed_func.h>
#include <string>
#include "../workspace_pool.h"
namespace dgl {
......@@ -19,94 +21,89 @@ namespace runtime {
template <typename T>
inline bool is_zero(T size) {
return size == 0;
return size == 0;
}
template <>
inline bool is_zero<dim3>(dim3 size) {
return size.x == 0 || size.y == 0 || size.z == 0;
return size.x == 0 || size.y == 0 || size.z == 0;
}
#define CUDA_DRIVER_CALL(x) \
{ \
CUresult result = x; \
if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) { \
const char *msg; \
const char* msg; \
cuGetErrorName(result, &msg); \
LOG(FATAL) \
<< "CUDAError: " #x " failed with error: " << msg; \
LOG(FATAL) << "CUDAError: " #x " failed with error: " << msg; \
} \
}
#define CUDA_CALL(func) \
{ \
cudaError_t e = (func); \
CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \
<< "CUDA: " << cudaGetErrorString(e); \
#define CUDA_CALL(func) \
{ \
cudaError_t e = (func); \
CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \
<< "CUDA: " << cudaGetErrorString(e); \
}
#define CUDA_KERNEL_CALL(kernel, nblks, nthrs, shmem, stream, ...) \
{ \
if (!dgl::runtime::is_zero((nblks)) && \
!dgl::runtime::is_zero((nthrs))) { \
(kernel) <<< (nblks), (nthrs), (shmem), (stream) >>> \
(__VA_ARGS__); \
cudaError_t e = cudaGetLastError(); \
CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \
<< "CUDA kernel launch error: " \
<< cudaGetErrorString(e); \
} \
#define CUDA_KERNEL_CALL(kernel, nblks, nthrs, shmem, stream, ...) \
{ \
if (!dgl::runtime::is_zero((nblks)) && !dgl::runtime::is_zero((nthrs))) { \
(kernel)<<<(nblks), (nthrs), (shmem), (stream)>>>(__VA_ARGS__); \
cudaError_t e = cudaGetLastError(); \
CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \
<< "CUDA kernel launch error: " << cudaGetErrorString(e); \
} \
}
#define CUSPARSE_CALL(func) \
{ \
cusparseStatus_t e = (func); \
CHECK(e == CUSPARSE_STATUS_SUCCESS) \
<< "CUSPARSE ERROR: " << e; \
#define CUSPARSE_CALL(func) \
{ \
cusparseStatus_t e = (func); \
CHECK(e == CUSPARSE_STATUS_SUCCESS) << "CUSPARSE ERROR: " << e; \
}
#define CUBLAS_CALL(func) \
{ \
cublasStatus_t e = (func); \
CHECK(e == CUBLAS_STATUS_SUCCESS) << "CUBLAS ERROR: " << e; \
#define CUBLAS_CALL(func) \
{ \
cublasStatus_t e = (func); \
CHECK(e == CUBLAS_STATUS_SUCCESS) << "CUBLAS ERROR: " << e; \
}
#define CURAND_CALL(func) \
{ \
curandStatus_t e = (func); \
CHECK(e == CURAND_STATUS_SUCCESS) \
<< "CURAND Error: " << dgl::runtime::curandGetErrorString(e) \
<< " at " << __FILE__ << ":" << __LINE__; \
}
#define CURAND_CALL(func) \
{ \
curandStatus_t e = (func); \
CHECK(e == CURAND_STATUS_SUCCESS) \
<< "CURAND Error: " << dgl::runtime::curandGetErrorString(e) << " at " \
<< __FILE__ << ":" << __LINE__; \
}
inline const char* curandGetErrorString(curandStatus_t error) {
switch (error) {
case CURAND_STATUS_SUCCESS:
return "CURAND_STATUS_SUCCESS";
case CURAND_STATUS_VERSION_MISMATCH:
return "CURAND_STATUS_VERSION_MISMATCH";
case CURAND_STATUS_NOT_INITIALIZED:
return "CURAND_STATUS_NOT_INITIALIZED";
case CURAND_STATUS_ALLOCATION_FAILED:
return "CURAND_STATUS_ALLOCATION_FAILED";
case CURAND_STATUS_TYPE_ERROR:
return "CURAND_STATUS_TYPE_ERROR";
case CURAND_STATUS_OUT_OF_RANGE:
return "CURAND_STATUS_OUT_OF_RANGE";
case CURAND_STATUS_LENGTH_NOT_MULTIPLE:
return "CURAND_STATUS_LENGTH_NOT_MULTIPLE";
case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED:
return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED";
case CURAND_STATUS_LAUNCH_FAILURE:
return "CURAND_STATUS_LAUNCH_FAILURE";
case CURAND_STATUS_PREEXISTING_FAILURE:
return "CURAND_STATUS_PREEXISTING_FAILURE";
case CURAND_STATUS_INITIALIZATION_FAILED:
return "CURAND_STATUS_INITIALIZATION_FAILED";
case CURAND_STATUS_ARCH_MISMATCH:
return "CURAND_STATUS_ARCH_MISMATCH";
case CURAND_STATUS_INTERNAL_ERROR:
return "CURAND_STATUS_INTERNAL_ERROR";
case CURAND_STATUS_SUCCESS:
return "CURAND_STATUS_SUCCESS";
case CURAND_STATUS_VERSION_MISMATCH:
return "CURAND_STATUS_VERSION_MISMATCH";
case CURAND_STATUS_NOT_INITIALIZED:
return "CURAND_STATUS_NOT_INITIALIZED";
case CURAND_STATUS_ALLOCATION_FAILED:
return "CURAND_STATUS_ALLOCATION_FAILED";
case CURAND_STATUS_TYPE_ERROR:
return "CURAND_STATUS_TYPE_ERROR";
case CURAND_STATUS_OUT_OF_RANGE:
return "CURAND_STATUS_OUT_OF_RANGE";
case CURAND_STATUS_LENGTH_NOT_MULTIPLE:
return "CURAND_STATUS_LENGTH_NOT_MULTIPLE";
case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED:
return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED";
case CURAND_STATUS_LAUNCH_FAILURE:
return "CURAND_STATUS_LAUNCH_FAILURE";
case CURAND_STATUS_PREEXISTING_FAILURE:
return "CURAND_STATUS_PREEXISTING_FAILURE";
case CURAND_STATUS_INITIALIZATION_FAILED:
return "CURAND_STATUS_INITIALIZATION_FAILED";
case CURAND_STATUS_ARCH_MISMATCH:
return "CURAND_STATUS_ARCH_MISMATCH";
case CURAND_STATUS_INTERNAL_ERROR:
return "CURAND_STATUS_INTERNAL_ERROR";
}
// To suppress compiler warning.
return "Unrecognized curand error string";
......
......@@ -3,11 +3,12 @@
* \file cuda_device_api.cc
* \brief GPU specific API
*/
#include <cuda_runtime.h>
#include <dgl/runtime/device_api.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/tensordispatch.h>
#include <dmlc/thread_local.h>
#include <dgl/runtime/registry.h>
#include <cuda_runtime.h>
#include "cuda_common.h"
namespace dgl {
......@@ -28,9 +29,7 @@ class CUDADeviceAPI final : public DeviceAPI {
is_available_ = count > 0;
}
bool IsAvailable() final {
return is_available_;
}
bool IsAvailable() final { return is_available_; }
void SetDevice(DGLContext ctx) final {
CUDA_CALL(cudaSetDevice(ctx.device_id));
......@@ -39,10 +38,10 @@ class CUDADeviceAPI final : public DeviceAPI {
int value = 0;
switch (kind) {
case kExist:
value = (
cudaDeviceGetAttribute(
&value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id)
== cudaSuccess);
value =
(cudaDeviceGetAttribute(
&value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id) ==
cudaSuccess);
break;
case kMaxThreadsPerBlock: {
CUDA_CALL(cudaDeviceGetAttribute(
......@@ -50,8 +49,8 @@ class CUDADeviceAPI final : public DeviceAPI {
break;
}
case kWarpSize: {
CUDA_CALL(cudaDeviceGetAttribute(
&value, cudaDevAttrWarpSize, ctx.device_id));
CUDA_CALL(
cudaDeviceGetAttribute(&value, cudaDevAttrWarpSize, ctx.device_id));
break;
}
case kMaxSharedMemoryPerBlock: {
......@@ -96,26 +95,24 @@ class CUDADeviceAPI final : public DeviceAPI {
&dims[2], cudaDevAttrMaxBlockDimZ, ctx.device_id));
std::stringstream ss; // use json string to return multiple int values;
ss << "[" << dims[0] <<", " << dims[1] << ", " << dims[2] << "]";
ss << "[" << dims[0] << ", " << dims[1] << ", " << dims[2] << "]";
*rv = ss.str();
return;
}
}
*rv = value;
}
void* AllocDataSpace(DGLContext ctx,
size_t nbytes,
size_t alignment,
DGLDataType type_hint) final {
void* AllocDataSpace(
DGLContext ctx, size_t nbytes, size_t alignment,
DGLDataType type_hint) final {
SetDevice(ctx);
// Redirect to PyTorch's allocator when available.
TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable())
return td->CUDAAllocWorkspace(nbytes, getCurrentCUDAStream());
CHECK_EQ(256 % alignment, 0U)
<< "CUDA space is aligned at 256 bytes";
void *ret;
CHECK_EQ(256 % alignment, 0U) << "CUDA space is aligned at 256 bytes";
void* ret;
CUDA_CALL(cudaMalloc(&ret, nbytes));
return ret;
}
......@@ -123,21 +120,15 @@ class CUDADeviceAPI final : public DeviceAPI {
void FreeDataSpace(DGLContext ctx, void* ptr) final {
SetDevice(ctx);
TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable())
return td->CUDAFreeWorkspace(ptr);
if (td->IsAvailable()) return td->CUDAFreeWorkspace(ptr);
CUDA_CALL(cudaFree(ptr));
}
void CopyDataFromTo(const void* from,
size_t from_offset,
void* to,
size_t to_offset,
size_t size,
DGLContext ctx_from,
DGLContext ctx_to,
DGLDataType type_hint,
DGLStreamHandle stream) {
void CopyDataFromTo(
const void* from, size_t from_offset, void* to, size_t to_offset,
size_t size, DGLContext ctx_from, DGLContext ctx_to,
DGLDataType type_hint, DGLStreamHandle stream) {
cudaStream_t cu_stream = static_cast<cudaStream_t>(stream);
from = static_cast<const char*>(from) + from_offset;
to = static_cast<char*>(to) + to_offset;
......@@ -146,14 +137,15 @@ class CUDADeviceAPI final : public DeviceAPI {
if (ctx_from.device_id == ctx_to.device_id) {
GPUCopy(from, to, size, cudaMemcpyDeviceToDevice, cu_stream);
} else {
CUDA_CALL(cudaMemcpyPeerAsync(to, ctx_to.device_id,
from, ctx_from.device_id,
size, cu_stream));
CUDA_CALL(cudaMemcpyPeerAsync(
to, ctx_to.device_id, from, ctx_from.device_id, size, cu_stream));
}
} else if (ctx_from.device_type == kDGLCUDA && ctx_to.device_type == kDGLCPU) {
} else if (
ctx_from.device_type == kDGLCUDA && ctx_to.device_type == kDGLCPU) {
CUDA_CALL(cudaSetDevice(ctx_from.device_id));
GPUCopy(from, to, size, cudaMemcpyDeviceToHost, cu_stream);
} else if (ctx_from.device_type == kDGLCPU && ctx_to.device_type == kDGLCUDA) {
} else if (
ctx_from.device_type == kDGLCPU && ctx_to.device_type == kDGLCUDA) {
CUDA_CALL(cudaSetDevice(ctx_to.device_id));
GPUCopy(from, to, size, cudaMemcpyHostToDevice, cu_stream);
} else {
......@@ -161,16 +153,14 @@ class CUDADeviceAPI final : public DeviceAPI {
}
}
void CopyDataFromTo(const void* from,
size_t from_offset,
void* to,
size_t to_offset,
size_t size,
DGLContext ctx_from,
DGLContext ctx_to,
DGLDataType type_hint) final {
void CopyDataFromTo(
const void* from, size_t from_offset, void* to, size_t to_offset,
size_t size, DGLContext ctx_from, DGLContext ctx_to,
DGLDataType type_hint) final {
auto stream = GetStream();
CopyDataFromTo(from, from_offset, to, to_offset, size, ctx_from, ctx_to, type_hint, stream);
CopyDataFromTo(
from, from_offset, to, to_offset, size, ctx_from, ctx_to, type_hint,
stream);
}
DGLStreamHandle CreateStream(DGLContext ctx) {
......@@ -187,7 +177,8 @@ class CUDADeviceAPI final : public DeviceAPI {
CUDA_CALL(cudaStreamDestroy(cu_stream));
}
void SyncStreamFromTo(DGLContext ctx, DGLStreamHandle event_src, DGLStreamHandle event_dst) {
void SyncStreamFromTo(
DGLContext ctx, DGLStreamHandle event_src, DGLStreamHandle event_dst) {
CUDA_CALL(cudaSetDevice(ctx.device_id));
cudaStream_t src_stream = static_cast<cudaStream_t>(event_src);
cudaStream_t dst_stream = static_cast<cudaStream_t>(event_dst);
......@@ -222,54 +213,54 @@ class CUDADeviceAPI final : public DeviceAPI {
*/
void PinData(void* ptr, size_t nbytes) {
// prevent users from pinning empty tensors or graphs
if (ptr == nullptr || nbytes == 0)
return;
if (ptr == nullptr || nbytes == 0) return;
CUDA_CALL(cudaHostRegister(ptr, nbytes, cudaHostRegisterDefault));
}
void UnpinData(void* ptr) {
if (ptr == nullptr)
return;
if (ptr == nullptr) return;
CUDA_CALL(cudaHostUnregister(ptr));
}
bool IsPinned(const void* ptr) override {
// can't be a pinned tensor if CUDA context is unavailable.
if (!is_available_)
return false;
if (!is_available_) return false;
cudaPointerAttributes attr;
cudaError_t status = cudaPointerGetAttributes(&attr, ptr);
bool result = false;
switch (status) {
case cudaErrorInvalidValue:
// might be a normal CPU tensor in CUDA 10.2-
cudaGetLastError(); // clear error
break;
case cudaSuccess:
result = (attr.type == cudaMemoryTypeHost);
break;
case cudaErrorInitializationError:
case cudaErrorNoDevice:
case cudaErrorInsufficientDriver:
case cudaErrorInvalidDevice:
// We don't want to fail in these particular cases since this function can be called
// when users only want to run on CPU even if CUDA API is enabled, or in a forked
// subprocess where CUDA context cannot be initialized. So we just mark the CUDA
// context to unavailable and return.
is_available_ = false;
cudaGetLastError(); // clear error
break;
default:
LOG(FATAL) << "error while determining memory status: " << cudaGetErrorString(status);
break;
case cudaErrorInvalidValue:
// might be a normal CPU tensor in CUDA 10.2-
cudaGetLastError(); // clear error
break;
case cudaSuccess:
result = (attr.type == cudaMemoryTypeHost);
break;
case cudaErrorInitializationError:
case cudaErrorNoDevice:
case cudaErrorInsufficientDriver:
case cudaErrorInvalidDevice:
// We don't want to fail in these particular cases since this function
// can be called when users only want to run on CPU even if CUDA API is
// enabled, or in a forked subprocess where CUDA context cannot be
// initialized. So we just mark the CUDA context to unavailable and
// return.
is_available_ = false;
cudaGetLastError(); // clear error
break;
default:
LOG(FATAL) << "error while determining memory status: "
<< cudaGetErrorString(status);
break;
}
return result;
}
void* AllocWorkspace(DGLContext ctx, size_t size, DGLDataType type_hint) final {
void* AllocWorkspace(
DGLContext ctx, size_t size, DGLDataType type_hint) final {
SetDevice(ctx);
// Redirect to PyTorch's allocator when available.
TensorDispatcher* td = TensorDispatcher::Global();
......@@ -282,8 +273,7 @@ class CUDADeviceAPI final : public DeviceAPI {
void FreeWorkspace(DGLContext ctx, void* data) final {
SetDevice(ctx);
TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable())
return td->CUDAFreeWorkspace(data);
if (td->IsAvailable()) return td->CUDAFreeWorkspace(data);
CUDAThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
}
......@@ -295,14 +285,13 @@ class CUDADeviceAPI final : public DeviceAPI {
}
private:
static void GPUCopy(const void* from,
void* to,
size_t size,
cudaMemcpyKind kind,
cudaStream_t stream) {
static void GPUCopy(
const void* from, void* to, size_t size, cudaMemcpyKind kind,
cudaStream_t stream) {
CUDA_CALL(cudaMemcpyAsync(to, from, size, kind, stream));
if (stream == 0 && kind == cudaMemcpyDeviceToHost) {
// only wait for the copy, when it's on the default stream, and it's to host memory
// only wait for the copy, when it's on the default stream, and it's to
// host memory
CUDA_CALL(cudaStreamSynchronize(stream));
}
}
......@@ -312,9 +301,7 @@ class CUDADeviceAPI final : public DeviceAPI {
typedef dmlc::ThreadLocalStore<CUDAThreadEntry> CUDAThreadStore;
CUDAThreadEntry::CUDAThreadEntry()
: pool(kDGLCUDA, CUDADeviceAPI::Global()) {
}
CUDAThreadEntry::CUDAThreadEntry() : pool(kDGLCUDA, CUDADeviceAPI::Global()) {}
CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() {
return CUDAThreadStore::Get();
......@@ -329,10 +316,10 @@ cudaStream_t getCurrentCUDAStream() {
}
DGL_REGISTER_GLOBAL("device_api.cuda")
.set_body([](DGLArgs args, DGLRetValue* rv) {
DeviceAPI* ptr = CUDADeviceAPI::Global().get();
*rv = static_cast<void*>(ptr);
});
.set_body([](DGLArgs args, DGLRetValue* rv) {
DeviceAPI* ptr = CUDADeviceAPI::Global().get();
*rv = static_cast<void*>(ptr);
});
} // namespace runtime
} // namespace dgl
This diff is collapsed.
......@@ -9,14 +9,14 @@
#include <dgl/runtime/c_runtime_api.h>
#include "cuda_runtime.h"
#include "cuda_common.h"
#include "cuda_runtime.h"
namespace dgl {
namespace runtime {
namespace cuda {
template<typename>
template <typename>
class OrderedHashTable;
/*!
......@@ -31,7 +31,7 @@ class OrderedHashTable;
* used.
*
* The hash table should be used in two phases, with the first being populating
* the hash table with the OrderedHashTable object, and then generating this
* the hash table with the OrderedHashTable object, and then generating this
* handle from it. This object can then be used to search the hash table,
* to find mappings, from with CUDA code.
*
......@@ -62,7 +62,7 @@ class OrderedHashTable;
*
* \tparam IdType The type of the IDs.
*/
template<typename IdType>
template <typename IdType>
class DeviceOrderedHashTable {
public:
/**
......@@ -80,16 +80,15 @@ class DeviceOrderedHashTable {
/**
* \brief The index of the item when inserted into the hashtable (e.g.,
* the index within the array passed into FillWithDuplicates()).
*/
*/
int64_t index;
};
typedef const Mapping* ConstIterator;
DeviceOrderedHashTable(
const DeviceOrderedHashTable& other) = default;
DeviceOrderedHashTable& operator=(
const DeviceOrderedHashTable& other) = default;
DeviceOrderedHashTable(const DeviceOrderedHashTable& other) = default;
DeviceOrderedHashTable& operator=(const DeviceOrderedHashTable& other) =
default;
/**
* \brief Find the non-mutable mapping of a given key within the hash table.
......@@ -101,8 +100,7 @@ class DeviceOrderedHashTable {
*
* \return An iterator to the mapping.
*/
inline __device__ ConstIterator Search(
const IdType id) const {
inline __device__ ConstIterator Search(const IdType id) const {
const IdType pos = SearchForPosition(id);
return &table_[pos];
......@@ -115,8 +113,7 @@ class DeviceOrderedHashTable {
*
* \return True if the key exists in the hashtable.
*/
inline __device__ bool Contains(
const IdType id) const {
inline __device__ bool Contains(const IdType id) const {
IdType pos = Hash(id);
IdType delta = 1;
......@@ -124,8 +121,8 @@ class DeviceOrderedHashTable {
if (table_[pos].key == id) {
return true;
}
pos = Hash(pos+delta);
delta +=1;
pos = Hash(pos + delta);
delta += 1;
}
return false;
}
......@@ -134,7 +131,7 @@ class DeviceOrderedHashTable {
// Must be uniform bytes for memset to work
static constexpr IdType kEmptyKey = static_cast<IdType>(-1);
const Mapping * table_;
const Mapping* table_;
size_t size_;
/**
......@@ -143,9 +140,7 @@ class DeviceOrderedHashTable {
* \param table The table stored in GPU memory.
* \param size The size of the table.
*/
explicit DeviceOrderedHashTable(
const Mapping * table,
size_t size);
explicit DeviceOrderedHashTable(const Mapping* table, size_t size);
/**
* \brief Search for an item in the hash table which is known to exist.
......@@ -157,16 +152,15 @@ class DeviceOrderedHashTable {
*
* \return The the position of the item in the hashtable.
*/
inline __device__ IdType SearchForPosition(
const IdType id) const {
inline __device__ IdType SearchForPosition(const IdType id) const {
IdType pos = Hash(id);
// linearly scan for matching entry
IdType delta = 1;
while (table_[pos].key != id) {
assert(table_[pos].key != kEmptyKey);
pos = Hash(pos+delta);
delta +=1;
pos = Hash(pos + delta);
delta += 1;
}
assert(pos < size_);
......@@ -180,10 +174,7 @@ class DeviceOrderedHashTable {
*
* \return The hash.
*/
inline __device__ size_t Hash(
const IdType id) const {
return id % size_;
}
inline __device__ size_t Hash(const IdType id) const { return id % size_; }
friend class OrderedHashTable<IdType>;
};
......@@ -219,7 +210,7 @@ class DeviceOrderedHashTable {
*
* \tparam IdType The type of the IDs.
*/
template<typename IdType>
template <typename IdType>
class OrderedHashTable {
public:
static constexpr int kDefaultScale = 3;
......@@ -237,9 +228,7 @@ class OrderedHashTable {
* \param stream The stream to use for initializing the hashtable.
*/
OrderedHashTable(
const size_t size,
DGLContext ctx,
cudaStream_t stream,
const size_t size, DGLContext ctx, cudaStream_t stream,
const int scale = kDefaultScale);
/**
......@@ -248,10 +237,8 @@ class OrderedHashTable {
~OrderedHashTable();
// Disable copying
OrderedHashTable(
const OrderedHashTable& other) = delete;
OrderedHashTable& operator=(
const OrderedHashTable& other) = delete;
OrderedHashTable(const OrderedHashTable& other) = delete;
OrderedHashTable& operator=(const OrderedHashTable& other) = delete;
/**
* \brief Fill the hashtable with the array containing possibly duplicate
......@@ -264,11 +251,8 @@ class OrderedHashTable {
* \param stream The stream to perform operations on.
*/
void FillWithDuplicates(
const IdType * const input,
const size_t num_input,
IdType * const unique,
int64_t * const num_unique,
cudaStream_t stream);
const IdType* const input, const size_t num_input, IdType* const unique,
int64_t* const num_unique, cudaStream_t stream);
/**
* \brief Fill the hashtable with an array of unique keys.
......@@ -278,9 +262,7 @@ class OrderedHashTable {
* \param stream The stream to perform operations on.
*/
void FillWithUnique(
const IdType * const input,
const size_t num_input,
cudaStream_t stream);
const IdType* const input, const size_t num_input, cudaStream_t stream);
/**
* \brief Get a verison of the hashtable usable from device functions.
......@@ -290,12 +272,11 @@ class OrderedHashTable {
DeviceOrderedHashTable<IdType> DeviceHandle() const;
private:
Mapping * table_;
Mapping* table_;
size_t size_;
DGLContext ctx_;
};
} // namespace cuda
} // namespace runtime
} // namespace dgl
......
This diff is collapsed.
......@@ -14,10 +14,9 @@
* limitations under the License.
*
* \file nccl_api.h
* \brief Wrapper around NCCL routines.
* \brief Wrapper around NCCL routines.
*/
#ifndef DGL_RUNTIME_CUDA_NCCL_API_H_
#define DGL_RUNTIME_CUDA_NCCL_API_H_
......@@ -27,11 +26,14 @@
// if not compiling with NCCL, this class will only support communicators of
// size 1.
#define NCCL_UNIQUE_ID_BYTES 128
typedef struct { char internal[NCCL_UNIQUE_ID_BYTES]; } ncclUniqueId;
typedef struct {
char internal[NCCL_UNIQUE_ID_BYTES];
} ncclUniqueId;
typedef int ncclComm_t;
#endif
#include <dgl/runtime/object.h>
#include <string>
namespace dgl {
......@@ -59,17 +61,13 @@ DGL_DEFINE_OBJECT_REF(NCCLUniqueIdRef, NCCLUniqueId);
class NCCLCommunicator : public runtime::Object {
public:
NCCLCommunicator(
int size,
int rank,
ncclUniqueId id);
NCCLCommunicator(int size, int rank, ncclUniqueId id);
~NCCLCommunicator();
// disable copying
NCCLCommunicator(const NCCLCommunicator& other) = delete;
NCCLCommunicator& operator=(
const NCCLCommunicator& other);
NCCLCommunicator& operator=(const NCCLCommunicator& other);
ncclComm_t Get();
......@@ -81,12 +79,9 @@ class NCCLCommunicator : public runtime::Object {
* @param count The size of data to send to each rank.
* @param stream The stream to operate on.
*/
template<typename IdType>
template <typename IdType>
void AllToAll(
const IdType * send,
IdType * recv,
int64_t count,
cudaStream_t stream);
const IdType* send, IdType* recv, int64_t count, cudaStream_t stream);
/**
* @brief Perform an all-to-all variable sized communication.
......@@ -99,13 +94,10 @@ class NCCLCommunicator : public runtime::Object {
* @param type The type of data to send.
* @param stream The stream to operate on.
*/
template<typename DType>
template <typename DType>
void AllToAllV(
const DType * const send,
const int64_t * send_prefix,
DType * const recv,
const int64_t * recv_prefix,
cudaStream_t stream);
const DType* const send, const int64_t* send_prefix, DType* const recv,
const int64_t* recv_prefix, cudaStream_t stream);
/**
* @brief Perform an all-to-all with sparse data (idx and value pairs). By
......@@ -124,16 +116,11 @@ class NCCLCommunicator : public runtime::Object {
* recieve on the host.
* @param stream The stream to communicate on.
*/
template<typename IdType, typename DType>
template <typename IdType, typename DType>
void SparseAllToAll(
const IdType * send_idx,
const DType * send_value,
const int64_t num_feat,
const int64_t * send_prefix,
IdType * recv_idx,
DType * recv_value,
const int64_t * recv_prefix,
cudaStream_t stream);
const IdType* send_idx, const DType* send_value, const int64_t num_feat,
const int64_t* send_prefix, IdType* recv_idx, DType* recv_value,
const int64_t* recv_prefix, cudaStream_t stream);
int size() const;
......
......@@ -3,12 +3,12 @@
* \file src/runtime/dlpack_convert.cc
* \brief Conversion between NDArray and DLPack.
*/
#include <dgl/runtime/dlpack_convert.h>
#include <dlpack/dlpack.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/device_api.h>
#include <dgl/runtime/dlpack_convert.h>
#include <dgl/runtime/ndarray.h>
#include <dlpack/dlpack.h>
#include "runtime_base.h"
// deleter for arrays used by DLPack exporter
......@@ -69,8 +69,7 @@ NDArray DLPackConvert::FromDLPack(DLManagedTensor* tensor) {
void DLPackConvert::DLPackDeleter(NDArray::Container* ptr) {
// if the array is pinned by dgl, unpin it before freeing
if (ptr->pinned_by_dgl_)
NDArray::UnpinContainer(ptr);
if (ptr->pinned_by_dgl_) NDArray::UnpinContainer(ptr);
DLManagedTensor* tensor = static_cast<DLManagedTensor*>(ptr->manager_ctx);
if (tensor->deleter != nullptr) {
(*tensor->deleter)(tensor);
......@@ -95,7 +94,7 @@ DLManagedTensor* ContainerToDLPack(NDArray::Container* from) {
return ret;
}
DLManagedTensor* DLPackConvert::ToDLPack(const NDArray &from) {
DLManagedTensor* DLPackConvert::ToDLPack(const NDArray& from) {
return ContainerToDLPack(from.data_);
}
......@@ -113,15 +112,14 @@ inline bool IsAligned(const void* ptr, std::uintptr_t alignment) noexcept {
return !(iptr % alignment);
}
int DGLArrayFromDLPack(DLManagedTensor* from,
DGLArrayHandle* out) {
int DGLArrayFromDLPack(DLManagedTensor* from, DGLArrayHandle* out) {
API_BEGIN();
*out = NDArray::Internal::MoveAsDGLArray(DLPackConvert::FromDLPack(from));
API_END();
}
int DGLArrayToDLPack(DGLArrayHandle from, DLManagedTensor** out,
int alignment) {
int DGLArrayToDLPack(
DGLArrayHandle from, DLManagedTensor** out, int alignment) {
API_BEGIN();
auto* nd_container = reinterpret_cast<NDArray::Container*>(from);
DGLArray* nd = &(nd_container->dl_tensor);
......
......@@ -4,8 +4,9 @@
* \brief Module to load from dynamic shared library.
*/
#include <dgl/runtime/module.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/packed_func.h>
#include <dgl/runtime/registry.h>
#include "module_util.h"
#if defined(_WIN32)
......@@ -25,9 +26,7 @@ class DSOModuleNode final : public ModuleNode {
if (lib_handle_) Unload();
}
const char* type_key() const final {
return "dso";
}
const char* type_key() const final { return "dso"; }
PackedFunc GetFunction(
const std::string& name,
......@@ -36,8 +35,9 @@ class DSOModuleNode final : public ModuleNode {
if (name == runtime::symbol::dgl_module_main) {
const char* entry_name = reinterpret_cast<const char*>(
GetSymbol(runtime::symbol::dgl_module_main));
CHECK(entry_name!= nullptr)
<< "Symbol " << runtime::symbol::dgl_module_main << " is not presented";
CHECK(entry_name != nullptr)
<< "Symbol " << runtime::symbol::dgl_module_main
<< " is not presented";
faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(entry_name));
} else {
faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(name.c_str()));
......@@ -48,17 +48,15 @@ class DSOModuleNode final : public ModuleNode {
void Init(const std::string& name) {
Load(name);
if (auto *ctx_addr =
reinterpret_cast<void**>(GetSymbol(runtime::symbol::dgl_module_ctx))) {
if (auto* ctx_addr = reinterpret_cast<void**>(
GetSymbol(runtime::symbol::dgl_module_ctx))) {
*ctx_addr = this;
}
InitContextFunctions([this](const char* fname) {
return GetSymbol(fname);
});
InitContextFunctions(
[this](const char* fname) { return GetSymbol(fname); });
// Load the imported modules
const char* dev_mblob =
reinterpret_cast<const char*>(
GetSymbol(runtime::symbol::dgl_dev_mblob));
const char* dev_mblob = reinterpret_cast<const char*>(
GetSymbol(runtime::symbol::dgl_dev_mblob));
if (dev_mblob != nullptr) {
ImportModuleBlob(dev_mblob, &imports_);
}
......@@ -79,11 +77,9 @@ class DSOModuleNode final : public ModuleNode {
}
void* GetSymbol(const char* name) {
return reinterpret_cast<void*>(
GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*)
}
void Unload() {
FreeLibrary(lib_handle_);
GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*)
}
void Unload() { FreeLibrary(lib_handle_); }
#else
// Library handle
void* lib_handle_{nullptr};
......@@ -91,23 +87,18 @@ class DSOModuleNode final : public ModuleNode {
void Load(const std::string& name) {
lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL);
CHECK(lib_handle_ != nullptr)
<< "Failed to load dynamic shared library " << name
<< " " << dlerror();
}
void* GetSymbol(const char* name) {
return dlsym(lib_handle_, name);
}
void Unload() {
dlclose(lib_handle_);
<< "Failed to load dynamic shared library " << name << " " << dlerror();
}
void* GetSymbol(const char* name) { return dlsym(lib_handle_, name); }
void Unload() { dlclose(lib_handle_); }
#endif
};
DGL_REGISTER_GLOBAL("module.loadfile_so")
.set_body([](DGLArgs args, DGLRetValue* rv) {
std::shared_ptr<DSOModuleNode> n = std::make_shared<DSOModuleNode>();
n->Init(args[0]);
*rv = runtime::Module(n);
});
.set_body([](DGLArgs args, DGLRetValue* rv) {
std::shared_ptr<DSOModuleNode> n = std::make_shared<DSOModuleNode>();
n->Init(args[0]);
*rv = runtime::Module(n);
});
} // namespace runtime
} // namespace dgl
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment