Unverified Commit 401e1278 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] clang-format auto fix. (#4811)



* [Misc] clang-format auto fix.

* fix

* manual
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 6c53f351
......@@ -41,7 +41,7 @@ class TCPSocket {
* \param port end port
* \return true for success and false for failure
*/
bool Connect(const char * ip, int port);
bool Connect(const char* ip, int port);
/*!
* \brief Bind on the given IP and PORT
......@@ -49,7 +49,7 @@ class TCPSocket {
* \param port end port
* \return true for success and false for failure
*/
bool Bind(const char * ip, int port);
bool Bind(const char* ip, int port);
/*!
* \brief listen for remote connection
......@@ -65,9 +65,7 @@ class TCPSocket {
* \param port_client new PORT will be stored to port_client
* \return true for success and false for failure
*/
bool Accept(TCPSocket * socket,
std::string * ip_client,
int * port_client);
bool Accept(TCPSocket* socket, std::string* ip_client, int* port_client);
/*!
* \brief SetNonBlocking() is needed refering to this example of epoll:
......@@ -104,7 +102,7 @@ class TCPSocket {
* \param len_data length of data
* \return return number of bytes sent if OK, -1 on error
*/
int64_t Send(const char * data, int64_t len_data);
int64_t Send(const char* data, int64_t len_data);
/*!
* \brief Receive data.
......@@ -112,7 +110,7 @@ class TCPSocket {
* \param size_buffer size of buffer
* \return return number of bytes received if OK, -1 on error
*/
int64_t Receive(char * buffer, int64_t size_buffer);
int64_t Receive(char* buffer, int64_t size_buffer);
/*!
* \brief Get socket's file descriptor
......
......@@ -101,7 +101,8 @@ void InitGlobalTpContext() {
char* numUvThreads_str = std::getenv("DGL_SOCKET_NTHREADS");
if (numUvThreads_str) {
int numUvThreads = std::atoi(numUvThreads_str);
CHECK(numUvThreads > 0) << "DGL_SOCKET_NTHREADS should be positive integer if set";
CHECK(numUvThreads > 0)
<< "DGL_SOCKET_NTHREADS should be positive integer if set";
// Register multiplex uv channel
std::vector<std::shared_ptr<tensorpipe::transport::Context>> contexts;
std::vector<std::shared_ptr<tensorpipe::transport::Listener>> listeners;
......@@ -111,8 +112,8 @@ void InitGlobalTpContext() {
contexts.push_back(std::move(context));
listeners.push_back(contexts.back()->listen(address));
}
auto mptChannel = tensorpipe::channel::mpt::create(std::move(contexts),
std::move(listeners));
auto mptChannel = tensorpipe::channel::mpt::create(
std::move(contexts), std::move(listeners));
context->registerChannel(20 /* high priority */, "mpt", mptChannel);
}
}
......@@ -120,10 +121,10 @@ void InitGlobalTpContext() {
//////////////////////////// C APIs ////////////////////////////
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCReset")
.set_body([](DGLArgs args, DGLRetValue* rv) { RPCContext::Reset(); });
.set_body([](DGLArgs args, DGLRetValue* rv) { RPCContext::Reset(); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateSender")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
int64_t msg_queue_size = args[0];
std::string type = args[1];
int max_thread_count = args[2];
......@@ -138,11 +139,12 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateSender")
LOG(FATAL) << "Unknown communicator type for rpc sender: " << type;
}
LOG(INFO) << "Sender with NetType~"
<< RPCContext::getInstance()->sender->NetType() << " is created.";
});
<< RPCContext::getInstance()->sender->NetType()
<< " is created.";
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateReceiver")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
int64_t msg_queue_size = args[0];
std::string type = args[1];
int max_thread_count = args[2];
......@@ -157,171 +159,174 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateReceiver")
LOG(FATAL) << "Unknown communicator type for rpc receiver: " << type;
}
LOG(INFO) << "Receiver with NetType~"
<< RPCContext::getInstance()->receiver->NetType() << " is created.";
});
<< RPCContext::getInstance()->receiver->NetType()
<< " is created.";
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeSender")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
RPCContext::getInstance()->sender->Finalize();
});
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeReceiver")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
RPCContext::getInstance()->receiver->Finalize();
});
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCWaitForSenders")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
std::string ip = args[0];
int port = args[1];
int num_sender = args[2];
bool blocking = args[3];
std::string addr;
addr = StringPrintf("tcp://%s:%d", ip.c_str(), port);
if (RPCContext::getInstance()->receiver->Wait(addr, num_sender, blocking) == false) {
if (RPCContext::getInstance()->receiver->Wait(
addr, num_sender, blocking) == false) {
LOG(FATAL) << "Wait sender socket failed.";
}
});
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCConnectReceiver")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
std::string ip = args[0];
int port = args[1];
int recv_id = args[2];
std::string addr;
addr = StringPrintf("tcp://%s:%d", ip.c_str(), port);
*rv = RPCContext::getInstance()->sender->ConnectReceiver(addr, recv_id);
});
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCConnectReceiverFinalize")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
const int max_try_times = args[0];
*rv = RPCContext::getInstance()->sender->ConnectReceiverFinalize(max_try_times);
});
*rv = RPCContext::getInstance()->sender->ConnectReceiverFinalize(
max_try_times);
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetRank")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
const int32_t rank = args[0];
RPCContext::getInstance()->rank = rank;
});
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetRank")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
*rv = RPCContext::getInstance()->rank;
});
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumServer")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
const int32_t num_servers = args[0];
*rv = RPCContext::getInstance()->num_servers = num_servers;
});
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServer")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
*rv = RPCContext::getInstance()->num_servers;
});
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumClient")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
const int32_t num_clients = args[0];
*rv = RPCContext::getInstance()->num_clients = num_clients;
});
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumClient")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
*rv = RPCContext::getInstance()->num_clients;
});
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumServerPerMachine")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
const int32_t num_servers = args[0];
*rv = RPCContext::getInstance()->num_servers_per_machine = num_servers;
});
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServerPerMachine")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
*rv = RPCContext::getInstance()->num_servers_per_machine;
});
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCIncrMsgSeq")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
*rv = (RPCContext::getInstance()->msg_seq)++;
});
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetMsgSeq")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
*rv = RPCContext::getInstance()->msg_seq;
});
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetMsgSeq")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
const int64_t msg_seq = args[0];
RPCContext::getInstance()->msg_seq = msg_seq;
});
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetBarrierCount")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
const int32_t group_id = args[0];
auto&& cnt = RPCContext::getInstance()->barrier_count;
if (cnt.find(group_id) == cnt.end()) {
cnt.emplace(group_id, 0x0);
}
*rv = cnt[group_id];
});
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetBarrierCount")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
const int32_t count = args[0];
const int32_t group_id = args[1];
RPCContext::getInstance()->barrier_count[group_id] = count;
});
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetMachineID")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
*rv = RPCContext::getInstance()->machine_id;
});
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetMachineID")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
const int32_t machine_id = args[0];
RPCContext::getInstance()->machine_id = machine_id;
});
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumMachines")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
*rv = RPCContext::getInstance()->num_machines;
});
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumMachines")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
const int32_t num_machines = args[0];
RPCContext::getInstance()->num_machines = num_machines;
});
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSendRPCMessage")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
RPCMessageRef msg = args[0];
const int32_t target_id = args[1];
*rv = SendRPCMessage(*(msg.sptr()), target_id);
});
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCRecvRPCMessage")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
int32_t timeout = args[0];
RPCMessageRef msg = args[1];
*rv = RecvRPCMessage(msg.sptr().get(), timeout);
});
});
//////////////////////////// RPCMessage ////////////////////////////
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateEmptyRPCMessage")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
std::shared_ptr<RPCMessage> rst(new RPCMessage);
*rv = rst;
});
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateRPCMessage")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
std::shared_ptr<RPCMessage> rst(new RPCMessage);
rst->service_id = args[0];
rst->msg_seq = args[1];
......@@ -333,48 +338,48 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateRPCMessage")
rst->tensors = ListValueToVector<NDArray>(args[5]);
rst->group_id = args[6];
*rv = rst;
});
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetServiceId")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
const RPCMessageRef msg = args[0];
*rv = msg->service_id;
});
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetMsgSeq")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
const RPCMessageRef msg = args[0];
*rv = msg->msg_seq;
});
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetClientId")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
const RPCMessageRef msg = args[0];
*rv = msg->client_id;
});
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetServerId")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
const RPCMessageRef msg = args[0];
*rv = msg->server_id;
});
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetData")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
const RPCMessageRef msg = args[0];
DGLByteArray barr{msg->data.c_str(), msg->data.size()};
*rv = barr;
});
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetTensors")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
const RPCMessageRef msg = args[0];
List<Value> ret;
for (size_t i = 0; i < msg->tensors.size(); ++i) {
ret.push_back(Value(MakeValue(msg->tensors[i])));
}
*rv = ret;
});
});
#if defined(__linux__)
/*!
......@@ -388,7 +393,7 @@ void SigHandler(int s) {
}
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCHandleSignal")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
// Ctrl+C handler
struct sigaction sigHandler;
sigHandler.sa_handler = SigHandler;
......@@ -396,24 +401,25 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCHandleSignal")
sigHandler.sa_flags = 0;
sigaction(SIGINT, &sigHandler, nullptr);
sigaction(SIGTERM, &sigHandler, nullptr);
});
});
#endif
//////////////////////////// ServerState ////////////////////////////
DGL_REGISTER_GLOBAL("distributed.server_state._CAPI_DGLRPCGetServerState")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
auto st = RPCContext::getInstance()->server_state;
if (st.get() == nullptr) {
RPCContext::getInstance()->server_state = std::make_shared<ServerState>();
RPCContext::getInstance()->server_state =
std::make_shared<ServerState>();
}
*rv = st;
});
});
//////////////////////////// KVStore ////////////////////////////
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetGlobalIDFromLocalPartition")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
NDArray ID = args[0];
NDArray part_id = args[1];
int local_machine_id = args[2];
......@@ -428,10 +434,10 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetGlobalIDFromLocalPartition")
}
NDArray res_tensor = dgl::aten::VecToIdArray<int64_t>(global_id);
*rv = res_tensor;
});
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFastPull")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
// Input
std::string name = args[0];
int local_machine_id = args[1];
......@@ -496,7 +502,8 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFastPull")
msg.client_id = client_id;
int lower = i * group_count;
int upper = (i + 1) * group_count;
msg.server_id = dgl::RandomEngine::ThreadLocal()->RandInt(lower, upper);
msg.server_id =
dgl::RandomEngine::ThreadLocal()->RandInt(lower, upper);
msg.data = pickle_data;
NDArray tensor = dgl::aten::VecToIdArray<dgl_id_t>(remote_ids[i]);
msg.tensors.push_back(tensor);
......@@ -506,18 +513,18 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFastPull")
}
}
local_data_shape[0] = ID_size;
NDArray res_tensor = NDArray::Empty(local_data_shape,
local_data->dtype,
DGLContext{kDGLCPU, 0});
NDArray res_tensor = NDArray::Empty(
local_data_shape, local_data->dtype, DGLContext{kDGLCPU, 0});
char* return_data = static_cast<char*>(res_tensor->data);
// Copy local data
parallel_for(0, local_ids.size(), [&](size_t b, size_t e) {
for (auto i = b; i < e; ++i) {
CHECK_GE(ID_size * row_size,
local_ids_orginal[i] * row_size + row_size);
CHECK_GE(
ID_size * row_size, local_ids_orginal[i] * row_size + row_size);
CHECK_GE(data_size, local_ids[i] * row_size + row_size);
CHECK_GE(local_ids[i], 0);
memcpy(return_data + local_ids_orginal[i] * row_size,
memcpy(
return_data + local_ids_orginal[i] * row_size,
local_data_char + local_ids[i] * row_size, row_size);
}
});
......@@ -532,43 +539,44 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFastPull")
char* data_char = static_cast<char*>(msg.tensors[0]->data);
dgl_id_t id_size = remote_ids[part_id].size();
for (size_t n = 0; n < id_size; ++n) {
memcpy(return_data + remote_ids_original[part_id][n] * row_size,
memcpy(
return_data + remote_ids_original[part_id][n] * row_size,
data_char + n * row_size, row_size);
}
}
*rv = res_tensor;
});
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetGroupID")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
*rv = RPCContext::getInstance()->group_id;
});
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetGroupID")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
const int32_t group_id = args[0];
RPCContext::getInstance()->group_id = group_id;
});
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetGroupId")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
const RPCMessageRef msg = args[0];
*rv = msg->group_id;
});
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCRegisterClient")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
const int32_t client_id = args[0];
const int32_t group_id = args[1];
*rv = RPCContext::getInstance()->RegisterClient(client_id, group_id);
});
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetClient")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
const int32_t client_id = args[0];
const int32_t group_id = args[1];
*rv = RPCContext::getInstance()->GetClient(client_id, group_id);
});
});
} // namespace rpc
} // namespace dgl
......
......@@ -6,24 +6,25 @@
#ifndef DGL_RPC_RPC_H_
#define DGL_RPC_RPC_H_
#include <dgl/runtime/object.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/runtime/object.h>
#include <dgl/zerocopy_serializer.h>
#include <dmlc/thread_local.h>
#include <cstdint>
#include <memory>
#include <deque>
#include <vector>
#include <string>
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
#include <vector>
#include "./network/common.h"
#include "./rpc_msg.h"
#include "./server_state.h"
#include "net_type.h"
#include "network/socket_communicator.h"
#include "tensorpipe/tp_communicator.h"
#include "./network/common.h"
#include "./server_state.h"
namespace dgl {
namespace rpc {
......@@ -138,7 +139,7 @@ struct RPCContext {
}
int32_t RegisterClient(int32_t client_id, int32_t group_id) {
auto &&m = clients_[group_id];
auto&& m = clients_[group_id];
if (m.find(client_id) != m.end()) {
return -1;
}
......@@ -150,7 +151,7 @@ struct RPCContext {
if (clients_.find(group_id) == clients_.end()) {
return -1;
}
const auto &m = clients_.at(group_id);
const auto& m = clients_.at(group_id);
if (m.find(client_id) == m.end()) {
return -1;
}
......
......@@ -6,8 +6,8 @@
#ifndef DGL_RPC_RPC_MSG_H_
#define DGL_RPC_RPC_MSG_H_
#include <dgl/runtime/object.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/runtime/object.h>
#include <dgl/zerocopy_serializer.h>
#include <string>
......
......@@ -7,11 +7,12 @@
#ifndef DGL_RPC_SERVER_STATE_H_
#define DGL_RPC_SERVER_STATE_H_
#include <dgl/runtime/object.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/base_heterograph.h>
#include <unordered_map>
#include <dgl/runtime/ndarray.h>
#include <dgl/runtime/object.h>
#include <string>
#include <unordered_map>
namespace dgl {
namespace rpc {
......
......@@ -9,10 +9,11 @@
#define DGL_RPC_TENSORPIPE_QUEUE_H_
#include <dmlc/logging.h>
#include <chrono>
#include <condition_variable>
#include <deque>
#include <mutex>
#include <chrono>
#include <utility>
namespace dgl {
......@@ -39,8 +40,9 @@ class Queue {
DLOG(WARNING) << "Will wait infinitely until message is popped...";
cv_.wait(lock, [this] { return items_.size() > 0; });
} else {
if (!cv_.wait_for(lock, std::chrono::milliseconds(timeout),
[this] { return items_.size() > 0; })) {
if (!cv_.wait_for(lock, std::chrono::milliseconds(timeout), [this] {
return items_.size() > 0;
})) {
DLOG(WARNING) << "Times out for popping message after " << timeout
<< " milliseconds.";
return false;
......
......@@ -48,8 +48,8 @@ void TPSender::Send(const RPCMessage &msg, int recv_id) {
StreamWithBuffer zc_write_strm(zerocopy_blob_ptr, true);
zc_write_strm.Write(msg);
int32_t nonempty_ndarray_count = zc_write_strm.buffer_list().size();
zerocopy_blob_ptr->append(reinterpret_cast<char *>(&nonempty_ndarray_count),
sizeof(int32_t));
zerocopy_blob_ptr->append(
reinterpret_cast<char *>(&nonempty_ndarray_count), sizeof(int32_t));
tp_msg.tensors.resize(nonempty_ndarray_count);
// Hold the NDArray that ensure it's valid until write operation completes
auto ndarray_holder = std::make_shared<std::vector<NDArray>>();
......@@ -68,8 +68,8 @@ void TPSender::Send(const RPCMessage &msg, int recv_id) {
}
// Let's write blockingly in case of congestion in underlying transports.
auto done = std::make_shared<std::promise<void>>();
pipe->write(tp_msg,
[ndarray_holder, recv_id, done](const tensorpipe::Error &error) {
pipe->write(
tp_msg, [ndarray_holder, recv_id, done](const tensorpipe::Error &error) {
if (error) {
LOG(FATAL) << "Failed to send message to " << recv_id
<< ". Details: " << error.what();
......@@ -120,7 +120,8 @@ void TPReceiver::OnAccepted(const Error &error, std::shared_ptr<Pipe> pipe) {
if (error.isOfType<ListenerClosedError>()) {
// Expected.
} else {
LOG(WARNING) << "Unexpected error when accepting incoming pipe: " << error.what();
LOG(WARNING) << "Unexpected error when accepting incoming pipe: "
<< error.what();
}
return;
}
......@@ -133,7 +134,8 @@ void TPReceiver::OnAccepted(const Error &error, std::shared_ptr<Pipe> pipe) {
// read the handshake message: "dglconnect"
pipe->readDescriptor([pipe, this](const Error &error, Descriptor descriptor) {
if (error) {
LOG(ERROR) << "Unexpected error when reading from accepted pipe: " << error.what();
LOG(ERROR) << "Unexpected error when reading from accepted pipe: "
<< error.what();
return;
}
Allocation allocation;
......@@ -145,10 +147,10 @@ void TPReceiver::OnAccepted(const Error &error, std::shared_ptr<Pipe> pipe) {
});
}
void TPReceiver::ReceiveFromPipe(std::shared_ptr<Pipe> pipe,
std::shared_ptr<RPCMessageQueue> queue) {
pipe->readDescriptor([pipe, queue = std::move(queue)](const Error &error,
Descriptor descriptor) {
void TPReceiver::ReceiveFromPipe(
std::shared_ptr<Pipe> pipe, std::shared_ptr<RPCMessageQueue> queue) {
pipe->readDescriptor([pipe, queue = std::move(queue)](
const Error &error, Descriptor descriptor) {
if (error) {
// Error may happen when the pipe is closed
return;
......@@ -165,22 +167,24 @@ void TPReceiver::ReceiveFromPipe(std::shared_ptr<Pipe> pipe,
allocation.tensors[i].buffer = cpu_buffer;
}
}
pipe->read(allocation, [allocation, descriptor = std::move(descriptor),
queue = std::move(queue),
pipe](const Error &error) {
pipe->read(
allocation, [allocation, descriptor = std::move(descriptor),
queue = std::move(queue), pipe](const Error &error) {
if (error) {
// Because we always have a read event posted to the epoll,
// Therefore when pipe is closed, error will be raised.
// But this error is expected.
// Other error is not expected. But we cannot identify the error with
// each Other for now. Thus here we skip handling for all errors
// Other error is not expected. But we cannot identify the error
// with each Other for now. Thus here we skip handling for all
// errors
return;
}
char *meta_msg_begin = const_cast<char *>(&descriptor.metadata[0]);
std::vector<void *> buffer_list(descriptor.tensors.size());
for (size_t i = 0; i < descriptor.tensors.size(); i++) {
buffer_list[i] = allocation.tensors[i].buffer.unwrap<CpuBuffer>().ptr;
buffer_list[i] =
allocation.tensors[i].buffer.unwrap<CpuBuffer>().ptr;
}
StreamWithBuffer zc_read_strm(
meta_msg_begin, descriptor.metadata.size() - sizeof(int32_t),
......
......@@ -9,15 +9,16 @@
#include <dmlc/logging.h>
#include <tensorpipe/tensorpipe.h>
#include <atomic>
#include <deque>
#include <memory>
#include <string>
#include <thread>
#include <unordered_map>
#include <vector>
#include <atomic>
#include "./queue.h"
#include "../net_type.h"
#include "./queue.h"
namespace dgl {
namespace rpc {
......@@ -48,9 +49,10 @@ class TPSender : public RPCSender {
/*!
* \brief Connect to a receiver.
*
* When there are multiple receivers to be connected, application will call `ConnectReceiver`
* for each and then call `ConnectReceiverFinalize` to make sure that either all the connections are
* successfully established or some of them fail.
* When there are multiple receivers to be connected, application will call
* `ConnectReceiver` for each and then call `ConnectReceiverFinalize` to make
* sure that either all the connections are successfully established or some
* of them fail.
*
* \param addr Networking address, e.g., 'tcp://127.0.0.1:50091'
* \param recv_id receiver's ID
......@@ -75,7 +77,7 @@ class TPSender : public RPCSender {
/*!
* \brief Communicator type: 'tp'
*/
const std::string &NetType() const override {
const std::string& NetType() const override {
static const std::string net_type = "tensorpipe";
return net_type;
}
......@@ -129,13 +131,14 @@ class TPReceiver : public RPCReceiver {
*
* Wait() is not thread-safe and only one thread can invoke this API.
*/
bool Wait(const std::string &addr, int num_sender,
bool blocking = true) override;
bool Wait(
const std::string& addr, int num_sender, bool blocking = true) override;
/*!
* \brief Recv RPCMessage from Sender. Actually removing data from queue.
* \param msg pointer of RPCmessage
* \param timeout The timeout value in milliseconds. If zero, wait indefinitely.
* \param timeout The timeout value in milliseconds. If zero, wait
* indefinitely.
* \return RPCStatus: kRPCSuccess or kRPCTimeOut.
*/
RPCStatus Recv(RPCMessage* msg, int timeout) override;
......@@ -150,7 +153,7 @@ class TPReceiver : public RPCReceiver {
/*!
* \brief Communicator type: 'tp' (tensorpipe)
*/
const std::string &NetType() const override {
const std::string& NetType() const override {
static const std::string net_type = "tensorpipe";
return net_type;
}
......@@ -158,7 +161,8 @@ class TPReceiver : public RPCReceiver {
/*!
* \brief Issue a receive request on pipe, and push the result into queue
*/
static void ReceiveFromPipe(std::shared_ptr<tensorpipe::Pipe> pipe,
static void ReceiveFromPipe(
std::shared_ptr<tensorpipe::Pipe> pipe,
std::shared_ptr<RPCMessageQueue> queue);
private:
......@@ -186,8 +190,8 @@ class TPReceiver : public RPCReceiver {
/*!
* \brief pipe for each client connections
*/
std::unordered_map<int /* Sender (virutal) ID */,
std::shared_ptr<tensorpipe::Pipe>>
std::unordered_map<
int /* Sender (virutal) ID */, std::shared_ptr<tensorpipe::Pipe>>
pipes_;
/*!
......
......@@ -3,16 +3,18 @@
* Implementation of C API (reference: tvm/src/api/c_api.cc)
* \file c_api.cc
*/
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/c_object_api.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/runtime/object.h>
#include <dmlc/base.h>
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <dgl/runtime/c_object_api.h>
#include <dgl/runtime/object.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/packed_func_ext.h>
#include <vector>
#include <string>
#include <exception>
#include <string>
#include <vector>
#include "runtime_base.h"
/*! \brief entry to to easily hold returning information */
......@@ -20,7 +22,7 @@ struct DGLAPIThreadLocalEntry {
/*! \brief result holder for returning strings */
std::vector<std::string> ret_vec_str;
/*! \brief result holder for returning string pointers */
std::vector<const char *> ret_vec_charp;
std::vector<const char*> ret_vec_charp;
/*! \brief result holder for retruning string */
std::string ret_str;
};
......@@ -44,7 +46,8 @@ struct APIAttrGetter : public AttrVisitor {
if (skey == key) *ret = value[0];
}
void Visit(const char* key, uint64_t* value) final {
CHECK_LE(value[0], static_cast<uint64_t>(std::numeric_limits<int64_t>::max()))
CHECK_LE(
value[0], static_cast<uint64_t>(std::numeric_limits<int64_t>::max()))
<< "cannot return too big constant";
if (skey == key) *ret = static_cast<int64_t>(value[0]);
}
......@@ -71,30 +74,16 @@ struct APIAttrGetter : public AttrVisitor {
struct APIAttrDir : public AttrVisitor {
std::vector<std::string>* names;
void Visit(const char* key, double* value) final {
names->push_back(key);
}
void Visit(const char* key, int64_t* value) final {
names->push_back(key);
}
void Visit(const char* key, uint64_t* value) final {
names->push_back(key);
}
void Visit(const char* key, bool* value) final {
names->push_back(key);
}
void Visit(const char* key, int* value) final {
names->push_back(key);
}
void Visit(const char* key, double* value) final { names->push_back(key); }
void Visit(const char* key, int64_t* value) final { names->push_back(key); }
void Visit(const char* key, uint64_t* value) final { names->push_back(key); }
void Visit(const char* key, bool* value) final { names->push_back(key); }
void Visit(const char* key, int* value) final { names->push_back(key); }
void Visit(const char* key, std::string* value) final {
names->push_back(key);
}
void Visit(const char* key, ObjectRef* value) final {
names->push_back(key);
}
void Visit(const char* key, NDArray* value) final {
names->push_back(key);
}
void Visit(const char* key, ObjectRef* value) final { names->push_back(key); }
void Visit(const char* key, NDArray* value) final { names->push_back(key); }
};
int DGLObjectFree(ObjectHandle handle) {
......@@ -103,25 +92,21 @@ int DGLObjectFree(ObjectHandle handle) {
API_END();
}
int DGLObjectTypeKey2Index(const char* type_key,
int* out_index) {
int DGLObjectTypeKey2Index(const char* type_key, int* out_index) {
API_BEGIN();
*out_index = static_cast<int>(Object::TypeKey2Index(type_key));
API_END();
}
int DGLObjectGetTypeIndex(ObjectHandle handle,
int* out_index) {
int DGLObjectGetTypeIndex(ObjectHandle handle, int* out_index) {
API_BEGIN();
*out_index = static_cast<int>(
(*static_cast<DGLAPIObject*>(handle))->type_index());
*out_index =
static_cast<int>((*static_cast<DGLAPIObject*>(handle))->type_index());
API_END();
}
int DGLObjectGetAttr(ObjectHandle handle,
const char* key,
DGLValue* ret_val,
int* ret_type_code,
int DGLObjectGetAttr(
ObjectHandle handle, const char* key, DGLValue* ret_val, int* ret_type_code,
int* ret_success) {
API_BEGIN();
DGLRetValue rv;
......@@ -136,9 +121,8 @@ int DGLObjectGetAttr(ObjectHandle handle,
} else {
(*tobject)->VisitAttrs(&getter);
*ret_success = getter.found_object_ref || rv.type_code() != kNull;
if (rv.type_code() == kStr ||
rv.type_code() == kDGLDataType) {
DGLAPIThreadLocalEntry *e = DGLAPIThreadLocalStore::Get();
if (rv.type_code() == kStr || rv.type_code() == kDGLDataType) {
DGLAPIThreadLocalEntry* e = DGLAPIThreadLocalStore::Get();
e->ret_str = rv.operator std::string();
*ret_type_code = kStr;
ret_val->v_str = e->ret_str.c_str();
......@@ -149,10 +133,9 @@ int DGLObjectGetAttr(ObjectHandle handle,
API_END();
}
int DGLObjectListAttrNames(ObjectHandle handle,
int *out_size,
const char*** out_array) {
DGLAPIThreadLocalEntry *ret = DGLAPIThreadLocalStore::Get();
int DGLObjectListAttrNames(
ObjectHandle handle, int* out_size, const char*** out_array) {
DGLAPIThreadLocalEntry* ret = DGLAPIThreadLocalStore::Get();
API_BEGIN();
ret->ret_vec_str.clear();
DGLAPIObject* tobject = static_cast<DGLAPIObject*>(handle);
......
......@@ -3,18 +3,20 @@
* \file c_runtime_api.cc
* \brief Runtime API implementation
*/
#include <dmlc/thread_local.h>
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/c_backend_api.h>
#include <dgl/runtime/packed_func.h>
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/device_api.h>
#include <dgl/runtime/module.h>
#include <dgl/runtime/packed_func.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/device_api.h>
#include <dgl/runtime/tensordispatch.h>
#include <array>
#include <dmlc/thread_local.h>
#include <algorithm>
#include <string>
#include <array>
#include <cstdlib>
#include <string>
#include "runtime_base.h"
namespace dgl {
......@@ -26,10 +28,14 @@ namespace runtime {
*/
inline std::string DeviceName(int type) {
switch (type) {
case kDGLCPU: return "cpu";
case kDGLCUDA: return "cuda";
case kDGLCPU:
return "cpu";
case kDGLCUDA:
return "cuda";
// add more device here once supported
default: LOG(FATAL) << "unknown type =" << type; return "Unknown";
default:
LOG(FATAL) << "unknown type =" << type;
return "Unknown";
}
}
......@@ -37,9 +43,7 @@ class DeviceAPIManager {
public:
static const int kMaxDeviceAPI = 32;
// Get API
static DeviceAPI* Get(const DGLContext& ctx) {
return Get(ctx.device_type);
}
static DeviceAPI* Get(const DGLContext& ctx) { return Get(ctx.device_type); }
static DeviceAPI* Get(int dev_type, bool allow_missing = false) {
return Global()->GetAPI(dev_type, allow_missing);
}
......@@ -49,9 +53,7 @@ class DeviceAPIManager {
DeviceAPI* rpc_api_{nullptr};
std::mutex mutex_;
// constructor
DeviceAPIManager() {
std::fill(api_.begin(), api_.end(), nullptr);
}
DeviceAPIManager() { std::fill(api_.begin(), api_.end(), nullptr); }
// Global static variable.
static DeviceAPIManager* Global() {
static DeviceAPIManager inst;
......@@ -78,7 +80,8 @@ class DeviceAPIManager {
auto* f = Registry::Get(factory);
if (f == nullptr) {
CHECK(allow_missing)
<< "Device API " << name << " is not enabled. Please install the cuda version of dgl.";
<< "Device API " << name
<< " is not enabled. Please install the cuda version of dgl.";
return nullptr;
}
void* ptr = (*f)();
......@@ -95,9 +98,8 @@ DeviceAPI* DeviceAPI::Get(DGLDeviceType dev_type, bool allow_missing) {
return DeviceAPIManager::Get(static_cast<int>(dev_type), allow_missing);
}
void* DeviceAPI::AllocWorkspace(DGLContext ctx,
size_t size,
DGLDataType type_hint) {
void* DeviceAPI::AllocWorkspace(
DGLContext ctx, size_t size, DGLDataType type_hint) {
return AllocDataSpace(ctx, size, kTempAllocaAlignment, type_hint);
}
......@@ -114,9 +116,8 @@ void DeviceAPI::FreeStream(DGLContext ctx, DGLStreamHandle stream) {
LOG(FATAL) << "Device does not support stream api.";
}
void DeviceAPI::SyncStreamFromTo(DGLContext ctx,
DGLStreamHandle event_src,
DGLStreamHandle event_dst) {
void DeviceAPI::SyncStreamFromTo(
DGLContext ctx, DGLStreamHandle event_src, DGLStreamHandle event_dst) {
LOG(FATAL) << "Device does not support stream api.";
}
......@@ -140,7 +141,7 @@ struct DGLRuntimeEntry {
typedef dmlc::ThreadLocalStore<DGLRuntimeEntry> DGLAPIRuntimeStore;
const char *DGLGetLastError() {
const char* DGLGetLastError() {
return DGLAPIRuntimeStore::Get()->last_error.c_str();
}
......@@ -152,30 +153,26 @@ void DGLAPISetLastError(const char* msg) {
#endif
}
int DGLModLoadFromFile(const char* file_name,
const char* format,
DGLModuleHandle* out) {
int DGLModLoadFromFile(
const char* file_name, const char* format, DGLModuleHandle* out) {
API_BEGIN();
Module m = Module::LoadFromFile(file_name, format);
*out = new Module(m);
API_END();
}
int DGLModImport(DGLModuleHandle mod,
DGLModuleHandle dep) {
int DGLModImport(DGLModuleHandle mod, DGLModuleHandle dep) {
API_BEGIN();
static_cast<Module*>(mod)->Import(
*static_cast<Module*>(dep));
static_cast<Module*>(mod)->Import(*static_cast<Module*>(dep));
API_END();
}
int DGLModGetFunction(DGLModuleHandle mod,
const char* func_name,
int query_imports,
DGLFunctionHandle *func) {
int DGLModGetFunction(
DGLModuleHandle mod, const char* func_name, int query_imports,
DGLFunctionHandle* func) {
API_BEGIN();
PackedFunc pf = static_cast<Module*>(mod)->GetFunction(
func_name, query_imports != 0);
PackedFunc pf =
static_cast<Module*>(mod)->GetFunction(func_name, query_imports != 0);
if (pf != nullptr) {
*func = new PackedFunc(pf);
} else {
......@@ -190,19 +187,17 @@ int DGLModFree(DGLModuleHandle mod) {
API_END();
}
int DGLBackendGetFuncFromEnv(void* mod_node,
const char* func_name,
DGLFunctionHandle *func) {
int DGLBackendGetFuncFromEnv(
void* mod_node, const char* func_name, DGLFunctionHandle* func) {
API_BEGIN();
*func = (DGLFunctionHandle)(
static_cast<ModuleNode*>(mod_node)->GetFuncFromEnv(func_name));
*func =
(DGLFunctionHandle)(static_cast<ModuleNode*>(mod_node)->GetFuncFromEnv(
func_name));
API_END();
}
void* DGLBackendAllocWorkspace(int device_type,
int device_id,
uint64_t size,
int dtype_code_hint,
void* DGLBackendAllocWorkspace(
int device_type, int device_id, uint64_t size, int dtype_code_hint,
int dtype_bits_hint) {
DGLContext ctx;
ctx.device_type = static_cast<DGLDeviceType>(device_type);
......@@ -213,14 +208,11 @@ void* DGLBackendAllocWorkspace(int device_type,
type_hint.bits = static_cast<decltype(type_hint.bits)>(dtype_bits_hint);
type_hint.lanes = 1;
return DeviceAPIManager::Get(ctx)->AllocWorkspace(ctx,
static_cast<size_t>(size),
type_hint);
return DeviceAPIManager::Get(ctx)->AllocWorkspace(
ctx, static_cast<size_t>(size), type_hint);
}
int DGLBackendFreeWorkspace(int device_type,
int device_id,
void* ptr) {
int DGLBackendFreeWorkspace(int device_type, int device_id, void* ptr) {
DGLContext ctx;
ctx.device_type = static_cast<DGLDeviceType>(device_type);
ctx.device_id = device_id;
......@@ -228,10 +220,7 @@ int DGLBackendFreeWorkspace(int device_type,
return 0;
}
int DGLBackendRunOnce(void** handle,
int (*f)(void*),
void* cdata,
int nbytes) {
int DGLBackendRunOnce(void** handle, int (*f)(void*), void* cdata, int nbytes) {
if (*handle == nullptr) {
*handle = reinterpret_cast<void*>(1);
return (*f)(cdata);
......@@ -245,19 +234,15 @@ int DGLFuncFree(DGLFunctionHandle func) {
API_END();
}
int DGLFuncCall(DGLFunctionHandle func,
DGLValue* args,
int* arg_type_codes,
int num_args,
DGLValue* ret_val,
int* ret_type_code) {
int DGLFuncCall(
DGLFunctionHandle func, DGLValue* args, int* arg_type_codes, int num_args,
DGLValue* ret_val, int* ret_type_code) {
API_BEGIN();
DGLRetValue rv;
(*static_cast<const PackedFunc*>(func)).CallPacked(
DGLArgs(args, arg_type_codes, num_args), &rv);
(*static_cast<const PackedFunc*>(func))
.CallPacked(DGLArgs(args, arg_type_codes, num_args), &rv);
// handle return string.
if (rv.type_code() == kStr ||
rv.type_code() == kDGLDataType ||
if (rv.type_code() == kStr || rv.type_code() == kDGLDataType ||
rv.type_code() == kBytes) {
DGLRuntimeEntry* e = DGLAPIRuntimeStore::Get();
if (rv.type_code() != kDGLDataType) {
......@@ -280,10 +265,8 @@ int DGLFuncCall(DGLFunctionHandle func,
API_END();
}
int DGLCFuncSetReturn(DGLRetValueHandle ret,
DGLValue* value,
int* type_code,
int num_ret) {
int DGLCFuncSetReturn(
DGLRetValueHandle ret, DGLValue* value, int* type_code, int num_ret) {
API_BEGIN();
CHECK_EQ(num_ret, 1);
DGLRetValue* rv = static_cast<DGLRetValue*>(ret);
......@@ -291,15 +274,15 @@ int DGLCFuncSetReturn(DGLRetValueHandle ret,
API_END();
}
int DGLFuncCreateFromCFunc(DGLPackedCFunc func,
void* resource_handle,
DGLPackedCFuncFinalizer fin,
DGLFunctionHandle *out) {
int DGLFuncCreateFromCFunc(
DGLPackedCFunc func, void* resource_handle, DGLPackedCFuncFinalizer fin,
DGLFunctionHandle* out) {
API_BEGIN();
if (fin == nullptr) {
*out = new PackedFunc(
[func, resource_handle](DGLArgs args, DGLRetValue* rv) {
int ret = func((DGLValue*)args.values, (int*)args.type_codes, // NOLINT(*)
*out =
new PackedFunc([func, resource_handle](DGLArgs args, DGLRetValue* rv) {
int ret = func(
(DGLValue*)args.values, (int*)args.type_codes, // NOLINT(*)
args.num_args, rv, resource_handle);
if (ret != 0) {
std::string err = "DGLCall CFunc Error:\n";
......@@ -311,9 +294,9 @@ int DGLFuncCreateFromCFunc(DGLPackedCFunc func,
// wrap it in a shared_ptr, with fin as deleter.
// so fin will be called when the lambda went out of scope.
std::shared_ptr<void> rpack(resource_handle, fin);
*out = new PackedFunc(
[func, rpack](DGLArgs args, DGLRetValue* rv) {
int ret = func((DGLValue*)args.values, (int*)args.type_codes, // NOLINT(*)
*out = new PackedFunc([func, rpack](DGLArgs args, DGLRetValue* rv) {
int ret = func(
(DGLValue*)args.values, (int*)args.type_codes, // NOLINT(*)
args.num_args, rv, rpack.get());
if (ret != 0) {
std::string err = "DGLCall CFunc Error:\n";
......@@ -370,10 +353,8 @@ int DGLSynchronize(int device_type, int device_id, DGLStreamHandle stream) {
API_END();
}
int DGLStreamStreamSynchronize(int device_type,
int device_id,
DGLStreamHandle src,
DGLStreamHandle dst) {
int DGLStreamStreamSynchronize(
int device_type, int device_id, DGLStreamHandle src, DGLStreamHandle dst) {
API_BEGIN();
DGLContext ctx;
ctx.device_type = static_cast<DGLDeviceType>(device_type);
......@@ -392,13 +373,13 @@ int DGLCbArgToReturn(DGLValue* value, int code) {
API_END();
}
int DGLLoadTensorAdapter(const char *path) {
int DGLLoadTensorAdapter(const char* path) {
return TensorDispatcher::Global()->Load(path) ? 0 : -1;
}
// set device api
DGL_REGISTER_GLOBAL(dgl::runtime::symbol::dgl_set_device)
.set_body([](DGLArgs args, DGLRetValue *ret) {
.set_body([](DGLArgs args, DGLRetValue* ret) {
DGLContext ctx;
ctx.device_type = static_cast<DGLDeviceType>(args[0].operator int());
ctx.device_id = args[1];
......@@ -407,7 +388,7 @@ DGL_REGISTER_GLOBAL(dgl::runtime::symbol::dgl_set_device)
// set device api
DGL_REGISTER_GLOBAL("_GetDeviceAttr")
.set_body([](DGLArgs args, DGLRetValue *ret) {
.set_body([](DGLArgs args, DGLRetValue* ret) {
DGLContext ctx;
ctx.device_type = static_cast<DGLDeviceType>(args[0].operator int());
ctx.device_id = args[1];
......@@ -424,4 +405,3 @@ DGL_REGISTER_GLOBAL("_GetDeviceAttr")
DeviceAPIManager::Get(ctx)->GetAttr(ctx, kind, ret);
}
});
......@@ -4,30 +4,26 @@
* \brief DGL runtime config
*/
#include <dgl/runtime/registry.h>
#include <dgl/runtime/config.h>
#include <dgl/runtime/registry.h>
using namespace dgl::runtime;
namespace dgl {
namespace runtime {
void Config::EnableLibxsmm(bool b) {
libxsmm_ = b;
}
void Config::EnableLibxsmm(bool b) { libxsmm_ = b; }
bool Config::IsLibxsmmAvailable() const {
return libxsmm_;
}
bool Config::IsLibxsmmAvailable() const { return libxsmm_; }
DGL_REGISTER_GLOBAL("global_config._CAPI_DGLConfigSetLibxsmm")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
bool use_libxsmm = args[0];
dgl::runtime::Config::Global()->EnableLibxsmm(use_libxsmm);
});
DGL_REGISTER_GLOBAL("global_config._CAPI_DGLConfigGetLibxsmm")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
*rv = dgl::runtime::Config::Global()->IsLibxsmmAvailable();
});
......
......@@ -2,13 +2,15 @@
* Copyright (c) 2016-2022 by Contributors
* \file cpu_device_api.cc
*/
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/device_api.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/tensordispatch.h>
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <cstdlib>
#include <cstring>
#include "workspace_pool.h"
namespace dgl {
......@@ -21,13 +23,11 @@ class CPUDeviceAPI final : public DeviceAPI {
*rv = 1;
}
}
void* AllocDataSpace(DGLContext ctx,
size_t nbytes,
size_t alignment,
void* AllocDataSpace(
DGLContext ctx, size_t nbytes, size_t alignment,
DGLDataType type_hint) final {
TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable())
return td->CPUAllocWorkspace(nbytes);
if (td->IsAvailable()) return td->CPUAllocWorkspace(nbytes);
void* ptr;
#if _MSC_VER || defined(__MINGW32__)
......@@ -45,8 +45,7 @@ class CPUDeviceAPI final : public DeviceAPI {
void FreeDataSpace(DGLContext ctx, void* ptr) final {
TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable())
return td->CPUFreeWorkspace(ptr);
if (td->IsAvailable()) return td->CPUFreeWorkspace(ptr);
#if _MSC_VER || defined(__MINGW32__)
_aligned_free(ptr);
......@@ -55,25 +54,21 @@ class CPUDeviceAPI final : public DeviceAPI {
#endif
}
void CopyDataFromTo(const void* from,
size_t from_offset,
void* to,
size_t to_offset,
size_t size,
DGLContext ctx_from,
DGLContext ctx_to,
void CopyDataFromTo(
const void* from, size_t from_offset, void* to, size_t to_offset,
size_t size, DGLContext ctx_from, DGLContext ctx_to,
DGLDataType type_hint) final {
memcpy(static_cast<char*>(to) + to_offset,
static_cast<const char*>(from) + from_offset,
size);
memcpy(
static_cast<char*>(to) + to_offset,
static_cast<const char*>(from) + from_offset, size);
}
DGLStreamHandle CreateStream(DGLContext) final { return nullptr; }
void StreamSync(DGLContext ctx, DGLStreamHandle stream) final {
}
void StreamSync(DGLContext ctx, DGLStreamHandle stream) final {}
void* AllocWorkspace(DGLContext ctx, size_t size, DGLDataType type_hint) final;
void* AllocWorkspace(
DGLContext ctx, size_t size, DGLDataType type_hint) final;
void FreeWorkspace(DGLContext ctx, void* data) final;
static const std::shared_ptr<CPUDeviceAPI>& Global() {
......@@ -84,30 +79,27 @@ class CPUDeviceAPI final : public DeviceAPI {
};
struct CPUWorkspacePool : public WorkspacePool {
CPUWorkspacePool() :
WorkspacePool(kDGLCPU, CPUDeviceAPI::Global()) {}
CPUWorkspacePool() : WorkspacePool(kDGLCPU, CPUDeviceAPI::Global()) {}
};
void* CPUDeviceAPI::AllocWorkspace(DGLContext ctx,
size_t size,
DGLDataType type_hint) {
void* CPUDeviceAPI::AllocWorkspace(
DGLContext ctx, size_t size, DGLDataType type_hint) {
TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable())
return td->CPUAllocWorkspace(size);
if (td->IsAvailable()) return td->CPUAllocWorkspace(size);
return dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()->AllocWorkspace(ctx, size);
return dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()->AllocWorkspace(
ctx, size);
}
void CPUDeviceAPI::FreeWorkspace(DGLContext ctx, void* data) {
TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable())
return td->CPUFreeWorkspace(data);
if (td->IsAvailable()) return td->CPUFreeWorkspace(data);
dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()->FreeWorkspace(ctx, data);
}
DGL_REGISTER_GLOBAL("device_api.cpu")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
DeviceAPI* ptr = CPUDeviceAPI::Global().get();
*rv = static_cast<void*>(ptr);
});
......
......@@ -7,11 +7,13 @@
#define DGL_RUNTIME_CUDA_CUDA_COMMON_H_
#include <cublas_v2.h>
#include <cusparse.h>
#include <cuda_runtime.h>
#include <curand.h>
#include <cusparse.h>
#include <dgl/runtime/packed_func.h>
#include <string>
#include "../workspace_pool.h"
namespace dgl {
......@@ -31,10 +33,9 @@ inline bool is_zero<dim3>(dim3 size) {
{ \
CUresult result = x; \
if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) { \
const char *msg; \
const char* msg; \
cuGetErrorName(result, &msg); \
LOG(FATAL) \
<< "CUDAError: " #x " failed with error: " << msg; \
LOG(FATAL) << "CUDAError: " #x " failed with error: " << msg; \
} \
}
......@@ -47,22 +48,18 @@ inline bool is_zero<dim3>(dim3 size) {
#define CUDA_KERNEL_CALL(kernel, nblks, nthrs, shmem, stream, ...) \
{ \
if (!dgl::runtime::is_zero((nblks)) && \
!dgl::runtime::is_zero((nthrs))) { \
(kernel) <<< (nblks), (nthrs), (shmem), (stream) >>> \
(__VA_ARGS__); \
if (!dgl::runtime::is_zero((nblks)) && !dgl::runtime::is_zero((nthrs))) { \
(kernel)<<<(nblks), (nthrs), (shmem), (stream)>>>(__VA_ARGS__); \
cudaError_t e = cudaGetLastError(); \
CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \
<< "CUDA kernel launch error: " \
<< cudaGetErrorString(e); \
<< "CUDA kernel launch error: " << cudaGetErrorString(e); \
} \
}
#define CUSPARSE_CALL(func) \
{ \
cusparseStatus_t e = (func); \
CHECK(e == CUSPARSE_STATUS_SUCCESS) \
<< "CUSPARSE ERROR: " << e; \
CHECK(e == CUSPARSE_STATUS_SUCCESS) << "CUSPARSE ERROR: " << e; \
}
#define CUBLAS_CALL(func) \
......@@ -72,12 +69,12 @@ inline bool is_zero<dim3>(dim3 size) {
}
#define CURAND_CALL(func) \
{ \
{ \
curandStatus_t e = (func); \
CHECK(e == CURAND_STATUS_SUCCESS) \
<< "CURAND Error: " << dgl::runtime::curandGetErrorString(e) \
<< " at " << __FILE__ << ":" << __LINE__; \
}
<< "CURAND Error: " << dgl::runtime::curandGetErrorString(e) << " at " \
<< __FILE__ << ":" << __LINE__; \
}
inline const char* curandGetErrorString(curandStatus_t error) {
switch (error) {
......
......@@ -3,11 +3,12 @@
* \file cuda_device_api.cc
* \brief GPU specific API
*/
#include <cuda_runtime.h>
#include <dgl/runtime/device_api.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/tensordispatch.h>
#include <dmlc/thread_local.h>
#include <dgl/runtime/registry.h>
#include <cuda_runtime.h>
#include "cuda_common.h"
namespace dgl {
......@@ -28,9 +29,7 @@ class CUDADeviceAPI final : public DeviceAPI {
is_available_ = count > 0;
}
bool IsAvailable() final {
return is_available_;
}
bool IsAvailable() final { return is_available_; }
void SetDevice(DGLContext ctx) final {
CUDA_CALL(cudaSetDevice(ctx.device_id));
......@@ -39,10 +38,10 @@ class CUDADeviceAPI final : public DeviceAPI {
int value = 0;
switch (kind) {
case kExist:
value = (
cudaDeviceGetAttribute(
&value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id)
== cudaSuccess);
value =
(cudaDeviceGetAttribute(
&value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id) ==
cudaSuccess);
break;
case kMaxThreadsPerBlock: {
CUDA_CALL(cudaDeviceGetAttribute(
......@@ -50,8 +49,8 @@ class CUDADeviceAPI final : public DeviceAPI {
break;
}
case kWarpSize: {
CUDA_CALL(cudaDeviceGetAttribute(
&value, cudaDevAttrWarpSize, ctx.device_id));
CUDA_CALL(
cudaDeviceGetAttribute(&value, cudaDevAttrWarpSize, ctx.device_id));
break;
}
case kMaxSharedMemoryPerBlock: {
......@@ -96,16 +95,15 @@ class CUDADeviceAPI final : public DeviceAPI {
&dims[2], cudaDevAttrMaxBlockDimZ, ctx.device_id));
std::stringstream ss; // use json string to return multiple int values;
ss << "[" << dims[0] <<", " << dims[1] << ", " << dims[2] << "]";
ss << "[" << dims[0] << ", " << dims[1] << ", " << dims[2] << "]";
*rv = ss.str();
return;
}
}
*rv = value;
}
void* AllocDataSpace(DGLContext ctx,
size_t nbytes,
size_t alignment,
void* AllocDataSpace(
DGLContext ctx, size_t nbytes, size_t alignment,
DGLDataType type_hint) final {
SetDevice(ctx);
// Redirect to PyTorch's allocator when available.
......@@ -113,9 +111,8 @@ class CUDADeviceAPI final : public DeviceAPI {
if (td->IsAvailable())
return td->CUDAAllocWorkspace(nbytes, getCurrentCUDAStream());
CHECK_EQ(256 % alignment, 0U)
<< "CUDA space is aligned at 256 bytes";
void *ret;
CHECK_EQ(256 % alignment, 0U) << "CUDA space is aligned at 256 bytes";
void* ret;
CUDA_CALL(cudaMalloc(&ret, nbytes));
return ret;
}
......@@ -123,21 +120,15 @@ class CUDADeviceAPI final : public DeviceAPI {
void FreeDataSpace(DGLContext ctx, void* ptr) final {
SetDevice(ctx);
TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable())
return td->CUDAFreeWorkspace(ptr);
if (td->IsAvailable()) return td->CUDAFreeWorkspace(ptr);
CUDA_CALL(cudaFree(ptr));
}
void CopyDataFromTo(const void* from,
size_t from_offset,
void* to,
size_t to_offset,
size_t size,
DGLContext ctx_from,
DGLContext ctx_to,
DGLDataType type_hint,
DGLStreamHandle stream) {
void CopyDataFromTo(
const void* from, size_t from_offset, void* to, size_t to_offset,
size_t size, DGLContext ctx_from, DGLContext ctx_to,
DGLDataType type_hint, DGLStreamHandle stream) {
cudaStream_t cu_stream = static_cast<cudaStream_t>(stream);
from = static_cast<const char*>(from) + from_offset;
to = static_cast<char*>(to) + to_offset;
......@@ -146,14 +137,15 @@ class CUDADeviceAPI final : public DeviceAPI {
if (ctx_from.device_id == ctx_to.device_id) {
GPUCopy(from, to, size, cudaMemcpyDeviceToDevice, cu_stream);
} else {
CUDA_CALL(cudaMemcpyPeerAsync(to, ctx_to.device_id,
from, ctx_from.device_id,
size, cu_stream));
CUDA_CALL(cudaMemcpyPeerAsync(
to, ctx_to.device_id, from, ctx_from.device_id, size, cu_stream));
}
} else if (ctx_from.device_type == kDGLCUDA && ctx_to.device_type == kDGLCPU) {
} else if (
ctx_from.device_type == kDGLCUDA && ctx_to.device_type == kDGLCPU) {
CUDA_CALL(cudaSetDevice(ctx_from.device_id));
GPUCopy(from, to, size, cudaMemcpyDeviceToHost, cu_stream);
} else if (ctx_from.device_type == kDGLCPU && ctx_to.device_type == kDGLCUDA) {
} else if (
ctx_from.device_type == kDGLCPU && ctx_to.device_type == kDGLCUDA) {
CUDA_CALL(cudaSetDevice(ctx_to.device_id));
GPUCopy(from, to, size, cudaMemcpyHostToDevice, cu_stream);
} else {
......@@ -161,16 +153,14 @@ class CUDADeviceAPI final : public DeviceAPI {
}
}
void CopyDataFromTo(const void* from,
size_t from_offset,
void* to,
size_t to_offset,
size_t size,
DGLContext ctx_from,
DGLContext ctx_to,
void CopyDataFromTo(
const void* from, size_t from_offset, void* to, size_t to_offset,
size_t size, DGLContext ctx_from, DGLContext ctx_to,
DGLDataType type_hint) final {
auto stream = GetStream();
CopyDataFromTo(from, from_offset, to, to_offset, size, ctx_from, ctx_to, type_hint, stream);
CopyDataFromTo(
from, from_offset, to, to_offset, size, ctx_from, ctx_to, type_hint,
stream);
}
DGLStreamHandle CreateStream(DGLContext ctx) {
......@@ -187,7 +177,8 @@ class CUDADeviceAPI final : public DeviceAPI {
CUDA_CALL(cudaStreamDestroy(cu_stream));
}
void SyncStreamFromTo(DGLContext ctx, DGLStreamHandle event_src, DGLStreamHandle event_dst) {
void SyncStreamFromTo(
DGLContext ctx, DGLStreamHandle event_src, DGLStreamHandle event_dst) {
CUDA_CALL(cudaSetDevice(ctx.device_id));
cudaStream_t src_stream = static_cast<cudaStream_t>(event_src);
cudaStream_t dst_stream = static_cast<cudaStream_t>(event_dst);
......@@ -222,21 +213,18 @@ class CUDADeviceAPI final : public DeviceAPI {
*/
void PinData(void* ptr, size_t nbytes) {
// prevent users from pinning empty tensors or graphs
if (ptr == nullptr || nbytes == 0)
return;
if (ptr == nullptr || nbytes == 0) return;
CUDA_CALL(cudaHostRegister(ptr, nbytes, cudaHostRegisterDefault));
}
void UnpinData(void* ptr) {
if (ptr == nullptr)
return;
if (ptr == nullptr) return;
CUDA_CALL(cudaHostUnregister(ptr));
}
bool IsPinned(const void* ptr) override {
// can't be a pinned tensor if CUDA context is unavailable.
if (!is_available_)
return false;
if (!is_available_) return false;
cudaPointerAttributes attr;
cudaError_t status = cudaPointerGetAttributes(&attr, ptr);
......@@ -254,22 +242,25 @@ class CUDADeviceAPI final : public DeviceAPI {
case cudaErrorNoDevice:
case cudaErrorInsufficientDriver:
case cudaErrorInvalidDevice:
// We don't want to fail in these particular cases since this function can be called
// when users only want to run on CPU even if CUDA API is enabled, or in a forked
// subprocess where CUDA context cannot be initialized. So we just mark the CUDA
// context to unavailable and return.
// We don't want to fail in these particular cases since this function
// can be called when users only want to run on CPU even if CUDA API is
// enabled, or in a forked subprocess where CUDA context cannot be
// initialized. So we just mark the CUDA context to unavailable and
// return.
is_available_ = false;
cudaGetLastError(); // clear error
break;
default:
LOG(FATAL) << "error while determining memory status: " << cudaGetErrorString(status);
LOG(FATAL) << "error while determining memory status: "
<< cudaGetErrorString(status);
break;
}
return result;
}
void* AllocWorkspace(DGLContext ctx, size_t size, DGLDataType type_hint) final {
void* AllocWorkspace(
DGLContext ctx, size_t size, DGLDataType type_hint) final {
SetDevice(ctx);
// Redirect to PyTorch's allocator when available.
TensorDispatcher* td = TensorDispatcher::Global();
......@@ -282,8 +273,7 @@ class CUDADeviceAPI final : public DeviceAPI {
void FreeWorkspace(DGLContext ctx, void* data) final {
SetDevice(ctx);
TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable())
return td->CUDAFreeWorkspace(data);
if (td->IsAvailable()) return td->CUDAFreeWorkspace(data);
CUDAThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
}
......@@ -295,14 +285,13 @@ class CUDADeviceAPI final : public DeviceAPI {
}
private:
static void GPUCopy(const void* from,
void* to,
size_t size,
cudaMemcpyKind kind,
static void GPUCopy(
const void* from, void* to, size_t size, cudaMemcpyKind kind,
cudaStream_t stream) {
CUDA_CALL(cudaMemcpyAsync(to, from, size, kind, stream));
if (stream == 0 && kind == cudaMemcpyDeviceToHost) {
// only wait for the copy, when it's on the default stream, and it's to host memory
// only wait for the copy, when it's on the default stream, and it's to
// host memory
CUDA_CALL(cudaStreamSynchronize(stream));
}
}
......@@ -312,9 +301,7 @@ class CUDADeviceAPI final : public DeviceAPI {
typedef dmlc::ThreadLocalStore<CUDAThreadEntry> CUDAThreadStore;
CUDAThreadEntry::CUDAThreadEntry()
: pool(kDGLCUDA, CUDADeviceAPI::Global()) {
}
CUDAThreadEntry::CUDAThreadEntry() : pool(kDGLCUDA, CUDADeviceAPI::Global()) {}
CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() {
return CUDAThreadStore::Get();
......@@ -329,7 +316,7 @@ cudaStream_t getCurrentCUDAStream() {
}
DGL_REGISTER_GLOBAL("device_api.cuda")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
DeviceAPI* ptr = CUDADeviceAPI::Global().get();
*rv = static_cast<void*>(ptr);
});
......
......@@ -6,10 +6,10 @@
#include <cassert>
#include "cuda_common.h"
#include "cuda_hashtable.cuh"
#include "../../array/cuda/atomic.cuh"
#include "../../array/cuda/dgl_cub.cuh"
#include "cuda_common.h"
#include "cuda_hashtable.cuh"
using namespace dgl::aten::cuda;
......@@ -23,12 +23,12 @@ constexpr static const int BLOCK_SIZE = 256;
constexpr static const size_t TILE_SIZE = 1024;
/**
* @brief This is the mutable version of the DeviceOrderedHashTable, for use in
* inserting elements into the hashtable.
*
* @tparam IdType The type of ID to store in the hashtable.
*/
template<typename IdType>
* @brief This is the mutable version of the DeviceOrderedHashTable, for use in
* inserting elements into the hashtable.
*
* @tparam IdType The type of ID to store in the hashtable.
*/
template <typename IdType>
class MutableDeviceOrderedHashTable : public DeviceOrderedHashTable<IdType> {
public:
typedef typename DeviceOrderedHashTable<IdType>::Mapping* Iterator;
......@@ -40,9 +40,8 @@ class MutableDeviceOrderedHashTable : public DeviceOrderedHashTable<IdType> {
* @param hostTable The original hash table on the host.
*/
explicit MutableDeviceOrderedHashTable(
OrderedHashTable<IdType>* const hostTable) :
DeviceOrderedHashTable<IdType>(hostTable->DeviceHandle()) {
}
OrderedHashTable<IdType>* const hostTable)
: DeviceOrderedHashTable<IdType>(hostTable->DeviceHandle()) {}
/**
* @brief Find the mutable mapping of a given key within the hash table.
......@@ -54,8 +53,7 @@ class MutableDeviceOrderedHashTable : public DeviceOrderedHashTable<IdType> {
*
* @return The mapping.
*/
inline __device__ Iterator Search(
const IdType id) {
inline __device__ Iterator Search(const IdType id) {
const IdType pos = SearchForPosition(id);
return GetMutable(pos);
......@@ -71,15 +69,15 @@ class MutableDeviceOrderedHashTable : public DeviceOrderedHashTable<IdType> {
* \return True, if the insertion was successful.
*/
inline __device__ bool AttemptInsertAt(
const size_t pos,
const IdType id,
const size_t index) {
const size_t pos, const IdType id, const size_t index) {
const IdType key = AtomicCAS(&GetMutable(pos)->key, kEmptyKey, id);
if (key == kEmptyKey || key == id) {
// we either set a match key, or found a matching key, so then place the
// minimum index in position. Match the type of atomicMin, so ignore
// linting
atomicMin(reinterpret_cast<unsigned long long*>(&GetMutable(pos)->index), // NOLINT
atomicMin(
reinterpret_cast<unsigned long long*>( // NOLINT
&GetMutable(pos)->index),
static_cast<unsigned long long>(index)); // NOLINT
return true;
} else {
......@@ -96,16 +94,14 @@ class MutableDeviceOrderedHashTable : public DeviceOrderedHashTable<IdType> {
*
* @return An iterator to inserted mapping.
*/
inline __device__ Iterator Insert(
const IdType id,
const size_t index) {
inline __device__ Iterator Insert(const IdType id, const size_t index) {
size_t pos = Hash(id);
// linearly scan for an empty slot or matching entry
IdType delta = 1;
while (!AttemptInsertAt(pos, id, index)) {
pos = Hash(pos+delta);
delta +=1;
pos = Hash(pos + delta);
delta += 1;
}
return GetMutable(pos);
......@@ -124,44 +120,40 @@ class MutableDeviceOrderedHashTable : public DeviceOrderedHashTable<IdType> {
// The parent class Device is read-only, but we ensure this can only be
// constructed from a mutable version of OrderedHashTable, making this
// a safe cast to perform.
return const_cast<Iterator>(this->table_+pos);
return const_cast<Iterator>(this->table_ + pos);
}
};
/**
* @brief Calculate the number of buckets in the hashtable. To guarantee we can
* fill the hashtable in the worst case, we must use a number of buckets which
* is a power of two.
* https://en.wikipedia.org/wiki/Quadratic_probing#Limitations
*
* @param num The number of items to insert (should be an upper bound on the
* number of unique keys).
* @param scale The power of two larger the number of buckets should be than the
* unique keys.
*
* @return The number of buckets the table should contain.
*/
size_t TableSize(
const size_t num,
const int scale) {
* @brief Calculate the number of buckets in the hashtable. To guarantee we can
* fill the hashtable in the worst case, we must use a number of buckets which
* is a power of two.
* https://en.wikipedia.org/wiki/Quadratic_probing#Limitations
*
* @param num The number of items to insert (should be an upper bound on the
* number of unique keys).
* @param scale The power of two larger the number of buckets should be than the
* unique keys.
*
* @return The number of buckets the table should contain.
*/
size_t TableSize(const size_t num, const int scale) {
const size_t next_pow2 = 1 << static_cast<size_t>(1 + std::log2(num >> 1));
return next_pow2 << scale;
}
/**
* @brief This structure is used with cub's block-level prefixscan in order to
* keep a running sum as items are iteratively processed.
*
* @tparam IdType The type to perform the prefixsum on.
*/
template<typename IdType>
* @brief This structure is used with cub's block-level prefixscan in order to
* keep a running sum as items are iteratively processed.
*
* @tparam IdType The type to perform the prefixsum on.
*/
template <typename IdType>
struct BlockPrefixCallbackOp {
IdType running_total_;
__device__ BlockPrefixCallbackOp(
const IdType running_total) :
running_total_(running_total) {
}
__device__ BlockPrefixCallbackOp(const IdType running_total)
: running_total_(running_total) {}
__device__ IdType operator()(const IdType block_aggregate) {
const IdType old_prefix = running_total_;
......@@ -173,28 +165,28 @@ struct BlockPrefixCallbackOp {
} // namespace
/**
* \brief This generates a hash map where the keys are the global item numbers,
* and the values are indexes, and inputs may have duplciates.
*
* \tparam IdType The type of of id.
* \tparam BLOCK_SIZE The size of the thread block.
* \tparam TILE_SIZE The number of entries each thread block will process.
* \param items The items to insert.
* \param num_items The number of items to insert.
* \param table The hash table.
*/
template<typename IdType, int BLOCK_SIZE, size_t TILE_SIZE>
* \brief This generates a hash map where the keys are the global item numbers,
* and the values are indexes, and inputs may have duplciates.
*
* \tparam IdType The type of of id.
* \tparam BLOCK_SIZE The size of the thread block.
* \tparam TILE_SIZE The number of entries each thread block will process.
* \param items The items to insert.
* \param num_items The number of items to insert.
* \param table The hash table.
*/
template <typename IdType, int BLOCK_SIZE, size_t TILE_SIZE>
__global__ void generate_hashmap_duplicates(
const IdType * const items,
const int64_t num_items,
const IdType* const items, const int64_t num_items,
MutableDeviceOrderedHashTable<IdType> table) {
assert(BLOCK_SIZE == blockDim.x);
const size_t block_start = TILE_SIZE*blockIdx.x;
const size_t block_end = TILE_SIZE*(blockIdx.x+1);
const size_t block_start = TILE_SIZE * blockIdx.x;
const size_t block_end = TILE_SIZE * (blockIdx.x + 1);
#pragma unroll
for (size_t index = threadIdx.x + block_start; index < block_end; index += BLOCK_SIZE) {
#pragma unroll
for (size_t index = threadIdx.x + block_start; index < block_end;
index += BLOCK_SIZE) {
if (index < num_items) {
table.Insert(items[index], index);
}
......@@ -202,30 +194,30 @@ __global__ void generate_hashmap_duplicates(
}
/**
* \brief This generates a hash map where the keys are the global item numbers,
* and the values are indexes, and all inputs are unique.
*
* \tparam IdType The type of of id.
* \tparam BLOCK_SIZE The size of the thread block.
* \tparam TILE_SIZE The number of entries each thread block will process.
* \param items The unique items to insert.
* \param num_items The number of items to insert.
* \param table The hash table.
*/
template<typename IdType, int BLOCK_SIZE, size_t TILE_SIZE>
* \brief This generates a hash map where the keys are the global item numbers,
* and the values are indexes, and all inputs are unique.
*
* \tparam IdType The type of of id.
* \tparam BLOCK_SIZE The size of the thread block.
* \tparam TILE_SIZE The number of entries each thread block will process.
* \param items The unique items to insert.
* \param num_items The number of items to insert.
* \param table The hash table.
*/
template <typename IdType, int BLOCK_SIZE, size_t TILE_SIZE>
__global__ void generate_hashmap_unique(
const IdType * const items,
const int64_t num_items,
const IdType* const items, const int64_t num_items,
MutableDeviceOrderedHashTable<IdType> table) {
assert(BLOCK_SIZE == blockDim.x);
using Iterator = typename MutableDeviceOrderedHashTable<IdType>::Iterator;
const size_t block_start = TILE_SIZE*blockIdx.x;
const size_t block_end = TILE_SIZE*(blockIdx.x+1);
const size_t block_start = TILE_SIZE * blockIdx.x;
const size_t block_end = TILE_SIZE * (blockIdx.x + 1);
#pragma unroll
for (size_t index = threadIdx.x + block_start; index < block_end; index += BLOCK_SIZE) {
#pragma unroll
for (size_t index = threadIdx.x + block_start; index < block_end;
index += BLOCK_SIZE) {
if (index < num_items) {
const Iterator pos = table.Insert(items[index], index);
......@@ -237,35 +229,34 @@ __global__ void generate_hashmap_unique(
}
/**
* \brief This counts the number of nodes inserted per thread block.
*
* \tparam IdType The type of of id.
* \tparam BLOCK_SIZE The size of the thread block.
* \tparam TILE_SIZE The number of entries each thread block will process.
* \param input The nodes to insert.
* \param num_input The number of nodes to insert.
* \param table The hash table.
* \param num_unique The number of nodes inserted into the hash table per thread
* block.
*/
template<typename IdType, int BLOCK_SIZE, size_t TILE_SIZE>
* \brief This counts the number of nodes inserted per thread block.
*
* \tparam IdType The type of of id.
* \tparam BLOCK_SIZE The size of the thread block.
* \tparam TILE_SIZE The number of entries each thread block will process.
* \param input The nodes to insert.
* \param num_input The number of nodes to insert.
* \param table The hash table.
* \param num_unique The number of nodes inserted into the hash table per thread
* block.
*/
template <typename IdType, int BLOCK_SIZE, size_t TILE_SIZE>
__global__ void count_hashmap(
const IdType * items,
const size_t num_items,
DeviceOrderedHashTable<IdType> table,
IdType * const num_unique) {
const IdType* items, const size_t num_items,
DeviceOrderedHashTable<IdType> table, IdType* const num_unique) {
assert(BLOCK_SIZE == blockDim.x);
using BlockReduce = typename cub::BlockReduce<IdType, BLOCK_SIZE>;
using Mapping = typename DeviceOrderedHashTable<IdType>::Mapping;
const size_t block_start = TILE_SIZE*blockIdx.x;
const size_t block_end = TILE_SIZE*(blockIdx.x+1);
const size_t block_start = TILE_SIZE * blockIdx.x;
const size_t block_end = TILE_SIZE * (blockIdx.x + 1);
IdType count = 0;
#pragma unroll
for (size_t index = threadIdx.x + block_start; index < block_end; index += BLOCK_SIZE) {
#pragma unroll
for (size_t index = threadIdx.x + block_start; index < block_end;
index += BLOCK_SIZE) {
if (index < num_items) {
const Mapping& mapping = *table.Search(items[index]);
if (mapping.index == index) {
......@@ -286,29 +277,26 @@ __global__ void count_hashmap(
}
}
/**
* \brief Update the local numbering of elements in the hashmap.
*
* \tparam IdType The type of id.
* \tparam BLOCK_SIZE The size of the thread blocks.
* \tparam TILE_SIZE The number of elements each thread block works on.
* \param items The set of non-unique items to update from.
* \param num_items The number of non-unique items.
* \param table The hash table.
* \param num_items_prefix The number of unique items preceding each thread
* block.
* \param unique_items The set of unique items (output).
* \param num_unique_items The number of unique items (output).
*/
template<typename IdType, int BLOCK_SIZE, size_t TILE_SIZE>
* \brief Update the local numbering of elements in the hashmap.
*
* \tparam IdType The type of id.
* \tparam BLOCK_SIZE The size of the thread blocks.
* \tparam TILE_SIZE The number of elements each thread block works on.
* \param items The set of non-unique items to update from.
* \param num_items The number of non-unique items.
* \param table The hash table.
* \param num_items_prefix The number of unique items preceding each thread
* block.
* \param unique_items The set of unique items (output).
* \param num_unique_items The number of unique items (output).
*/
template <typename IdType, int BLOCK_SIZE, size_t TILE_SIZE>
__global__ void compact_hashmap(
const IdType * const items,
const size_t num_items,
const IdType* const items, const size_t num_items,
MutableDeviceOrderedHashTable<IdType> table,
const IdType * const num_items_prefix,
IdType * const unique_items,
int64_t * const num_unique_items) {
const IdType* const num_items_prefix, IdType* const unique_items,
int64_t* const num_unique_items) {
assert(BLOCK_SIZE == blockDim.x);
using FlagType = uint16_t;
......@@ -325,10 +313,10 @@ __global__ void compact_hashmap(
// count successful placements
for (int32_t i = 0; i < VALS_PER_THREAD; ++i) {
const IdType index = threadIdx.x + i*BLOCK_SIZE + blockIdx.x*TILE_SIZE;
const IdType index = threadIdx.x + i * BLOCK_SIZE + blockIdx.x * TILE_SIZE;
FlagType flag;
Mapping * kv;
Mapping* kv;
if (index < num_items) {
kv = table.Search(items[index]);
flag = kv->index == index;
......@@ -344,7 +332,7 @@ __global__ void compact_hashmap(
__syncthreads();
if (kv) {
const IdType pos = offset+flag;
const IdType pos = offset + flag;
kv->local = pos;
unique_items[pos] = items[index];
}
......@@ -357,128 +345,94 @@ __global__ void compact_hashmap(
// DeviceOrderedHashTable implementation
template<typename IdType>
template <typename IdType>
DeviceOrderedHashTable<IdType>::DeviceOrderedHashTable(
const Mapping* const table,
const size_t size) :
table_(table),
size_(size) {
}
const Mapping* const table, const size_t size)
: table_(table), size_(size) {}
template<typename IdType>
template <typename IdType>
DeviceOrderedHashTable<IdType> OrderedHashTable<IdType>::DeviceHandle() const {
return DeviceOrderedHashTable<IdType>(table_, size_);
}
// OrderedHashTable implementation
template<typename IdType>
template <typename IdType>
OrderedHashTable<IdType>::OrderedHashTable(
const size_t size,
DGLContext ctx,
cudaStream_t stream,
const int scale) :
table_(nullptr),
size_(TableSize(size, scale)),
ctx_(ctx) {
const size_t size, DGLContext ctx, cudaStream_t stream, const int scale)
: table_(nullptr), size_(TableSize(size, scale)), ctx_(ctx) {
// make sure we will at least as many buckets as items.
CHECK_GT(scale, 0);
auto device = runtime::DeviceAPI::Get(ctx_);
table_ = static_cast<Mapping*>(
device->AllocWorkspace(ctx_, sizeof(Mapping)*size_));
device->AllocWorkspace(ctx_, sizeof(Mapping) * size_));
CUDA_CALL(cudaMemsetAsync(
table_,
DeviceOrderedHashTable<IdType>::kEmptyKey,
sizeof(Mapping)*size_,
stream));
table_, DeviceOrderedHashTable<IdType>::kEmptyKey,
sizeof(Mapping) * size_, stream));
}
template<typename IdType>
template <typename IdType>
OrderedHashTable<IdType>::~OrderedHashTable() {
auto device = runtime::DeviceAPI::Get(ctx_);
device->FreeWorkspace(ctx_, table_);
}
template<typename IdType>
template <typename IdType>
void OrderedHashTable<IdType>::FillWithDuplicates(
const IdType * const input,
const size_t num_input,
IdType * const unique,
int64_t * const num_unique,
cudaStream_t stream) {
const IdType* const input, const size_t num_input, IdType* const unique,
int64_t* const num_unique, cudaStream_t stream) {
auto device = runtime::DeviceAPI::Get(ctx_);
const int64_t num_tiles = (num_input+TILE_SIZE-1)/TILE_SIZE;
const int64_t num_tiles = (num_input + TILE_SIZE - 1) / TILE_SIZE;
const dim3 grid(num_tiles);
const dim3 block(BLOCK_SIZE);
auto device_table = MutableDeviceOrderedHashTable<IdType>(this);
CUDA_KERNEL_CALL((generate_hashmap_duplicates<IdType, BLOCK_SIZE, TILE_SIZE>),
grid, block, 0, stream,
input,
num_input,
device_table);
CUDA_KERNEL_CALL(
(generate_hashmap_duplicates<IdType, BLOCK_SIZE, TILE_SIZE>), grid, block,
0, stream, input, num_input, device_table);
IdType * item_prefix = static_cast<IdType*>(
device->AllocWorkspace(ctx_, sizeof(IdType)*(num_input+1)));
IdType* item_prefix = static_cast<IdType*>(
device->AllocWorkspace(ctx_, sizeof(IdType) * (num_input + 1)));
CUDA_KERNEL_CALL((count_hashmap<IdType, BLOCK_SIZE, TILE_SIZE>),
grid, block, 0, stream,
input,
num_input,
device_table,
item_prefix);
CUDA_KERNEL_CALL(
(count_hashmap<IdType, BLOCK_SIZE, TILE_SIZE>), grid, block, 0, stream,
input, num_input, device_table, item_prefix);
size_t workspace_bytes;
CUDA_CALL(cub::DeviceScan::ExclusiveSum(
nullptr,
workspace_bytes,
static_cast<IdType*>(nullptr),
static_cast<IdType*>(nullptr),
grid.x+1, stream));
void * workspace = device->AllocWorkspace(ctx_, workspace_bytes);
nullptr, workspace_bytes, static_cast<IdType*>(nullptr),
static_cast<IdType*>(nullptr), grid.x + 1, stream));
void* workspace = device->AllocWorkspace(ctx_, workspace_bytes);
CUDA_CALL(cub::DeviceScan::ExclusiveSum(
workspace,
workspace_bytes,
item_prefix,
item_prefix,
grid.x+1, stream));
workspace, workspace_bytes, item_prefix, item_prefix, grid.x + 1,
stream));
device->FreeWorkspace(ctx_, workspace);
CUDA_KERNEL_CALL((compact_hashmap<IdType, BLOCK_SIZE, TILE_SIZE>),
grid, block, 0, stream,
input,
num_input,
device_table,
item_prefix,
unique,
num_unique);
CUDA_KERNEL_CALL(
(compact_hashmap<IdType, BLOCK_SIZE, TILE_SIZE>), grid, block, 0, stream,
input, num_input, device_table, item_prefix, unique, num_unique);
device->FreeWorkspace(ctx_, item_prefix);
}
template<typename IdType>
template <typename IdType>
void OrderedHashTable<IdType>::FillWithUnique(
const IdType * const input,
const size_t num_input,
cudaStream_t stream) {
const int64_t num_tiles = (num_input+TILE_SIZE-1)/TILE_SIZE;
const IdType* const input, const size_t num_input, cudaStream_t stream) {
const int64_t num_tiles = (num_input + TILE_SIZE - 1) / TILE_SIZE;
const dim3 grid(num_tiles);
const dim3 block(BLOCK_SIZE);
auto device_table = MutableDeviceOrderedHashTable<IdType>(this);
CUDA_KERNEL_CALL((generate_hashmap_unique<IdType, BLOCK_SIZE, TILE_SIZE>),
grid, block, 0, stream,
input,
num_input,
device_table);
CUDA_KERNEL_CALL(
(generate_hashmap_unique<IdType, BLOCK_SIZE, TILE_SIZE>), grid, block, 0,
stream, input, num_input, device_table);
}
template class OrderedHashTable<int32_t>;
......
......@@ -9,14 +9,14 @@
#include <dgl/runtime/c_runtime_api.h>
#include "cuda_runtime.h"
#include "cuda_common.h"
#include "cuda_runtime.h"
namespace dgl {
namespace runtime {
namespace cuda {
template<typename>
template <typename>
class OrderedHashTable;
/*!
......@@ -62,7 +62,7 @@ class OrderedHashTable;
*
* \tparam IdType The type of the IDs.
*/
template<typename IdType>
template <typename IdType>
class DeviceOrderedHashTable {
public:
/**
......@@ -86,10 +86,9 @@ class DeviceOrderedHashTable {
typedef const Mapping* ConstIterator;
DeviceOrderedHashTable(
const DeviceOrderedHashTable& other) = default;
DeviceOrderedHashTable& operator=(
const DeviceOrderedHashTable& other) = default;
DeviceOrderedHashTable(const DeviceOrderedHashTable& other) = default;
DeviceOrderedHashTable& operator=(const DeviceOrderedHashTable& other) =
default;
/**
* \brief Find the non-mutable mapping of a given key within the hash table.
......@@ -101,8 +100,7 @@ class DeviceOrderedHashTable {
*
* \return An iterator to the mapping.
*/
inline __device__ ConstIterator Search(
const IdType id) const {
inline __device__ ConstIterator Search(const IdType id) const {
const IdType pos = SearchForPosition(id);
return &table_[pos];
......@@ -115,8 +113,7 @@ class DeviceOrderedHashTable {
*
* \return True if the key exists in the hashtable.
*/
inline __device__ bool Contains(
const IdType id) const {
inline __device__ bool Contains(const IdType id) const {
IdType pos = Hash(id);
IdType delta = 1;
......@@ -124,8 +121,8 @@ class DeviceOrderedHashTable {
if (table_[pos].key == id) {
return true;
}
pos = Hash(pos+delta);
delta +=1;
pos = Hash(pos + delta);
delta += 1;
}
return false;
}
......@@ -134,7 +131,7 @@ class DeviceOrderedHashTable {
// Must be uniform bytes for memset to work
static constexpr IdType kEmptyKey = static_cast<IdType>(-1);
const Mapping * table_;
const Mapping* table_;
size_t size_;
/**
......@@ -143,9 +140,7 @@ class DeviceOrderedHashTable {
* \param table The table stored in GPU memory.
* \param size The size of the table.
*/
explicit DeviceOrderedHashTable(
const Mapping * table,
size_t size);
explicit DeviceOrderedHashTable(const Mapping* table, size_t size);
/**
* \brief Search for an item in the hash table which is known to exist.
......@@ -157,16 +152,15 @@ class DeviceOrderedHashTable {
*
* \return The the position of the item in the hashtable.
*/
inline __device__ IdType SearchForPosition(
const IdType id) const {
inline __device__ IdType SearchForPosition(const IdType id) const {
IdType pos = Hash(id);
// linearly scan for matching entry
IdType delta = 1;
while (table_[pos].key != id) {
assert(table_[pos].key != kEmptyKey);
pos = Hash(pos+delta);
delta +=1;
pos = Hash(pos + delta);
delta += 1;
}
assert(pos < size_);
......@@ -180,10 +174,7 @@ class DeviceOrderedHashTable {
*
* \return The hash.
*/
inline __device__ size_t Hash(
const IdType id) const {
return id % size_;
}
inline __device__ size_t Hash(const IdType id) const { return id % size_; }
friend class OrderedHashTable<IdType>;
};
......@@ -219,7 +210,7 @@ class DeviceOrderedHashTable {
*
* \tparam IdType The type of the IDs.
*/
template<typename IdType>
template <typename IdType>
class OrderedHashTable {
public:
static constexpr int kDefaultScale = 3;
......@@ -237,9 +228,7 @@ class OrderedHashTable {
* \param stream The stream to use for initializing the hashtable.
*/
OrderedHashTable(
const size_t size,
DGLContext ctx,
cudaStream_t stream,
const size_t size, DGLContext ctx, cudaStream_t stream,
const int scale = kDefaultScale);
/**
......@@ -248,10 +237,8 @@ class OrderedHashTable {
~OrderedHashTable();
// Disable copying
OrderedHashTable(
const OrderedHashTable& other) = delete;
OrderedHashTable& operator=(
const OrderedHashTable& other) = delete;
OrderedHashTable(const OrderedHashTable& other) = delete;
OrderedHashTable& operator=(const OrderedHashTable& other) = delete;
/**
* \brief Fill the hashtable with the array containing possibly duplicate
......@@ -264,11 +251,8 @@ class OrderedHashTable {
* \param stream The stream to perform operations on.
*/
void FillWithDuplicates(
const IdType * const input,
const size_t num_input,
IdType * const unique,
int64_t * const num_unique,
cudaStream_t stream);
const IdType* const input, const size_t num_input, IdType* const unique,
int64_t* const num_unique, cudaStream_t stream);
/**
* \brief Fill the hashtable with an array of unique keys.
......@@ -278,9 +262,7 @@ class OrderedHashTable {
* \param stream The stream to perform operations on.
*/
void FillWithUnique(
const IdType * const input,
const size_t num_input,
cudaStream_t stream);
const IdType* const input, const size_t num_input, cudaStream_t stream);
/**
* \brief Get a verison of the hashtable usable from device functions.
......@@ -290,12 +272,11 @@ class OrderedHashTable {
DeviceOrderedHashTable<IdType> DeviceHandle() const;
private:
Mapping * table_;
Mapping* table_;
size_t size_;
DGLContext ctx_;
};
} // namespace cuda
} // namespace runtime
} // namespace dgl
......
This diff is collapsed.
......@@ -17,7 +17,6 @@
* \brief Wrapper around NCCL routines.
*/
#ifndef DGL_RUNTIME_CUDA_NCCL_API_H_
#define DGL_RUNTIME_CUDA_NCCL_API_H_
......@@ -27,11 +26,14 @@
// if not compiling with NCCL, this class will only support communicators of
// size 1.
#define NCCL_UNIQUE_ID_BYTES 128
typedef struct { char internal[NCCL_UNIQUE_ID_BYTES]; } ncclUniqueId;
typedef struct {
char internal[NCCL_UNIQUE_ID_BYTES];
} ncclUniqueId;
typedef int ncclComm_t;
#endif
#include <dgl/runtime/object.h>
#include <string>
namespace dgl {
......@@ -59,17 +61,13 @@ DGL_DEFINE_OBJECT_REF(NCCLUniqueIdRef, NCCLUniqueId);
class NCCLCommunicator : public runtime::Object {
public:
NCCLCommunicator(
int size,
int rank,
ncclUniqueId id);
NCCLCommunicator(int size, int rank, ncclUniqueId id);
~NCCLCommunicator();
// disable copying
NCCLCommunicator(const NCCLCommunicator& other) = delete;
NCCLCommunicator& operator=(
const NCCLCommunicator& other);
NCCLCommunicator& operator=(const NCCLCommunicator& other);
ncclComm_t Get();
......@@ -81,12 +79,9 @@ class NCCLCommunicator : public runtime::Object {
* @param count The size of data to send to each rank.
* @param stream The stream to operate on.
*/
template<typename IdType>
template <typename IdType>
void AllToAll(
const IdType * send,
IdType * recv,
int64_t count,
cudaStream_t stream);
const IdType* send, IdType* recv, int64_t count, cudaStream_t stream);
/**
* @brief Perform an all-to-all variable sized communication.
......@@ -99,13 +94,10 @@ class NCCLCommunicator : public runtime::Object {
* @param type The type of data to send.
* @param stream The stream to operate on.
*/
template<typename DType>
template <typename DType>
void AllToAllV(
const DType * const send,
const int64_t * send_prefix,
DType * const recv,
const int64_t * recv_prefix,
cudaStream_t stream);
const DType* const send, const int64_t* send_prefix, DType* const recv,
const int64_t* recv_prefix, cudaStream_t stream);
/**
* @brief Perform an all-to-all with sparse data (idx and value pairs). By
......@@ -124,16 +116,11 @@ class NCCLCommunicator : public runtime::Object {
* recieve on the host.
* @param stream The stream to communicate on.
*/
template<typename IdType, typename DType>
template <typename IdType, typename DType>
void SparseAllToAll(
const IdType * send_idx,
const DType * send_value,
const int64_t num_feat,
const int64_t * send_prefix,
IdType * recv_idx,
DType * recv_value,
const int64_t * recv_prefix,
cudaStream_t stream);
const IdType* send_idx, const DType* send_value, const int64_t num_feat,
const int64_t* send_prefix, IdType* recv_idx, DType* recv_value,
const int64_t* recv_prefix, cudaStream_t stream);
int size() const;
......
......@@ -3,12 +3,12 @@
* \file src/runtime/dlpack_convert.cc
* \brief Conversion between NDArray and DLPack.
*/
#include <dgl/runtime/dlpack_convert.h>
#include <dlpack/dlpack.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/device_api.h>
#include <dgl/runtime/dlpack_convert.h>
#include <dgl/runtime/ndarray.h>
#include <dlpack/dlpack.h>
#include "runtime_base.h"
// deleter for arrays used by DLPack exporter
......@@ -69,8 +69,7 @@ NDArray DLPackConvert::FromDLPack(DLManagedTensor* tensor) {
void DLPackConvert::DLPackDeleter(NDArray::Container* ptr) {
// if the array is pinned by dgl, unpin it before freeing
if (ptr->pinned_by_dgl_)
NDArray::UnpinContainer(ptr);
if (ptr->pinned_by_dgl_) NDArray::UnpinContainer(ptr);
DLManagedTensor* tensor = static_cast<DLManagedTensor*>(ptr->manager_ctx);
if (tensor->deleter != nullptr) {
(*tensor->deleter)(tensor);
......@@ -95,7 +94,7 @@ DLManagedTensor* ContainerToDLPack(NDArray::Container* from) {
return ret;
}
DLManagedTensor* DLPackConvert::ToDLPack(const NDArray &from) {
DLManagedTensor* DLPackConvert::ToDLPack(const NDArray& from) {
return ContainerToDLPack(from.data_);
}
......@@ -113,15 +112,14 @@ inline bool IsAligned(const void* ptr, std::uintptr_t alignment) noexcept {
return !(iptr % alignment);
}
int DGLArrayFromDLPack(DLManagedTensor* from,
DGLArrayHandle* out) {
int DGLArrayFromDLPack(DLManagedTensor* from, DGLArrayHandle* out) {
API_BEGIN();
*out = NDArray::Internal::MoveAsDGLArray(DLPackConvert::FromDLPack(from));
API_END();
}
int DGLArrayToDLPack(DGLArrayHandle from, DLManagedTensor** out,
int alignment) {
int DGLArrayToDLPack(
DGLArrayHandle from, DLManagedTensor** out, int alignment) {
API_BEGIN();
auto* nd_container = reinterpret_cast<NDArray::Container*>(from);
DGLArray* nd = &(nd_container->dl_tensor);
......
......@@ -4,8 +4,9 @@
* \brief Module to load from dynamic shared library.
*/
#include <dgl/runtime/module.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/packed_func.h>
#include <dgl/runtime/registry.h>
#include "module_util.h"
#if defined(_WIN32)
......@@ -25,9 +26,7 @@ class DSOModuleNode final : public ModuleNode {
if (lib_handle_) Unload();
}
const char* type_key() const final {
return "dso";
}
const char* type_key() const final { return "dso"; }
PackedFunc GetFunction(
const std::string& name,
......@@ -36,8 +35,9 @@ class DSOModuleNode final : public ModuleNode {
if (name == runtime::symbol::dgl_module_main) {
const char* entry_name = reinterpret_cast<const char*>(
GetSymbol(runtime::symbol::dgl_module_main));
CHECK(entry_name!= nullptr)
<< "Symbol " << runtime::symbol::dgl_module_main << " is not presented";
CHECK(entry_name != nullptr)
<< "Symbol " << runtime::symbol::dgl_module_main
<< " is not presented";
faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(entry_name));
} else {
faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(name.c_str()));
......@@ -48,16 +48,14 @@ class DSOModuleNode final : public ModuleNode {
void Init(const std::string& name) {
Load(name);
if (auto *ctx_addr =
reinterpret_cast<void**>(GetSymbol(runtime::symbol::dgl_module_ctx))) {
if (auto* ctx_addr = reinterpret_cast<void**>(
GetSymbol(runtime::symbol::dgl_module_ctx))) {
*ctx_addr = this;
}
InitContextFunctions([this](const char* fname) {
return GetSymbol(fname);
});
InitContextFunctions(
[this](const char* fname) { return GetSymbol(fname); });
// Load the imported modules
const char* dev_mblob =
reinterpret_cast<const char*>(
const char* dev_mblob = reinterpret_cast<const char*>(
GetSymbol(runtime::symbol::dgl_dev_mblob));
if (dev_mblob != nullptr) {
ImportModuleBlob(dev_mblob, &imports_);
......@@ -81,9 +79,7 @@ class DSOModuleNode final : public ModuleNode {
return reinterpret_cast<void*>(
GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*)
}
void Unload() {
FreeLibrary(lib_handle_);
}
void Unload() { FreeLibrary(lib_handle_); }
#else
// Library handle
void* lib_handle_{nullptr};
......@@ -91,20 +87,15 @@ class DSOModuleNode final : public ModuleNode {
void Load(const std::string& name) {
lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL);
CHECK(lib_handle_ != nullptr)
<< "Failed to load dynamic shared library " << name
<< " " << dlerror();
}
void* GetSymbol(const char* name) {
return dlsym(lib_handle_, name);
}
void Unload() {
dlclose(lib_handle_);
<< "Failed to load dynamic shared library " << name << " " << dlerror();
}
void* GetSymbol(const char* name) { return dlsym(lib_handle_, name); }
void Unload() { dlclose(lib_handle_); }
#endif
};
DGL_REGISTER_GLOBAL("module.loadfile_so")
.set_body([](DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
std::shared_ptr<DSOModuleNode> n = std::make_shared<DSOModuleNode>();
n->Init(args[0]);
*rv = runtime::Module(n);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment