Unverified Commit 401e1278 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] clang-format auto fix. (#4811)



* [Misc] clang-format auto fix.

* fix

* manual
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 6c53f351
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include <ws2tcpip.h> #include <ws2tcpip.h>
#pragma comment(lib, "Ws2_32.lib") #pragma comment(lib, "Ws2_32.lib")
#else // !_WIN32 #else // !_WIN32
#include <sys/socket.h> #include <sys/socket.h>
#endif // _WIN32 #endif // _WIN32
#include <string> #include <string>
...@@ -20,7 +20,7 @@ namespace dgl { ...@@ -20,7 +20,7 @@ namespace dgl {
namespace network { namespace network {
/*! /*!
* \brief TCPSocket is a simple wrapper around a socket. * \brief TCPSocket is a simple wrapper around a socket.
* It supports only TCP connections. * It supports only TCP connections.
*/ */
class TCPSocket { class TCPSocket {
...@@ -32,7 +32,7 @@ class TCPSocket { ...@@ -32,7 +32,7 @@ class TCPSocket {
/*! /*!
* \brief TCPSocket deconstructor * \brief TCPSocket deconstructor
*/ */
~TCPSocket(); ~TCPSocket();
/*! /*!
...@@ -41,7 +41,7 @@ class TCPSocket { ...@@ -41,7 +41,7 @@ class TCPSocket {
* \param port end port * \param port end port
* \return true for success and false for failure * \return true for success and false for failure
*/ */
bool Connect(const char * ip, int port); bool Connect(const char* ip, int port);
/*! /*!
* \brief Bind on the given IP and PORT * \brief Bind on the given IP and PORT
...@@ -49,7 +49,7 @@ class TCPSocket { ...@@ -49,7 +49,7 @@ class TCPSocket {
* \param port end port * \param port end port
* \return true for success and false for failure * \return true for success and false for failure
*/ */
bool Bind(const char * ip, int port); bool Bind(const char* ip, int port);
/*! /*!
* \brief listen for remote connection * \brief listen for remote connection
...@@ -65,9 +65,7 @@ class TCPSocket { ...@@ -65,9 +65,7 @@ class TCPSocket {
* \param port_client new PORT will be stored to port_client * \param port_client new PORT will be stored to port_client
* \return true for success and false for failure * \return true for success and false for failure
*/ */
bool Accept(TCPSocket * socket, bool Accept(TCPSocket* socket, std::string* ip_client, int* port_client);
std::string * ip_client,
int * port_client);
/*! /*!
* \brief SetNonBlocking() is needed refering to this example of epoll: * \brief SetNonBlocking() is needed refering to this example of epoll:
...@@ -103,27 +101,27 @@ class TCPSocket { ...@@ -103,27 +101,27 @@ class TCPSocket {
* \param data data for sending * \param data data for sending
* \param len_data length of data * \param len_data length of data
* \return return number of bytes sent if OK, -1 on error * \return return number of bytes sent if OK, -1 on error
*/ */
int64_t Send(const char * data, int64_t len_data); int64_t Send(const char* data, int64_t len_data);
/*! /*!
* \brief Receive data. * \brief Receive data.
* \param buffer buffer for receving * \param buffer buffer for receving
* \param size_buffer size of buffer * \param size_buffer size of buffer
* \return return number of bytes received if OK, -1 on error * \return return number of bytes received if OK, -1 on error
*/ */
int64_t Receive(char * buffer, int64_t size_buffer); int64_t Receive(char* buffer, int64_t size_buffer);
/*! /*!
* \brief Get socket's file descriptor * \brief Get socket's file descriptor
* \return socket's file descriptor * \return socket's file descriptor
*/ */
int Socket() const; int Socket() const;
private: private:
/*! /*!
* \brief socket's file descriptor * \brief socket's file descriptor
*/ */
int socket_; int socket_;
}; };
......
...@@ -101,7 +101,8 @@ void InitGlobalTpContext() { ...@@ -101,7 +101,8 @@ void InitGlobalTpContext() {
char* numUvThreads_str = std::getenv("DGL_SOCKET_NTHREADS"); char* numUvThreads_str = std::getenv("DGL_SOCKET_NTHREADS");
if (numUvThreads_str) { if (numUvThreads_str) {
int numUvThreads = std::atoi(numUvThreads_str); int numUvThreads = std::atoi(numUvThreads_str);
CHECK(numUvThreads > 0) << "DGL_SOCKET_NTHREADS should be positive integer if set"; CHECK(numUvThreads > 0)
<< "DGL_SOCKET_NTHREADS should be positive integer if set";
// Register multiplex uv channel // Register multiplex uv channel
std::vector<std::shared_ptr<tensorpipe::transport::Context>> contexts; std::vector<std::shared_ptr<tensorpipe::transport::Context>> contexts;
std::vector<std::shared_ptr<tensorpipe::transport::Listener>> listeners; std::vector<std::shared_ptr<tensorpipe::transport::Listener>> listeners;
...@@ -111,8 +112,8 @@ void InitGlobalTpContext() { ...@@ -111,8 +112,8 @@ void InitGlobalTpContext() {
contexts.push_back(std::move(context)); contexts.push_back(std::move(context));
listeners.push_back(contexts.back()->listen(address)); listeners.push_back(contexts.back()->listen(address));
} }
auto mptChannel = tensorpipe::channel::mpt::create(std::move(contexts), auto mptChannel = tensorpipe::channel::mpt::create(
std::move(listeners)); std::move(contexts), std::move(listeners));
context->registerChannel(20 /* high priority */, "mpt", mptChannel); context->registerChannel(20 /* high priority */, "mpt", mptChannel);
} }
} }
...@@ -120,261 +121,265 @@ void InitGlobalTpContext() { ...@@ -120,261 +121,265 @@ void InitGlobalTpContext() {
//////////////////////////// C APIs //////////////////////////// //////////////////////////// C APIs ////////////////////////////
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCReset") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCReset")
.set_body([](DGLArgs args, DGLRetValue* rv) { RPCContext::Reset(); }); .set_body([](DGLArgs args, DGLRetValue* rv) { RPCContext::Reset(); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateSender") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateSender")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
int64_t msg_queue_size = args[0]; int64_t msg_queue_size = args[0];
std::string type = args[1]; std::string type = args[1];
int max_thread_count = args[2]; int max_thread_count = args[2];
if (type == "tensorpipe") { if (type == "tensorpipe") {
InitGlobalTpContext(); InitGlobalTpContext();
RPCContext::getInstance()->sender.reset( RPCContext::getInstance()->sender.reset(
new TPSender(RPCContext::getInstance()->ctx)); new TPSender(RPCContext::getInstance()->ctx));
} else if (type == "socket") { } else if (type == "socket") {
RPCContext::getInstance()->sender.reset( RPCContext::getInstance()->sender.reset(
new network::SocketSender(msg_queue_size, max_thread_count)); new network::SocketSender(msg_queue_size, max_thread_count));
} else { } else {
LOG(FATAL) << "Unknown communicator type for rpc sender: " << type; LOG(FATAL) << "Unknown communicator type for rpc sender: " << type;
} }
LOG(INFO) << "Sender with NetType~" LOG(INFO) << "Sender with NetType~"
<< RPCContext::getInstance()->sender->NetType() << " is created."; << RPCContext::getInstance()->sender->NetType()
}); << " is created.";
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateReceiver") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateReceiver")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
int64_t msg_queue_size = args[0]; int64_t msg_queue_size = args[0];
std::string type = args[1]; std::string type = args[1];
int max_thread_count = args[2]; int max_thread_count = args[2];
if (type == "tensorpipe") { if (type == "tensorpipe") {
InitGlobalTpContext(); InitGlobalTpContext();
RPCContext::getInstance()->receiver.reset( RPCContext::getInstance()->receiver.reset(
new TPReceiver(RPCContext::getInstance()->ctx)); new TPReceiver(RPCContext::getInstance()->ctx));
} else if (type == "socket") { } else if (type == "socket") {
RPCContext::getInstance()->receiver.reset( RPCContext::getInstance()->receiver.reset(
new network::SocketReceiver(msg_queue_size, max_thread_count)); new network::SocketReceiver(msg_queue_size, max_thread_count));
} else { } else {
LOG(FATAL) << "Unknown communicator type for rpc receiver: " << type; LOG(FATAL) << "Unknown communicator type for rpc receiver: " << type;
} }
LOG(INFO) << "Receiver with NetType~" LOG(INFO) << "Receiver with NetType~"
<< RPCContext::getInstance()->receiver->NetType() << " is created."; << RPCContext::getInstance()->receiver->NetType()
}); << " is created.";
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeSender") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeSender")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
RPCContext::getInstance()->sender->Finalize(); RPCContext::getInstance()->sender->Finalize();
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeReceiver") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeReceiver")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
RPCContext::getInstance()->receiver->Finalize(); RPCContext::getInstance()->receiver->Finalize();
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCWaitForSenders") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCWaitForSenders")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
std::string ip = args[0]; std::string ip = args[0];
int port = args[1]; int port = args[1];
int num_sender = args[2]; int num_sender = args[2];
bool blocking = args[3]; bool blocking = args[3];
std::string addr; std::string addr;
addr = StringPrintf("tcp://%s:%d", ip.c_str(), port); addr = StringPrintf("tcp://%s:%d", ip.c_str(), port);
if (RPCContext::getInstance()->receiver->Wait(addr, num_sender, blocking) == false) { if (RPCContext::getInstance()->receiver->Wait(
LOG(FATAL) << "Wait sender socket failed."; addr, num_sender, blocking) == false) {
} LOG(FATAL) << "Wait sender socket failed.";
}); }
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCConnectReceiver") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCConnectReceiver")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
std::string ip = args[0]; std::string ip = args[0];
int port = args[1]; int port = args[1];
int recv_id = args[2]; int recv_id = args[2];
std::string addr; std::string addr;
addr = StringPrintf("tcp://%s:%d", ip.c_str(), port); addr = StringPrintf("tcp://%s:%d", ip.c_str(), port);
*rv = RPCContext::getInstance()->sender->ConnectReceiver(addr, recv_id); *rv = RPCContext::getInstance()->sender->ConnectReceiver(addr, recv_id);
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCConnectReceiverFinalize") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCConnectReceiverFinalize")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const int max_try_times = args[0]; const int max_try_times = args[0];
*rv = RPCContext::getInstance()->sender->ConnectReceiverFinalize(max_try_times); *rv = RPCContext::getInstance()->sender->ConnectReceiverFinalize(
}); max_try_times);
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetRank") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetRank")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const int32_t rank = args[0]; const int32_t rank = args[0];
RPCContext::getInstance()->rank = rank; RPCContext::getInstance()->rank = rank;
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetRank") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetRank")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
*rv = RPCContext::getInstance()->rank; *rv = RPCContext::getInstance()->rank;
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumServer") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumServer")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const int32_t num_servers = args[0]; const int32_t num_servers = args[0];
*rv = RPCContext::getInstance()->num_servers = num_servers; *rv = RPCContext::getInstance()->num_servers = num_servers;
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServer") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServer")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
*rv = RPCContext::getInstance()->num_servers; *rv = RPCContext::getInstance()->num_servers;
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumClient") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumClient")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const int32_t num_clients = args[0]; const int32_t num_clients = args[0];
*rv = RPCContext::getInstance()->num_clients = num_clients; *rv = RPCContext::getInstance()->num_clients = num_clients;
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumClient") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumClient")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
*rv = RPCContext::getInstance()->num_clients; *rv = RPCContext::getInstance()->num_clients;
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumServerPerMachine") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumServerPerMachine")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const int32_t num_servers = args[0]; const int32_t num_servers = args[0];
*rv = RPCContext::getInstance()->num_servers_per_machine = num_servers; *rv = RPCContext::getInstance()->num_servers_per_machine = num_servers;
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServerPerMachine") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServerPerMachine")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
*rv = RPCContext::getInstance()->num_servers_per_machine; *rv = RPCContext::getInstance()->num_servers_per_machine;
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCIncrMsgSeq") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCIncrMsgSeq")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
*rv = (RPCContext::getInstance()->msg_seq)++; *rv = (RPCContext::getInstance()->msg_seq)++;
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetMsgSeq") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetMsgSeq")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
*rv = RPCContext::getInstance()->msg_seq; *rv = RPCContext::getInstance()->msg_seq;
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetMsgSeq") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetMsgSeq")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const int64_t msg_seq = args[0]; const int64_t msg_seq = args[0];
RPCContext::getInstance()->msg_seq = msg_seq; RPCContext::getInstance()->msg_seq = msg_seq;
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetBarrierCount") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetBarrierCount")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const int32_t group_id = args[0]; const int32_t group_id = args[0];
auto&& cnt = RPCContext::getInstance()->barrier_count; auto&& cnt = RPCContext::getInstance()->barrier_count;
if (cnt.find(group_id) == cnt.end()) { if (cnt.find(group_id) == cnt.end()) {
cnt.emplace(group_id, 0x0); cnt.emplace(group_id, 0x0);
} }
*rv = cnt[group_id]; *rv = cnt[group_id];
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetBarrierCount") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetBarrierCount")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const int32_t count = args[0]; const int32_t count = args[0];
const int32_t group_id = args[1]; const int32_t group_id = args[1];
RPCContext::getInstance()->barrier_count[group_id] = count; RPCContext::getInstance()->barrier_count[group_id] = count;
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetMachineID") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetMachineID")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
*rv = RPCContext::getInstance()->machine_id; *rv = RPCContext::getInstance()->machine_id;
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetMachineID") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetMachineID")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const int32_t machine_id = args[0]; const int32_t machine_id = args[0];
RPCContext::getInstance()->machine_id = machine_id; RPCContext::getInstance()->machine_id = machine_id;
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumMachines") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumMachines")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
*rv = RPCContext::getInstance()->num_machines; *rv = RPCContext::getInstance()->num_machines;
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumMachines") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumMachines")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const int32_t num_machines = args[0]; const int32_t num_machines = args[0];
RPCContext::getInstance()->num_machines = num_machines; RPCContext::getInstance()->num_machines = num_machines;
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSendRPCMessage") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSendRPCMessage")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
RPCMessageRef msg = args[0]; RPCMessageRef msg = args[0];
const int32_t target_id = args[1]; const int32_t target_id = args[1];
*rv = SendRPCMessage(*(msg.sptr()), target_id); *rv = SendRPCMessage(*(msg.sptr()), target_id);
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCRecvRPCMessage") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCRecvRPCMessage")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
int32_t timeout = args[0]; int32_t timeout = args[0];
RPCMessageRef msg = args[1]; RPCMessageRef msg = args[1];
*rv = RecvRPCMessage(msg.sptr().get(), timeout); *rv = RecvRPCMessage(msg.sptr().get(), timeout);
}); });
//////////////////////////// RPCMessage //////////////////////////// //////////////////////////// RPCMessage ////////////////////////////
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateEmptyRPCMessage") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateEmptyRPCMessage")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
std::shared_ptr<RPCMessage> rst(new RPCMessage); std::shared_ptr<RPCMessage> rst(new RPCMessage);
*rv = rst; *rv = rst;
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateRPCMessage") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateRPCMessage")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
std::shared_ptr<RPCMessage> rst(new RPCMessage); std::shared_ptr<RPCMessage> rst(new RPCMessage);
rst->service_id = args[0]; rst->service_id = args[0];
rst->msg_seq = args[1]; rst->msg_seq = args[1];
rst->client_id = args[2]; rst->client_id = args[2];
rst->server_id = args[3]; rst->server_id = args[3];
const std::string data = const std::string data =
args[4]; // directly assigning string value raises errors :( args[4]; // directly assigning string value raises errors :(
rst->data = data; rst->data = data;
rst->tensors = ListValueToVector<NDArray>(args[5]); rst->tensors = ListValueToVector<NDArray>(args[5]);
rst->group_id = args[6]; rst->group_id = args[6];
*rv = rst; *rv = rst;
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetServiceId") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetServiceId")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const RPCMessageRef msg = args[0]; const RPCMessageRef msg = args[0];
*rv = msg->service_id; *rv = msg->service_id;
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetMsgSeq") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetMsgSeq")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const RPCMessageRef msg = args[0]; const RPCMessageRef msg = args[0];
*rv = msg->msg_seq; *rv = msg->msg_seq;
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetClientId") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetClientId")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const RPCMessageRef msg = args[0]; const RPCMessageRef msg = args[0];
*rv = msg->client_id; *rv = msg->client_id;
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetServerId") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetServerId")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const RPCMessageRef msg = args[0]; const RPCMessageRef msg = args[0];
*rv = msg->server_id; *rv = msg->server_id;
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetData") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetData")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const RPCMessageRef msg = args[0]; const RPCMessageRef msg = args[0];
DGLByteArray barr{msg->data.c_str(), msg->data.size()}; DGLByteArray barr{msg->data.c_str(), msg->data.size()};
*rv = barr; *rv = barr;
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetTensors") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetTensors")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const RPCMessageRef msg = args[0]; const RPCMessageRef msg = args[0];
List<Value> ret; List<Value> ret;
for (size_t i = 0; i < msg->tensors.size(); ++i) { for (size_t i = 0; i < msg->tensors.size(); ++i) {
ret.push_back(Value(MakeValue(msg->tensors[i]))); ret.push_back(Value(MakeValue(msg->tensors[i])));
} }
*rv = ret; *rv = ret;
}); });
#if defined(__linux__) #if defined(__linux__)
/*! /*!
...@@ -388,187 +393,190 @@ void SigHandler(int s) { ...@@ -388,187 +393,190 @@ void SigHandler(int s) {
} }
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCHandleSignal") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCHandleSignal")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
// Ctrl+C handler // Ctrl+C handler
struct sigaction sigHandler; struct sigaction sigHandler;
sigHandler.sa_handler = SigHandler; sigHandler.sa_handler = SigHandler;
sigemptyset(&sigHandler.sa_mask); sigemptyset(&sigHandler.sa_mask);
sigHandler.sa_flags = 0; sigHandler.sa_flags = 0;
sigaction(SIGINT, &sigHandler, nullptr); sigaction(SIGINT, &sigHandler, nullptr);
sigaction(SIGTERM, &sigHandler, nullptr); sigaction(SIGTERM, &sigHandler, nullptr);
}); });
#endif #endif
//////////////////////////// ServerState //////////////////////////// //////////////////////////// ServerState ////////////////////////////
DGL_REGISTER_GLOBAL("distributed.server_state._CAPI_DGLRPCGetServerState") DGL_REGISTER_GLOBAL("distributed.server_state._CAPI_DGLRPCGetServerState")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
auto st = RPCContext::getInstance()->server_state; auto st = RPCContext::getInstance()->server_state;
if (st.get() == nullptr) { if (st.get() == nullptr) {
RPCContext::getInstance()->server_state = std::make_shared<ServerState>(); RPCContext::getInstance()->server_state =
} std::make_shared<ServerState>();
*rv = st; }
}); *rv = st;
});
//////////////////////////// KVStore //////////////////////////// //////////////////////////// KVStore ////////////////////////////
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetGlobalIDFromLocalPartition") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetGlobalIDFromLocalPartition")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
NDArray ID = args[0]; NDArray ID = args[0];
NDArray part_id = args[1]; NDArray part_id = args[1];
int local_machine_id = args[2]; int local_machine_id = args[2];
int64_t* ID_data = static_cast<int64_t*>(ID->data); int64_t* ID_data = static_cast<int64_t*>(ID->data);
int64_t* part_id_data = static_cast<int64_t*>(part_id->data); int64_t* part_id_data = static_cast<int64_t*>(part_id->data);
int64_t ID_size = ID.GetSize() / sizeof(int64_t); int64_t ID_size = ID.GetSize() / sizeof(int64_t);
std::vector<int64_t> global_id; std::vector<int64_t> global_id;
for (int64_t i = 0; i < ID_size; ++i) { for (int64_t i = 0; i < ID_size; ++i) {
if (part_id_data[i] == local_machine_id) { if (part_id_data[i] == local_machine_id) {
global_id.push_back(ID_data[i]); global_id.push_back(ID_data[i]);
} }
} }
NDArray res_tensor = dgl::aten::VecToIdArray<int64_t>(global_id); NDArray res_tensor = dgl::aten::VecToIdArray<int64_t>(global_id);
*rv = res_tensor; *rv = res_tensor;
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFastPull") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFastPull")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
// Input // Input
std::string name = args[0]; std::string name = args[0];
int local_machine_id = args[1]; int local_machine_id = args[1];
int machine_count = args[2]; int machine_count = args[2];
int group_count = args[3]; int group_count = args[3];
int client_id = args[4]; int client_id = args[4];
int service_id = args[5]; int service_id = args[5];
int64_t msg_seq = args[6]; int64_t msg_seq = args[6];
std::string pickle_data = args[7]; std::string pickle_data = args[7];
NDArray ID = args[8]; NDArray ID = args[8];
NDArray part_id = args[9]; NDArray part_id = args[9];
NDArray local_id = args[10]; NDArray local_id = args[10];
NDArray local_data = args[11]; NDArray local_data = args[11];
// Data // Data
dgl_id_t ID_size = ID.GetSize() / sizeof(dgl_id_t); dgl_id_t ID_size = ID.GetSize() / sizeof(dgl_id_t);
dgl_id_t* ID_data = static_cast<dgl_id_t*>(ID->data); dgl_id_t* ID_data = static_cast<dgl_id_t*>(ID->data);
dgl_id_t* part_id_data = static_cast<dgl_id_t*>(part_id->data); dgl_id_t* part_id_data = static_cast<dgl_id_t*>(part_id->data);
dgl_id_t* local_id_data = static_cast<dgl_id_t*>(local_id->data); dgl_id_t* local_id_data = static_cast<dgl_id_t*>(local_id->data);
char* local_data_char = static_cast<char*>(local_data->data); char* local_data_char = static_cast<char*>(local_data->data);
std::vector<dgl_id_t> local_ids; std::vector<dgl_id_t> local_ids;
std::vector<dgl_id_t> local_ids_orginal; std::vector<dgl_id_t> local_ids_orginal;
std::vector<int64_t> local_data_shape; std::vector<int64_t> local_data_shape;
std::vector<std::vector<dgl_id_t>> remote_ids(machine_count); std::vector<std::vector<dgl_id_t>> remote_ids(machine_count);
std::vector<std::vector<dgl_id_t>> remote_ids_original(machine_count); std::vector<std::vector<dgl_id_t>> remote_ids_original(machine_count);
// Get row size (in bytes) // Get row size (in bytes)
int row_size = 1; int row_size = 1;
for (int i = 0; i < local_data->ndim; ++i) { for (int i = 0; i < local_data->ndim; ++i) {
local_data_shape.push_back(local_data->shape[i]); local_data_shape.push_back(local_data->shape[i]);
if (i != 0) { if (i != 0) {
row_size *= local_data->shape[i]; row_size *= local_data->shape[i];
} }
} }
row_size *= (local_data->dtype.bits / 8); row_size *= (local_data->dtype.bits / 8);
size_t data_size = local_data.GetSize(); size_t data_size = local_data.GetSize();
CHECK_GT(local_data_shape.size(), 0); CHECK_GT(local_data_shape.size(), 0);
CHECK_EQ(row_size * local_data_shape[0], data_size); CHECK_EQ(row_size * local_data_shape[0], data_size);
// Get local id (used in local machine) and // Get local id (used in local machine) and
// remote id (send to remote machine) // remote id (send to remote machine)
dgl_id_t idx = 0; dgl_id_t idx = 0;
for (dgl_id_t i = 0; i < ID_size; ++i) { for (dgl_id_t i = 0; i < ID_size; ++i) {
dgl_id_t p_id = part_id_data[i]; dgl_id_t p_id = part_id_data[i];
if (static_cast<int>(p_id) == local_machine_id) { if (static_cast<int>(p_id) == local_machine_id) {
dgl_id_t l_id = local_id_data[idx++]; dgl_id_t l_id = local_id_data[idx++];
CHECK_LT(l_id, local_data_shape[0]); CHECK_LT(l_id, local_data_shape[0]);
CHECK_GE(l_id, 0); CHECK_GE(l_id, 0);
local_ids.push_back(l_id); local_ids.push_back(l_id);
local_ids_orginal.push_back(i); local_ids_orginal.push_back(i);
} else { } else {
CHECK_LT(p_id, machine_count) << "Invalid partition ID."; CHECK_LT(p_id, machine_count) << "Invalid partition ID.";
dgl_id_t id = ID_data[i]; dgl_id_t id = ID_data[i];
remote_ids[p_id].push_back(id); remote_ids[p_id].push_back(id);
remote_ids_original[p_id].push_back(i); remote_ids_original[p_id].push_back(i);
} }
} }
// Send remote id // Send remote id
int msg_count = 0; int msg_count = 0;
for (size_t i = 0; i < remote_ids.size(); ++i) { for (size_t i = 0; i < remote_ids.size(); ++i) {
if (remote_ids[i].size() != 0) { if (remote_ids[i].size() != 0) {
RPCMessage msg; RPCMessage msg;
msg.service_id = service_id; msg.service_id = service_id;
msg.msg_seq = msg_seq; msg.msg_seq = msg_seq;
msg.client_id = client_id; msg.client_id = client_id;
int lower = i * group_count; int lower = i * group_count;
int upper = (i + 1) * group_count; int upper = (i + 1) * group_count;
msg.server_id = dgl::RandomEngine::ThreadLocal()->RandInt(lower, upper); msg.server_id =
msg.data = pickle_data; dgl::RandomEngine::ThreadLocal()->RandInt(lower, upper);
NDArray tensor = dgl::aten::VecToIdArray<dgl_id_t>(remote_ids[i]); msg.data = pickle_data;
msg.tensors.push_back(tensor); NDArray tensor = dgl::aten::VecToIdArray<dgl_id_t>(remote_ids[i]);
msg.group_id = RPCContext::getInstance()->group_id; msg.tensors.push_back(tensor);
SendRPCMessage(msg, msg.server_id); msg.group_id = RPCContext::getInstance()->group_id;
msg_count++; SendRPCMessage(msg, msg.server_id);
} msg_count++;
} }
local_data_shape[0] = ID_size; }
NDArray res_tensor = NDArray::Empty(local_data_shape, local_data_shape[0] = ID_size;
local_data->dtype, NDArray res_tensor = NDArray::Empty(
DGLContext{kDGLCPU, 0}); local_data_shape, local_data->dtype, DGLContext{kDGLCPU, 0});
char* return_data = static_cast<char*>(res_tensor->data); char* return_data = static_cast<char*>(res_tensor->data);
// Copy local data // Copy local data
parallel_for(0, local_ids.size(), [&](size_t b, size_t e) { parallel_for(0, local_ids.size(), [&](size_t b, size_t e) {
for (auto i = b; i < e; ++i) { for (auto i = b; i < e; ++i) {
CHECK_GE(ID_size * row_size, CHECK_GE(
local_ids_orginal[i] * row_size + row_size); ID_size * row_size, local_ids_orginal[i] * row_size + row_size);
CHECK_GE(data_size, local_ids[i] * row_size + row_size); CHECK_GE(data_size, local_ids[i] * row_size + row_size);
CHECK_GE(local_ids[i], 0); CHECK_GE(local_ids[i], 0);
memcpy(return_data + local_ids_orginal[i] * row_size, memcpy(
local_data_char + local_ids[i] * row_size, row_size); return_data + local_ids_orginal[i] * row_size,
} local_data_char + local_ids[i] * row_size, row_size);
}); }
// Recv remote message });
int recv_cnt = 0; // Recv remote message
while (recv_cnt < msg_count) { int recv_cnt = 0;
RPCMessage msg; while (recv_cnt < msg_count) {
auto status = RecvRPCMessage(&msg, 0); RPCMessage msg;
CHECK_EQ(status, kRPCSuccess); auto status = RecvRPCMessage(&msg, 0);
++recv_cnt; CHECK_EQ(status, kRPCSuccess);
int part_id = msg.server_id / group_count; ++recv_cnt;
char* data_char = static_cast<char*>(msg.tensors[0]->data); int part_id = msg.server_id / group_count;
dgl_id_t id_size = remote_ids[part_id].size(); char* data_char = static_cast<char*>(msg.tensors[0]->data);
for (size_t n = 0; n < id_size; ++n) { dgl_id_t id_size = remote_ids[part_id].size();
memcpy(return_data + remote_ids_original[part_id][n] * row_size, for (size_t n = 0; n < id_size; ++n) {
data_char + n * row_size, row_size); memcpy(
} return_data + remote_ids_original[part_id][n] * row_size,
} data_char + n * row_size, row_size);
*rv = res_tensor; }
}); }
*rv = res_tensor;
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetGroupID") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetGroupID")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
*rv = RPCContext::getInstance()->group_id; *rv = RPCContext::getInstance()->group_id;
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetGroupID") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetGroupID")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const int32_t group_id = args[0]; const int32_t group_id = args[0];
RPCContext::getInstance()->group_id = group_id; RPCContext::getInstance()->group_id = group_id;
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetGroupId") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetGroupId")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const RPCMessageRef msg = args[0]; const RPCMessageRef msg = args[0];
*rv = msg->group_id; *rv = msg->group_id;
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCRegisterClient") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCRegisterClient")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const int32_t client_id = args[0]; const int32_t client_id = args[0];
const int32_t group_id = args[1]; const int32_t group_id = args[1];
*rv = RPCContext::getInstance()->RegisterClient(client_id, group_id); *rv = RPCContext::getInstance()->RegisterClient(client_id, group_id);
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetClient") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetClient")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const int32_t client_id = args[0]; const int32_t client_id = args[0];
const int32_t group_id = args[1]; const int32_t group_id = args[1];
*rv = RPCContext::getInstance()->GetClient(client_id, group_id); *rv = RPCContext::getInstance()->GetClient(client_id, group_id);
}); });
} // namespace rpc } // namespace rpc
} // namespace dgl } // namespace dgl
......
...@@ -6,24 +6,25 @@ ...@@ -6,24 +6,25 @@
#ifndef DGL_RPC_RPC_H_ #ifndef DGL_RPC_RPC_H_
#define DGL_RPC_RPC_H_ #define DGL_RPC_RPC_H_
#include <dgl/runtime/object.h>
#include <dgl/runtime/ndarray.h> #include <dgl/runtime/ndarray.h>
#include <dgl/runtime/object.h>
#include <dgl/zerocopy_serializer.h> #include <dgl/zerocopy_serializer.h>
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
#include <cstdint> #include <cstdint>
#include <memory>
#include <deque> #include <deque>
#include <vector> #include <memory>
#include <string>
#include <mutex> #include <mutex>
#include <string>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "./network/common.h"
#include "./rpc_msg.h" #include "./rpc_msg.h"
#include "./server_state.h"
#include "net_type.h" #include "net_type.h"
#include "network/socket_communicator.h" #include "network/socket_communicator.h"
#include "tensorpipe/tp_communicator.h" #include "tensorpipe/tp_communicator.h"
#include "./network/common.h"
#include "./server_state.h"
namespace dgl { namespace dgl {
namespace rpc { namespace rpc {
...@@ -138,7 +139,7 @@ struct RPCContext { ...@@ -138,7 +139,7 @@ struct RPCContext {
} }
int32_t RegisterClient(int32_t client_id, int32_t group_id) { int32_t RegisterClient(int32_t client_id, int32_t group_id) {
auto &&m = clients_[group_id]; auto&& m = clients_[group_id];
if (m.find(client_id) != m.end()) { if (m.find(client_id) != m.end()) {
return -1; return -1;
} }
...@@ -150,7 +151,7 @@ struct RPCContext { ...@@ -150,7 +151,7 @@ struct RPCContext {
if (clients_.find(group_id) == clients_.end()) { if (clients_.find(group_id) == clients_.end()) {
return -1; return -1;
} }
const auto &m = clients_.at(group_id); const auto& m = clients_.at(group_id);
if (m.find(client_id) == m.end()) { if (m.find(client_id) == m.end()) {
return -1; return -1;
} }
......
...@@ -6,8 +6,8 @@ ...@@ -6,8 +6,8 @@
#ifndef DGL_RPC_RPC_MSG_H_ #ifndef DGL_RPC_RPC_MSG_H_
#define DGL_RPC_RPC_MSG_H_ #define DGL_RPC_RPC_MSG_H_
#include <dgl/runtime/object.h>
#include <dgl/runtime/ndarray.h> #include <dgl/runtime/ndarray.h>
#include <dgl/runtime/object.h>
#include <dgl/zerocopy_serializer.h> #include <dgl/zerocopy_serializer.h>
#include <string> #include <string>
......
...@@ -7,11 +7,12 @@ ...@@ -7,11 +7,12 @@
#ifndef DGL_RPC_SERVER_STATE_H_ #ifndef DGL_RPC_SERVER_STATE_H_
#define DGL_RPC_SERVER_STATE_H_ #define DGL_RPC_SERVER_STATE_H_
#include <dgl/runtime/object.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/base_heterograph.h> #include <dgl/base_heterograph.h>
#include <unordered_map> #include <dgl/runtime/ndarray.h>
#include <dgl/runtime/object.h>
#include <string> #include <string>
#include <unordered_map>
namespace dgl { namespace dgl {
namespace rpc { namespace rpc {
......
...@@ -9,10 +9,11 @@ ...@@ -9,10 +9,11 @@
#define DGL_RPC_TENSORPIPE_QUEUE_H_ #define DGL_RPC_TENSORPIPE_QUEUE_H_
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <chrono>
#include <condition_variable> #include <condition_variable>
#include <deque> #include <deque>
#include <mutex> #include <mutex>
#include <chrono>
#include <utility> #include <utility>
namespace dgl { namespace dgl {
...@@ -39,8 +40,9 @@ class Queue { ...@@ -39,8 +40,9 @@ class Queue {
DLOG(WARNING) << "Will wait infinitely until message is popped..."; DLOG(WARNING) << "Will wait infinitely until message is popped...";
cv_.wait(lock, [this] { return items_.size() > 0; }); cv_.wait(lock, [this] { return items_.size() > 0; });
} else { } else {
if (!cv_.wait_for(lock, std::chrono::milliseconds(timeout), if (!cv_.wait_for(lock, std::chrono::milliseconds(timeout), [this] {
[this] { return items_.size() > 0; })) { return items_.size() > 0;
})) {
DLOG(WARNING) << "Times out for popping message after " << timeout DLOG(WARNING) << "Times out for popping message after " << timeout
<< " milliseconds."; << " milliseconds.";
return false; return false;
......
...@@ -48,8 +48,8 @@ void TPSender::Send(const RPCMessage &msg, int recv_id) { ...@@ -48,8 +48,8 @@ void TPSender::Send(const RPCMessage &msg, int recv_id) {
StreamWithBuffer zc_write_strm(zerocopy_blob_ptr, true); StreamWithBuffer zc_write_strm(zerocopy_blob_ptr, true);
zc_write_strm.Write(msg); zc_write_strm.Write(msg);
int32_t nonempty_ndarray_count = zc_write_strm.buffer_list().size(); int32_t nonempty_ndarray_count = zc_write_strm.buffer_list().size();
zerocopy_blob_ptr->append(reinterpret_cast<char *>(&nonempty_ndarray_count), zerocopy_blob_ptr->append(
sizeof(int32_t)); reinterpret_cast<char *>(&nonempty_ndarray_count), sizeof(int32_t));
tp_msg.tensors.resize(nonempty_ndarray_count); tp_msg.tensors.resize(nonempty_ndarray_count);
// Hold the NDArray that ensure it's valid until write operation completes // Hold the NDArray that ensure it's valid until write operation completes
auto ndarray_holder = std::make_shared<std::vector<NDArray>>(); auto ndarray_holder = std::make_shared<std::vector<NDArray>>();
...@@ -68,14 +68,14 @@ void TPSender::Send(const RPCMessage &msg, int recv_id) { ...@@ -68,14 +68,14 @@ void TPSender::Send(const RPCMessage &msg, int recv_id) {
} }
// Let's write blockingly in case of congestion in underlying transports. // Let's write blockingly in case of congestion in underlying transports.
auto done = std::make_shared<std::promise<void>>(); auto done = std::make_shared<std::promise<void>>();
pipe->write(tp_msg, pipe->write(
[ndarray_holder, recv_id, done](const tensorpipe::Error &error) { tp_msg, [ndarray_holder, recv_id, done](const tensorpipe::Error &error) {
if (error) { if (error) {
LOG(FATAL) << "Failed to send message to " << recv_id LOG(FATAL) << "Failed to send message to " << recv_id
<< ". Details: " << error.what(); << ". Details: " << error.what();
} }
done->set_value(); done->set_value();
}); });
done->get_future().wait(); done->get_future().wait();
} }
...@@ -120,7 +120,8 @@ void TPReceiver::OnAccepted(const Error &error, std::shared_ptr<Pipe> pipe) { ...@@ -120,7 +120,8 @@ void TPReceiver::OnAccepted(const Error &error, std::shared_ptr<Pipe> pipe) {
if (error.isOfType<ListenerClosedError>()) { if (error.isOfType<ListenerClosedError>()) {
// Expected. // Expected.
} else { } else {
LOG(WARNING) << "Unexpected error when accepting incoming pipe: " << error.what(); LOG(WARNING) << "Unexpected error when accepting incoming pipe: "
<< error.what();
} }
return; return;
} }
...@@ -133,7 +134,8 @@ void TPReceiver::OnAccepted(const Error &error, std::shared_ptr<Pipe> pipe) { ...@@ -133,7 +134,8 @@ void TPReceiver::OnAccepted(const Error &error, std::shared_ptr<Pipe> pipe) {
// read the handshake message: "dglconnect" // read the handshake message: "dglconnect"
pipe->readDescriptor([pipe, this](const Error &error, Descriptor descriptor) { pipe->readDescriptor([pipe, this](const Error &error, Descriptor descriptor) {
if (error) { if (error) {
LOG(ERROR) << "Unexpected error when reading from accepted pipe: " << error.what(); LOG(ERROR) << "Unexpected error when reading from accepted pipe: "
<< error.what();
return; return;
} }
Allocation allocation; Allocation allocation;
...@@ -145,10 +147,10 @@ void TPReceiver::OnAccepted(const Error &error, std::shared_ptr<Pipe> pipe) { ...@@ -145,10 +147,10 @@ void TPReceiver::OnAccepted(const Error &error, std::shared_ptr<Pipe> pipe) {
}); });
} }
void TPReceiver::ReceiveFromPipe(std::shared_ptr<Pipe> pipe, void TPReceiver::ReceiveFromPipe(
std::shared_ptr<RPCMessageQueue> queue) { std::shared_ptr<Pipe> pipe, std::shared_ptr<RPCMessageQueue> queue) {
pipe->readDescriptor([pipe, queue = std::move(queue)](const Error &error, pipe->readDescriptor([pipe, queue = std::move(queue)](
Descriptor descriptor) { const Error &error, Descriptor descriptor) {
if (error) { if (error) {
// Error may happen when the pipe is closed // Error may happen when the pipe is closed
return; return;
...@@ -165,31 +167,33 @@ void TPReceiver::ReceiveFromPipe(std::shared_ptr<Pipe> pipe, ...@@ -165,31 +167,33 @@ void TPReceiver::ReceiveFromPipe(std::shared_ptr<Pipe> pipe,
allocation.tensors[i].buffer = cpu_buffer; allocation.tensors[i].buffer = cpu_buffer;
} }
} }
pipe->read(allocation, [allocation, descriptor = std::move(descriptor), pipe->read(
queue = std::move(queue), allocation, [allocation, descriptor = std::move(descriptor),
pipe](const Error &error) { queue = std::move(queue), pipe](const Error &error) {
if (error) { if (error) {
// Because we always have a read event posted to the epoll, // Because we always have a read event posted to the epoll,
// Therefore when pipe is closed, error will be raised. // Therefore when pipe is closed, error will be raised.
// But this error is expected. // But this error is expected.
// Other error is not expected. But we cannot identify the error with // Other error is not expected. But we cannot identify the error
// each Other for now. Thus here we skip handling for all errors // with each Other for now. Thus here we skip handling for all
return; // errors
} return;
}
char *meta_msg_begin = const_cast<char *>(&descriptor.metadata[0]);
std::vector<void *> buffer_list(descriptor.tensors.size()); char *meta_msg_begin = const_cast<char *>(&descriptor.metadata[0]);
for (size_t i = 0; i < descriptor.tensors.size(); i++) { std::vector<void *> buffer_list(descriptor.tensors.size());
buffer_list[i] = allocation.tensors[i].buffer.unwrap<CpuBuffer>().ptr; for (size_t i = 0; i < descriptor.tensors.size(); i++) {
} buffer_list[i] =
StreamWithBuffer zc_read_strm( allocation.tensors[i].buffer.unwrap<CpuBuffer>().ptr;
meta_msg_begin, descriptor.metadata.size() - sizeof(int32_t), }
buffer_list); StreamWithBuffer zc_read_strm(
RPCMessage msg; meta_msg_begin, descriptor.metadata.size() - sizeof(int32_t),
zc_read_strm.Read(&msg); buffer_list);
queue->push(msg); RPCMessage msg;
TPReceiver::ReceiveFromPipe(pipe, queue); zc_read_strm.Read(&msg);
}); queue->push(msg);
TPReceiver::ReceiveFromPipe(pipe, queue);
});
}); });
} }
......
...@@ -9,15 +9,16 @@ ...@@ -9,15 +9,16 @@
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <tensorpipe/tensorpipe.h> #include <tensorpipe/tensorpipe.h>
#include <atomic>
#include <deque> #include <deque>
#include <memory> #include <memory>
#include <string> #include <string>
#include <thread> #include <thread>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include <atomic>
#include "./queue.h"
#include "../net_type.h" #include "../net_type.h"
#include "./queue.h"
namespace dgl { namespace dgl {
namespace rpc { namespace rpc {
...@@ -47,11 +48,12 @@ class TPSender : public RPCSender { ...@@ -47,11 +48,12 @@ class TPSender : public RPCSender {
/*! /*!
* \brief Connect to a receiver. * \brief Connect to a receiver.
* *
* When there are multiple receivers to be connected, application will call `ConnectReceiver` * When there are multiple receivers to be connected, application will call
* for each and then call `ConnectReceiverFinalize` to make sure that either all the connections are * `ConnectReceiver` for each and then call `ConnectReceiverFinalize` to make
* successfully established or some of them fail. * sure that either all the connections are successfully established or some
* * of them fail.
*
* \param addr Networking address, e.g., 'tcp://127.0.0.1:50091' * \param addr Networking address, e.g., 'tcp://127.0.0.1:50091'
* \param recv_id receiver's ID * \param recv_id receiver's ID
* \return True for success and False for fail * \return True for success and False for fail
...@@ -75,7 +77,7 @@ class TPSender : public RPCSender { ...@@ -75,7 +77,7 @@ class TPSender : public RPCSender {
/*! /*!
* \brief Communicator type: 'tp' * \brief Communicator type: 'tp'
*/ */
const std::string &NetType() const override { const std::string& NetType() const override {
static const std::string net_type = "tensorpipe"; static const std::string net_type = "tensorpipe";
return net_type; return net_type;
} }
...@@ -90,7 +92,7 @@ class TPSender : public RPCSender { ...@@ -90,7 +92,7 @@ class TPSender : public RPCSender {
* \brief pipe for each connection of receiver * \brief pipe for each connection of receiver
*/ */
std::unordered_map<int /* receiver ID */, std::shared_ptr<tensorpipe::Pipe>> std::unordered_map<int /* receiver ID */, std::shared_ptr<tensorpipe::Pipe>>
pipes_; pipes_;
/*! /*!
* \brief receivers' listening address * \brief receivers' listening address
...@@ -129,13 +131,14 @@ class TPReceiver : public RPCReceiver { ...@@ -129,13 +131,14 @@ class TPReceiver : public RPCReceiver {
* *
* Wait() is not thread-safe and only one thread can invoke this API. * Wait() is not thread-safe and only one thread can invoke this API.
*/ */
bool Wait(const std::string &addr, int num_sender, bool Wait(
bool blocking = true) override; const std::string& addr, int num_sender, bool blocking = true) override;
/*! /*!
* \brief Recv RPCMessage from Sender. Actually removing data from queue. * \brief Recv RPCMessage from Sender. Actually removing data from queue.
* \param msg pointer of RPCmessage * \param msg pointer of RPCmessage
* \param timeout The timeout value in milliseconds. If zero, wait indefinitely. * \param timeout The timeout value in milliseconds. If zero, wait
* indefinitely.
* \return RPCStatus: kRPCSuccess or kRPCTimeOut. * \return RPCStatus: kRPCSuccess or kRPCTimeOut.
*/ */
RPCStatus Recv(RPCMessage* msg, int timeout) override; RPCStatus Recv(RPCMessage* msg, int timeout) override;
...@@ -150,7 +153,7 @@ class TPReceiver : public RPCReceiver { ...@@ -150,7 +153,7 @@ class TPReceiver : public RPCReceiver {
/*! /*!
* \brief Communicator type: 'tp' (tensorpipe) * \brief Communicator type: 'tp' (tensorpipe)
*/ */
const std::string &NetType() const override { const std::string& NetType() const override {
static const std::string net_type = "tensorpipe"; static const std::string net_type = "tensorpipe";
return net_type; return net_type;
} }
...@@ -158,8 +161,9 @@ class TPReceiver : public RPCReceiver { ...@@ -158,8 +161,9 @@ class TPReceiver : public RPCReceiver {
/*! /*!
* \brief Issue a receive request on pipe, and push the result into queue * \brief Issue a receive request on pipe, and push the result into queue
*/ */
static void ReceiveFromPipe(std::shared_ptr<tensorpipe::Pipe> pipe, static void ReceiveFromPipe(
std::shared_ptr<RPCMessageQueue> queue); std::shared_ptr<tensorpipe::Pipe> pipe,
std::shared_ptr<RPCMessageQueue> queue);
private: private:
/*! /*!
...@@ -186,9 +190,9 @@ class TPReceiver : public RPCReceiver { ...@@ -186,9 +190,9 @@ class TPReceiver : public RPCReceiver {
/*! /*!
* \brief pipe for each client connections * \brief pipe for each client connections
*/ */
std::unordered_map<int /* Sender (virutal) ID */, std::unordered_map<
std::shared_ptr<tensorpipe::Pipe>> int /* Sender (virutal) ID */, std::shared_ptr<tensorpipe::Pipe>>
pipes_; pipes_;
/*! /*!
* \brief RPCMessage queue * \brief RPCMessage queue
......
...@@ -3,16 +3,18 @@ ...@@ -3,16 +3,18 @@
* Implementation of C API (reference: tvm/src/api/c_api.cc) * Implementation of C API (reference: tvm/src/api/c_api.cc)
* \file c_api.cc * \file c_api.cc
*/ */
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/c_object_api.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/runtime/object.h>
#include <dmlc/base.h> #include <dmlc/base.h>
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
#include <dgl/runtime/c_object_api.h>
#include <dgl/runtime/object.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/packed_func_ext.h>
#include <vector>
#include <string>
#include <exception> #include <exception>
#include <string>
#include <vector>
#include "runtime_base.h" #include "runtime_base.h"
/*! \brief entry to to easily hold returning information */ /*! \brief entry to to easily hold returning information */
...@@ -20,7 +22,7 @@ struct DGLAPIThreadLocalEntry { ...@@ -20,7 +22,7 @@ struct DGLAPIThreadLocalEntry {
/*! \brief result holder for returning strings */ /*! \brief result holder for returning strings */
std::vector<std::string> ret_vec_str; std::vector<std::string> ret_vec_str;
/*! \brief result holder for returning string pointers */ /*! \brief result holder for returning string pointers */
std::vector<const char *> ret_vec_charp; std::vector<const char*> ret_vec_charp;
/*! \brief result holder for retruning string */ /*! \brief result holder for retruning string */
std::string ret_str; std::string ret_str;
}; };
...@@ -44,7 +46,8 @@ struct APIAttrGetter : public AttrVisitor { ...@@ -44,7 +46,8 @@ struct APIAttrGetter : public AttrVisitor {
if (skey == key) *ret = value[0]; if (skey == key) *ret = value[0];
} }
void Visit(const char* key, uint64_t* value) final { void Visit(const char* key, uint64_t* value) final {
CHECK_LE(value[0], static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) CHECK_LE(
value[0], static_cast<uint64_t>(std::numeric_limits<int64_t>::max()))
<< "cannot return too big constant"; << "cannot return too big constant";
if (skey == key) *ret = static_cast<int64_t>(value[0]); if (skey == key) *ret = static_cast<int64_t>(value[0]);
} }
...@@ -71,30 +74,16 @@ struct APIAttrGetter : public AttrVisitor { ...@@ -71,30 +74,16 @@ struct APIAttrGetter : public AttrVisitor {
struct APIAttrDir : public AttrVisitor { struct APIAttrDir : public AttrVisitor {
std::vector<std::string>* names; std::vector<std::string>* names;
void Visit(const char* key, double* value) final { void Visit(const char* key, double* value) final { names->push_back(key); }
names->push_back(key); void Visit(const char* key, int64_t* value) final { names->push_back(key); }
} void Visit(const char* key, uint64_t* value) final { names->push_back(key); }
void Visit(const char* key, int64_t* value) final { void Visit(const char* key, bool* value) final { names->push_back(key); }
names->push_back(key); void Visit(const char* key, int* value) final { names->push_back(key); }
}
void Visit(const char* key, uint64_t* value) final {
names->push_back(key);
}
void Visit(const char* key, bool* value) final {
names->push_back(key);
}
void Visit(const char* key, int* value) final {
names->push_back(key);
}
void Visit(const char* key, std::string* value) final { void Visit(const char* key, std::string* value) final {
names->push_back(key); names->push_back(key);
} }
void Visit(const char* key, ObjectRef* value) final { void Visit(const char* key, ObjectRef* value) final { names->push_back(key); }
names->push_back(key); void Visit(const char* key, NDArray* value) final { names->push_back(key); }
}
void Visit(const char* key, NDArray* value) final {
names->push_back(key);
}
}; };
int DGLObjectFree(ObjectHandle handle) { int DGLObjectFree(ObjectHandle handle) {
...@@ -103,26 +92,22 @@ int DGLObjectFree(ObjectHandle handle) { ...@@ -103,26 +92,22 @@ int DGLObjectFree(ObjectHandle handle) {
API_END(); API_END();
} }
int DGLObjectTypeKey2Index(const char* type_key, int DGLObjectTypeKey2Index(const char* type_key, int* out_index) {
int* out_index) {
API_BEGIN(); API_BEGIN();
*out_index = static_cast<int>(Object::TypeKey2Index(type_key)); *out_index = static_cast<int>(Object::TypeKey2Index(type_key));
API_END(); API_END();
} }
int DGLObjectGetTypeIndex(ObjectHandle handle, int DGLObjectGetTypeIndex(ObjectHandle handle, int* out_index) {
int* out_index) {
API_BEGIN(); API_BEGIN();
*out_index = static_cast<int>( *out_index =
(*static_cast<DGLAPIObject*>(handle))->type_index()); static_cast<int>((*static_cast<DGLAPIObject*>(handle))->type_index());
API_END(); API_END();
} }
int DGLObjectGetAttr(ObjectHandle handle, int DGLObjectGetAttr(
const char* key, ObjectHandle handle, const char* key, DGLValue* ret_val, int* ret_type_code,
DGLValue* ret_val, int* ret_success) {
int* ret_type_code,
int* ret_success) {
API_BEGIN(); API_BEGIN();
DGLRetValue rv; DGLRetValue rv;
APIAttrGetter getter; APIAttrGetter getter;
...@@ -136,9 +121,8 @@ int DGLObjectGetAttr(ObjectHandle handle, ...@@ -136,9 +121,8 @@ int DGLObjectGetAttr(ObjectHandle handle,
} else { } else {
(*tobject)->VisitAttrs(&getter); (*tobject)->VisitAttrs(&getter);
*ret_success = getter.found_object_ref || rv.type_code() != kNull; *ret_success = getter.found_object_ref || rv.type_code() != kNull;
if (rv.type_code() == kStr || if (rv.type_code() == kStr || rv.type_code() == kDGLDataType) {
rv.type_code() == kDGLDataType) { DGLAPIThreadLocalEntry* e = DGLAPIThreadLocalStore::Get();
DGLAPIThreadLocalEntry *e = DGLAPIThreadLocalStore::Get();
e->ret_str = rv.operator std::string(); e->ret_str = rv.operator std::string();
*ret_type_code = kStr; *ret_type_code = kStr;
ret_val->v_str = e->ret_str.c_str(); ret_val->v_str = e->ret_str.c_str();
...@@ -149,10 +133,9 @@ int DGLObjectGetAttr(ObjectHandle handle, ...@@ -149,10 +133,9 @@ int DGLObjectGetAttr(ObjectHandle handle,
API_END(); API_END();
} }
int DGLObjectListAttrNames(ObjectHandle handle, int DGLObjectListAttrNames(
int *out_size, ObjectHandle handle, int* out_size, const char*** out_array) {
const char*** out_array) { DGLAPIThreadLocalEntry* ret = DGLAPIThreadLocalStore::Get();
DGLAPIThreadLocalEntry *ret = DGLAPIThreadLocalStore::Get();
API_BEGIN(); API_BEGIN();
ret->ret_vec_str.clear(); ret->ret_vec_str.clear();
DGLAPIObject* tobject = static_cast<DGLAPIObject*>(handle); DGLAPIObject* tobject = static_cast<DGLAPIObject*>(handle);
......
...@@ -3,18 +3,20 @@ ...@@ -3,18 +3,20 @@
* \file c_runtime_api.cc * \file c_runtime_api.cc
* \brief Runtime API implementation * \brief Runtime API implementation
*/ */
#include <dmlc/thread_local.h>
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/c_backend_api.h> #include <dgl/runtime/c_backend_api.h>
#include <dgl/runtime/packed_func.h> #include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/device_api.h>
#include <dgl/runtime/module.h> #include <dgl/runtime/module.h>
#include <dgl/runtime/packed_func.h>
#include <dgl/runtime/registry.h> #include <dgl/runtime/registry.h>
#include <dgl/runtime/device_api.h>
#include <dgl/runtime/tensordispatch.h> #include <dgl/runtime/tensordispatch.h>
#include <array> #include <dmlc/thread_local.h>
#include <algorithm> #include <algorithm>
#include <string> #include <array>
#include <cstdlib> #include <cstdlib>
#include <string>
#include "runtime_base.h" #include "runtime_base.h"
namespace dgl { namespace dgl {
...@@ -26,10 +28,14 @@ namespace runtime { ...@@ -26,10 +28,14 @@ namespace runtime {
*/ */
inline std::string DeviceName(int type) { inline std::string DeviceName(int type) {
switch (type) { switch (type) {
case kDGLCPU: return "cpu"; case kDGLCPU:
case kDGLCUDA: return "cuda"; return "cpu";
case kDGLCUDA:
return "cuda";
// add more device here once supported // add more device here once supported
default: LOG(FATAL) << "unknown type =" << type; return "Unknown"; default:
LOG(FATAL) << "unknown type =" << type;
return "Unknown";
} }
} }
...@@ -37,9 +43,7 @@ class DeviceAPIManager { ...@@ -37,9 +43,7 @@ class DeviceAPIManager {
public: public:
static const int kMaxDeviceAPI = 32; static const int kMaxDeviceAPI = 32;
// Get API // Get API
static DeviceAPI* Get(const DGLContext& ctx) { static DeviceAPI* Get(const DGLContext& ctx) { return Get(ctx.device_type); }
return Get(ctx.device_type);
}
static DeviceAPI* Get(int dev_type, bool allow_missing = false) { static DeviceAPI* Get(int dev_type, bool allow_missing = false) {
return Global()->GetAPI(dev_type, allow_missing); return Global()->GetAPI(dev_type, allow_missing);
} }
...@@ -49,9 +53,7 @@ class DeviceAPIManager { ...@@ -49,9 +53,7 @@ class DeviceAPIManager {
DeviceAPI* rpc_api_{nullptr}; DeviceAPI* rpc_api_{nullptr};
std::mutex mutex_; std::mutex mutex_;
// constructor // constructor
DeviceAPIManager() { DeviceAPIManager() { std::fill(api_.begin(), api_.end(), nullptr); }
std::fill(api_.begin(), api_.end(), nullptr);
}
// Global static variable. // Global static variable.
static DeviceAPIManager* Global() { static DeviceAPIManager* Global() {
static DeviceAPIManager inst; static DeviceAPIManager inst;
...@@ -78,7 +80,8 @@ class DeviceAPIManager { ...@@ -78,7 +80,8 @@ class DeviceAPIManager {
auto* f = Registry::Get(factory); auto* f = Registry::Get(factory);
if (f == nullptr) { if (f == nullptr) {
CHECK(allow_missing) CHECK(allow_missing)
<< "Device API " << name << " is not enabled. Please install the cuda version of dgl."; << "Device API " << name
<< " is not enabled. Please install the cuda version of dgl.";
return nullptr; return nullptr;
} }
void* ptr = (*f)(); void* ptr = (*f)();
...@@ -95,9 +98,8 @@ DeviceAPI* DeviceAPI::Get(DGLDeviceType dev_type, bool allow_missing) { ...@@ -95,9 +98,8 @@ DeviceAPI* DeviceAPI::Get(DGLDeviceType dev_type, bool allow_missing) {
return DeviceAPIManager::Get(static_cast<int>(dev_type), allow_missing); return DeviceAPIManager::Get(static_cast<int>(dev_type), allow_missing);
} }
void* DeviceAPI::AllocWorkspace(DGLContext ctx, void* DeviceAPI::AllocWorkspace(
size_t size, DGLContext ctx, size_t size, DGLDataType type_hint) {
DGLDataType type_hint) {
return AllocDataSpace(ctx, size, kTempAllocaAlignment, type_hint); return AllocDataSpace(ctx, size, kTempAllocaAlignment, type_hint);
} }
...@@ -114,9 +116,8 @@ void DeviceAPI::FreeStream(DGLContext ctx, DGLStreamHandle stream) { ...@@ -114,9 +116,8 @@ void DeviceAPI::FreeStream(DGLContext ctx, DGLStreamHandle stream) {
LOG(FATAL) << "Device does not support stream api."; LOG(FATAL) << "Device does not support stream api.";
} }
void DeviceAPI::SyncStreamFromTo(DGLContext ctx, void DeviceAPI::SyncStreamFromTo(
DGLStreamHandle event_src, DGLContext ctx, DGLStreamHandle event_src, DGLStreamHandle event_dst) {
DGLStreamHandle event_dst) {
LOG(FATAL) << "Device does not support stream api."; LOG(FATAL) << "Device does not support stream api.";
} }
...@@ -140,7 +141,7 @@ struct DGLRuntimeEntry { ...@@ -140,7 +141,7 @@ struct DGLRuntimeEntry {
typedef dmlc::ThreadLocalStore<DGLRuntimeEntry> DGLAPIRuntimeStore; typedef dmlc::ThreadLocalStore<DGLRuntimeEntry> DGLAPIRuntimeStore;
const char *DGLGetLastError() { const char* DGLGetLastError() {
return DGLAPIRuntimeStore::Get()->last_error.c_str(); return DGLAPIRuntimeStore::Get()->last_error.c_str();
} }
...@@ -152,30 +153,26 @@ void DGLAPISetLastError(const char* msg) { ...@@ -152,30 +153,26 @@ void DGLAPISetLastError(const char* msg) {
#endif #endif
} }
int DGLModLoadFromFile(const char* file_name, int DGLModLoadFromFile(
const char* format, const char* file_name, const char* format, DGLModuleHandle* out) {
DGLModuleHandle* out) {
API_BEGIN(); API_BEGIN();
Module m = Module::LoadFromFile(file_name, format); Module m = Module::LoadFromFile(file_name, format);
*out = new Module(m); *out = new Module(m);
API_END(); API_END();
} }
int DGLModImport(DGLModuleHandle mod, int DGLModImport(DGLModuleHandle mod, DGLModuleHandle dep) {
DGLModuleHandle dep) {
API_BEGIN(); API_BEGIN();
static_cast<Module*>(mod)->Import( static_cast<Module*>(mod)->Import(*static_cast<Module*>(dep));
*static_cast<Module*>(dep));
API_END(); API_END();
} }
int DGLModGetFunction(DGLModuleHandle mod, int DGLModGetFunction(
const char* func_name, DGLModuleHandle mod, const char* func_name, int query_imports,
int query_imports, DGLFunctionHandle* func) {
DGLFunctionHandle *func) {
API_BEGIN(); API_BEGIN();
PackedFunc pf = static_cast<Module*>(mod)->GetFunction( PackedFunc pf =
func_name, query_imports != 0); static_cast<Module*>(mod)->GetFunction(func_name, query_imports != 0);
if (pf != nullptr) { if (pf != nullptr) {
*func = new PackedFunc(pf); *func = new PackedFunc(pf);
} else { } else {
...@@ -190,20 +187,18 @@ int DGLModFree(DGLModuleHandle mod) { ...@@ -190,20 +187,18 @@ int DGLModFree(DGLModuleHandle mod) {
API_END(); API_END();
} }
int DGLBackendGetFuncFromEnv(void* mod_node, int DGLBackendGetFuncFromEnv(
const char* func_name, void* mod_node, const char* func_name, DGLFunctionHandle* func) {
DGLFunctionHandle *func) {
API_BEGIN(); API_BEGIN();
*func = (DGLFunctionHandle)( *func =
static_cast<ModuleNode*>(mod_node)->GetFuncFromEnv(func_name)); (DGLFunctionHandle)(static_cast<ModuleNode*>(mod_node)->GetFuncFromEnv(
func_name));
API_END(); API_END();
} }
void* DGLBackendAllocWorkspace(int device_type, void* DGLBackendAllocWorkspace(
int device_id, int device_type, int device_id, uint64_t size, int dtype_code_hint,
uint64_t size, int dtype_bits_hint) {
int dtype_code_hint,
int dtype_bits_hint) {
DGLContext ctx; DGLContext ctx;
ctx.device_type = static_cast<DGLDeviceType>(device_type); ctx.device_type = static_cast<DGLDeviceType>(device_type);
ctx.device_id = device_id; ctx.device_id = device_id;
...@@ -213,14 +208,11 @@ void* DGLBackendAllocWorkspace(int device_type, ...@@ -213,14 +208,11 @@ void* DGLBackendAllocWorkspace(int device_type,
type_hint.bits = static_cast<decltype(type_hint.bits)>(dtype_bits_hint); type_hint.bits = static_cast<decltype(type_hint.bits)>(dtype_bits_hint);
type_hint.lanes = 1; type_hint.lanes = 1;
return DeviceAPIManager::Get(ctx)->AllocWorkspace(ctx, return DeviceAPIManager::Get(ctx)->AllocWorkspace(
static_cast<size_t>(size), ctx, static_cast<size_t>(size), type_hint);
type_hint);
} }
int DGLBackendFreeWorkspace(int device_type, int DGLBackendFreeWorkspace(int device_type, int device_id, void* ptr) {
int device_id,
void* ptr) {
DGLContext ctx; DGLContext ctx;
ctx.device_type = static_cast<DGLDeviceType>(device_type); ctx.device_type = static_cast<DGLDeviceType>(device_type);
ctx.device_id = device_id; ctx.device_id = device_id;
...@@ -228,10 +220,7 @@ int DGLBackendFreeWorkspace(int device_type, ...@@ -228,10 +220,7 @@ int DGLBackendFreeWorkspace(int device_type,
return 0; return 0;
} }
int DGLBackendRunOnce(void** handle, int DGLBackendRunOnce(void** handle, int (*f)(void*), void* cdata, int nbytes) {
int (*f)(void*),
void* cdata,
int nbytes) {
if (*handle == nullptr) { if (*handle == nullptr) {
*handle = reinterpret_cast<void*>(1); *handle = reinterpret_cast<void*>(1);
return (*f)(cdata); return (*f)(cdata);
...@@ -245,19 +234,15 @@ int DGLFuncFree(DGLFunctionHandle func) { ...@@ -245,19 +234,15 @@ int DGLFuncFree(DGLFunctionHandle func) {
API_END(); API_END();
} }
int DGLFuncCall(DGLFunctionHandle func, int DGLFuncCall(
DGLValue* args, DGLFunctionHandle func, DGLValue* args, int* arg_type_codes, int num_args,
int* arg_type_codes, DGLValue* ret_val, int* ret_type_code) {
int num_args,
DGLValue* ret_val,
int* ret_type_code) {
API_BEGIN(); API_BEGIN();
DGLRetValue rv; DGLRetValue rv;
(*static_cast<const PackedFunc*>(func)).CallPacked( (*static_cast<const PackedFunc*>(func))
DGLArgs(args, arg_type_codes, num_args), &rv); .CallPacked(DGLArgs(args, arg_type_codes, num_args), &rv);
// handle return string. // handle return string.
if (rv.type_code() == kStr || if (rv.type_code() == kStr || rv.type_code() == kDGLDataType ||
rv.type_code() == kDGLDataType ||
rv.type_code() == kBytes) { rv.type_code() == kBytes) {
DGLRuntimeEntry* e = DGLAPIRuntimeStore::Get(); DGLRuntimeEntry* e = DGLAPIRuntimeStore::Get();
if (rv.type_code() != kDGLDataType) { if (rv.type_code() != kDGLDataType) {
...@@ -280,10 +265,8 @@ int DGLFuncCall(DGLFunctionHandle func, ...@@ -280,10 +265,8 @@ int DGLFuncCall(DGLFunctionHandle func,
API_END(); API_END();
} }
int DGLCFuncSetReturn(DGLRetValueHandle ret, int DGLCFuncSetReturn(
DGLValue* value, DGLRetValueHandle ret, DGLValue* value, int* type_code, int num_ret) {
int* type_code,
int num_ret) {
API_BEGIN(); API_BEGIN();
CHECK_EQ(num_ret, 1); CHECK_EQ(num_ret, 1);
DGLRetValue* rv = static_cast<DGLRetValue*>(ret); DGLRetValue* rv = static_cast<DGLRetValue*>(ret);
...@@ -291,16 +274,16 @@ int DGLCFuncSetReturn(DGLRetValueHandle ret, ...@@ -291,16 +274,16 @@ int DGLCFuncSetReturn(DGLRetValueHandle ret,
API_END(); API_END();
} }
int DGLFuncCreateFromCFunc(DGLPackedCFunc func, int DGLFuncCreateFromCFunc(
void* resource_handle, DGLPackedCFunc func, void* resource_handle, DGLPackedCFuncFinalizer fin,
DGLPackedCFuncFinalizer fin, DGLFunctionHandle* out) {
DGLFunctionHandle *out) {
API_BEGIN(); API_BEGIN();
if (fin == nullptr) { if (fin == nullptr) {
*out = new PackedFunc( *out =
[func, resource_handle](DGLArgs args, DGLRetValue* rv) { new PackedFunc([func, resource_handle](DGLArgs args, DGLRetValue* rv) {
int ret = func((DGLValue*)args.values, (int*)args.type_codes, // NOLINT(*) int ret = func(
args.num_args, rv, resource_handle); (DGLValue*)args.values, (int*)args.type_codes, // NOLINT(*)
args.num_args, rv, resource_handle);
if (ret != 0) { if (ret != 0) {
std::string err = "DGLCall CFunc Error:\n"; std::string err = "DGLCall CFunc Error:\n";
err += DGLGetLastError(); err += DGLGetLastError();
...@@ -311,16 +294,16 @@ int DGLFuncCreateFromCFunc(DGLPackedCFunc func, ...@@ -311,16 +294,16 @@ int DGLFuncCreateFromCFunc(DGLPackedCFunc func,
// wrap it in a shared_ptr, with fin as deleter. // wrap it in a shared_ptr, with fin as deleter.
// so fin will be called when the lambda went out of scope. // so fin will be called when the lambda went out of scope.
std::shared_ptr<void> rpack(resource_handle, fin); std::shared_ptr<void> rpack(resource_handle, fin);
*out = new PackedFunc( *out = new PackedFunc([func, rpack](DGLArgs args, DGLRetValue* rv) {
[func, rpack](DGLArgs args, DGLRetValue* rv) { int ret = func(
int ret = func((DGLValue*)args.values, (int*)args.type_codes, // NOLINT(*) (DGLValue*)args.values, (int*)args.type_codes, // NOLINT(*)
args.num_args, rv, rpack.get()); args.num_args, rv, rpack.get());
if (ret != 0) { if (ret != 0) {
std::string err = "DGLCall CFunc Error:\n"; std::string err = "DGLCall CFunc Error:\n";
err += DGLGetLastError(); err += DGLGetLastError();
throw dmlc::Error(err); throw dmlc::Error(err);
} }
}); });
} }
API_END(); API_END();
} }
...@@ -370,10 +353,8 @@ int DGLSynchronize(int device_type, int device_id, DGLStreamHandle stream) { ...@@ -370,10 +353,8 @@ int DGLSynchronize(int device_type, int device_id, DGLStreamHandle stream) {
API_END(); API_END();
} }
int DGLStreamStreamSynchronize(int device_type, int DGLStreamStreamSynchronize(
int device_id, int device_type, int device_id, DGLStreamHandle src, DGLStreamHandle dst) {
DGLStreamHandle src,
DGLStreamHandle dst) {
API_BEGIN(); API_BEGIN();
DGLContext ctx; DGLContext ctx;
ctx.device_type = static_cast<DGLDeviceType>(device_type); ctx.device_type = static_cast<DGLDeviceType>(device_type);
...@@ -392,36 +373,35 @@ int DGLCbArgToReturn(DGLValue* value, int code) { ...@@ -392,36 +373,35 @@ int DGLCbArgToReturn(DGLValue* value, int code) {
API_END(); API_END();
} }
int DGLLoadTensorAdapter(const char *path) { int DGLLoadTensorAdapter(const char* path) {
return TensorDispatcher::Global()->Load(path) ? 0 : -1; return TensorDispatcher::Global()->Load(path) ? 0 : -1;
} }
// set device api // set device api
DGL_REGISTER_GLOBAL(dgl::runtime::symbol::dgl_set_device) DGL_REGISTER_GLOBAL(dgl::runtime::symbol::dgl_set_device)
.set_body([](DGLArgs args, DGLRetValue *ret) { .set_body([](DGLArgs args, DGLRetValue* ret) {
DGLContext ctx; DGLContext ctx;
ctx.device_type = static_cast<DGLDeviceType>(args[0].operator int()); ctx.device_type = static_cast<DGLDeviceType>(args[0].operator int());
ctx.device_id = args[1]; ctx.device_id = args[1];
DeviceAPIManager::Get(ctx)->SetDevice(ctx); DeviceAPIManager::Get(ctx)->SetDevice(ctx);
}); });
// set device api // set device api
DGL_REGISTER_GLOBAL("_GetDeviceAttr") DGL_REGISTER_GLOBAL("_GetDeviceAttr")
.set_body([](DGLArgs args, DGLRetValue *ret) { .set_body([](DGLArgs args, DGLRetValue* ret) {
DGLContext ctx; DGLContext ctx;
ctx.device_type = static_cast<DGLDeviceType>(args[0].operator int()); ctx.device_type = static_cast<DGLDeviceType>(args[0].operator int());
ctx.device_id = args[1]; ctx.device_id = args[1];
DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[2].operator int()); DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[2].operator int());
if (kind == kExist) { if (kind == kExist) {
DeviceAPI* api = DeviceAPIManager::Get(ctx.device_type, true); DeviceAPI* api = DeviceAPIManager::Get(ctx.device_type, true);
if (api != nullptr) { if (api != nullptr) {
api->GetAttr(ctx, kind, ret); api->GetAttr(ctx, kind, ret);
} else {
*ret = 0;
}
} else { } else {
*ret = 0; DeviceAPIManager::Get(ctx)->GetAttr(ctx, kind, ret);
} }
} else { });
DeviceAPIManager::Get(ctx)->GetAttr(ctx, kind, ret);
}
});
...@@ -4,32 +4,28 @@ ...@@ -4,32 +4,28 @@
* \brief DGL runtime config * \brief DGL runtime config
*/ */
#include <dgl/runtime/registry.h>
#include <dgl/runtime/config.h> #include <dgl/runtime/config.h>
#include <dgl/runtime/registry.h>
using namespace dgl::runtime; using namespace dgl::runtime;
namespace dgl { namespace dgl {
namespace runtime { namespace runtime {
void Config::EnableLibxsmm(bool b) { void Config::EnableLibxsmm(bool b) { libxsmm_ = b; }
libxsmm_ = b;
}
bool Config::IsLibxsmmAvailable() const { bool Config::IsLibxsmmAvailable() const { return libxsmm_; }
return libxsmm_;
}
DGL_REGISTER_GLOBAL("global_config._CAPI_DGLConfigSetLibxsmm") DGL_REGISTER_GLOBAL("global_config._CAPI_DGLConfigSetLibxsmm")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
bool use_libxsmm = args[0]; bool use_libxsmm = args[0];
dgl::runtime::Config::Global()->EnableLibxsmm(use_libxsmm); dgl::runtime::Config::Global()->EnableLibxsmm(use_libxsmm);
}); });
DGL_REGISTER_GLOBAL("global_config._CAPI_DGLConfigGetLibxsmm") DGL_REGISTER_GLOBAL("global_config._CAPI_DGLConfigGetLibxsmm")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
*rv = dgl::runtime::Config::Global()->IsLibxsmmAvailable(); *rv = dgl::runtime::Config::Global()->IsLibxsmmAvailable();
}); });
} // namespace runtime } // namespace runtime
} // namespace dgl } // namespace dgl
...@@ -2,13 +2,15 @@ ...@@ -2,13 +2,15 @@
* Copyright (c) 2016-2022 by Contributors * Copyright (c) 2016-2022 by Contributors
* \file cpu_device_api.cc * \file cpu_device_api.cc
*/ */
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/device_api.h> #include <dgl/runtime/device_api.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/tensordispatch.h> #include <dgl/runtime/tensordispatch.h>
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <cstdlib> #include <cstdlib>
#include <cstring> #include <cstring>
#include "workspace_pool.h" #include "workspace_pool.h"
namespace dgl { namespace dgl {
...@@ -21,13 +23,11 @@ class CPUDeviceAPI final : public DeviceAPI { ...@@ -21,13 +23,11 @@ class CPUDeviceAPI final : public DeviceAPI {
*rv = 1; *rv = 1;
} }
} }
void* AllocDataSpace(DGLContext ctx, void* AllocDataSpace(
size_t nbytes, DGLContext ctx, size_t nbytes, size_t alignment,
size_t alignment, DGLDataType type_hint) final {
DGLDataType type_hint) final {
TensorDispatcher* td = TensorDispatcher::Global(); TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable()) if (td->IsAvailable()) return td->CPUAllocWorkspace(nbytes);
return td->CPUAllocWorkspace(nbytes);
void* ptr; void* ptr;
#if _MSC_VER || defined(__MINGW32__) #if _MSC_VER || defined(__MINGW32__)
...@@ -45,8 +45,7 @@ class CPUDeviceAPI final : public DeviceAPI { ...@@ -45,8 +45,7 @@ class CPUDeviceAPI final : public DeviceAPI {
void FreeDataSpace(DGLContext ctx, void* ptr) final { void FreeDataSpace(DGLContext ctx, void* ptr) final {
TensorDispatcher* td = TensorDispatcher::Global(); TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable()) if (td->IsAvailable()) return td->CPUFreeWorkspace(ptr);
return td->CPUFreeWorkspace(ptr);
#if _MSC_VER || defined(__MINGW32__) #if _MSC_VER || defined(__MINGW32__)
_aligned_free(ptr); _aligned_free(ptr);
...@@ -55,25 +54,21 @@ class CPUDeviceAPI final : public DeviceAPI { ...@@ -55,25 +54,21 @@ class CPUDeviceAPI final : public DeviceAPI {
#endif #endif
} }
void CopyDataFromTo(const void* from, void CopyDataFromTo(
size_t from_offset, const void* from, size_t from_offset, void* to, size_t to_offset,
void* to, size_t size, DGLContext ctx_from, DGLContext ctx_to,
size_t to_offset, DGLDataType type_hint) final {
size_t size, memcpy(
DGLContext ctx_from, static_cast<char*>(to) + to_offset,
DGLContext ctx_to, static_cast<const char*>(from) + from_offset, size);
DGLDataType type_hint) final {
memcpy(static_cast<char*>(to) + to_offset,
static_cast<const char*>(from) + from_offset,
size);
} }
DGLStreamHandle CreateStream(DGLContext) final { return nullptr; } DGLStreamHandle CreateStream(DGLContext) final { return nullptr; }
void StreamSync(DGLContext ctx, DGLStreamHandle stream) final { void StreamSync(DGLContext ctx, DGLStreamHandle stream) final {}
}
void* AllocWorkspace(DGLContext ctx, size_t size, DGLDataType type_hint) final; void* AllocWorkspace(
DGLContext ctx, size_t size, DGLDataType type_hint) final;
void FreeWorkspace(DGLContext ctx, void* data) final; void FreeWorkspace(DGLContext ctx, void* data) final;
static const std::shared_ptr<CPUDeviceAPI>& Global() { static const std::shared_ptr<CPUDeviceAPI>& Global() {
...@@ -84,32 +79,29 @@ class CPUDeviceAPI final : public DeviceAPI { ...@@ -84,32 +79,29 @@ class CPUDeviceAPI final : public DeviceAPI {
}; };
struct CPUWorkspacePool : public WorkspacePool { struct CPUWorkspacePool : public WorkspacePool {
CPUWorkspacePool() : CPUWorkspacePool() : WorkspacePool(kDGLCPU, CPUDeviceAPI::Global()) {}
WorkspacePool(kDGLCPU, CPUDeviceAPI::Global()) {}
}; };
void* CPUDeviceAPI::AllocWorkspace(DGLContext ctx, void* CPUDeviceAPI::AllocWorkspace(
size_t size, DGLContext ctx, size_t size, DGLDataType type_hint) {
DGLDataType type_hint) {
TensorDispatcher* td = TensorDispatcher::Global(); TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable()) if (td->IsAvailable()) return td->CPUAllocWorkspace(size);
return td->CPUAllocWorkspace(size);
return dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()->AllocWorkspace(ctx, size); return dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()->AllocWorkspace(
ctx, size);
} }
void CPUDeviceAPI::FreeWorkspace(DGLContext ctx, void* data) { void CPUDeviceAPI::FreeWorkspace(DGLContext ctx, void* data) {
TensorDispatcher* td = TensorDispatcher::Global(); TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable()) if (td->IsAvailable()) return td->CPUFreeWorkspace(data);
return td->CPUFreeWorkspace(data);
dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()->FreeWorkspace(ctx, data); dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()->FreeWorkspace(ctx, data);
} }
DGL_REGISTER_GLOBAL("device_api.cpu") DGL_REGISTER_GLOBAL("device_api.cpu")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
DeviceAPI* ptr = CPUDeviceAPI::Global().get(); DeviceAPI* ptr = CPUDeviceAPI::Global().get();
*rv = static_cast<void*>(ptr); *rv = static_cast<void*>(ptr);
}); });
} // namespace runtime } // namespace runtime
} // namespace dgl } // namespace dgl
...@@ -7,11 +7,13 @@ ...@@ -7,11 +7,13 @@
#define DGL_RUNTIME_CUDA_CUDA_COMMON_H_ #define DGL_RUNTIME_CUDA_CUDA_COMMON_H_
#include <cublas_v2.h> #include <cublas_v2.h>
#include <cusparse.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <curand.h> #include <curand.h>
#include <cusparse.h>
#include <dgl/runtime/packed_func.h> #include <dgl/runtime/packed_func.h>
#include <string> #include <string>
#include "../workspace_pool.h" #include "../workspace_pool.h"
namespace dgl { namespace dgl {
...@@ -19,94 +21,89 @@ namespace runtime { ...@@ -19,94 +21,89 @@ namespace runtime {
template <typename T> template <typename T>
inline bool is_zero(T size) { inline bool is_zero(T size) {
return size == 0; return size == 0;
} }
template <> template <>
inline bool is_zero<dim3>(dim3 size) { inline bool is_zero<dim3>(dim3 size) {
return size.x == 0 || size.y == 0 || size.z == 0; return size.x == 0 || size.y == 0 || size.z == 0;
} }
#define CUDA_DRIVER_CALL(x) \ #define CUDA_DRIVER_CALL(x) \
{ \ { \
CUresult result = x; \ CUresult result = x; \
if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) { \ if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) { \
const char *msg; \ const char* msg; \
cuGetErrorName(result, &msg); \ cuGetErrorName(result, &msg); \
LOG(FATAL) \ LOG(FATAL) << "CUDAError: " #x " failed with error: " << msg; \
<< "CUDAError: " #x " failed with error: " << msg; \
} \ } \
} }
#define CUDA_CALL(func) \ #define CUDA_CALL(func) \
{ \ { \
cudaError_t e = (func); \ cudaError_t e = (func); \
CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \ CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \
<< "CUDA: " << cudaGetErrorString(e); \ << "CUDA: " << cudaGetErrorString(e); \
} }
#define CUDA_KERNEL_CALL(kernel, nblks, nthrs, shmem, stream, ...) \ #define CUDA_KERNEL_CALL(kernel, nblks, nthrs, shmem, stream, ...) \
{ \ { \
if (!dgl::runtime::is_zero((nblks)) && \ if (!dgl::runtime::is_zero((nblks)) && !dgl::runtime::is_zero((nthrs))) { \
!dgl::runtime::is_zero((nthrs))) { \ (kernel)<<<(nblks), (nthrs), (shmem), (stream)>>>(__VA_ARGS__); \
(kernel) <<< (nblks), (nthrs), (shmem), (stream) >>> \ cudaError_t e = cudaGetLastError(); \
(__VA_ARGS__); \ CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \
cudaError_t e = cudaGetLastError(); \ << "CUDA kernel launch error: " << cudaGetErrorString(e); \
CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \ } \
<< "CUDA kernel launch error: " \
<< cudaGetErrorString(e); \
} \
} }
#define CUSPARSE_CALL(func) \ #define CUSPARSE_CALL(func) \
{ \ { \
cusparseStatus_t e = (func); \ cusparseStatus_t e = (func); \
CHECK(e == CUSPARSE_STATUS_SUCCESS) \ CHECK(e == CUSPARSE_STATUS_SUCCESS) << "CUSPARSE ERROR: " << e; \
<< "CUSPARSE ERROR: " << e; \
} }
#define CUBLAS_CALL(func) \ #define CUBLAS_CALL(func) \
{ \ { \
cublasStatus_t e = (func); \ cublasStatus_t e = (func); \
CHECK(e == CUBLAS_STATUS_SUCCESS) << "CUBLAS ERROR: " << e; \ CHECK(e == CUBLAS_STATUS_SUCCESS) << "CUBLAS ERROR: " << e; \
} }
#define CURAND_CALL(func) \ #define CURAND_CALL(func) \
{ \ { \
curandStatus_t e = (func); \ curandStatus_t e = (func); \
CHECK(e == CURAND_STATUS_SUCCESS) \ CHECK(e == CURAND_STATUS_SUCCESS) \
<< "CURAND Error: " << dgl::runtime::curandGetErrorString(e) \ << "CURAND Error: " << dgl::runtime::curandGetErrorString(e) << " at " \
<< " at " << __FILE__ << ":" << __LINE__; \ << __FILE__ << ":" << __LINE__; \
} }
inline const char* curandGetErrorString(curandStatus_t error) { inline const char* curandGetErrorString(curandStatus_t error) {
switch (error) { switch (error) {
case CURAND_STATUS_SUCCESS: case CURAND_STATUS_SUCCESS:
return "CURAND_STATUS_SUCCESS"; return "CURAND_STATUS_SUCCESS";
case CURAND_STATUS_VERSION_MISMATCH: case CURAND_STATUS_VERSION_MISMATCH:
return "CURAND_STATUS_VERSION_MISMATCH"; return "CURAND_STATUS_VERSION_MISMATCH";
case CURAND_STATUS_NOT_INITIALIZED: case CURAND_STATUS_NOT_INITIALIZED:
return "CURAND_STATUS_NOT_INITIALIZED"; return "CURAND_STATUS_NOT_INITIALIZED";
case CURAND_STATUS_ALLOCATION_FAILED: case CURAND_STATUS_ALLOCATION_FAILED:
return "CURAND_STATUS_ALLOCATION_FAILED"; return "CURAND_STATUS_ALLOCATION_FAILED";
case CURAND_STATUS_TYPE_ERROR: case CURAND_STATUS_TYPE_ERROR:
return "CURAND_STATUS_TYPE_ERROR"; return "CURAND_STATUS_TYPE_ERROR";
case CURAND_STATUS_OUT_OF_RANGE: case CURAND_STATUS_OUT_OF_RANGE:
return "CURAND_STATUS_OUT_OF_RANGE"; return "CURAND_STATUS_OUT_OF_RANGE";
case CURAND_STATUS_LENGTH_NOT_MULTIPLE: case CURAND_STATUS_LENGTH_NOT_MULTIPLE:
return "CURAND_STATUS_LENGTH_NOT_MULTIPLE"; return "CURAND_STATUS_LENGTH_NOT_MULTIPLE";
case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED: case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED:
return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED"; return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED";
case CURAND_STATUS_LAUNCH_FAILURE: case CURAND_STATUS_LAUNCH_FAILURE:
return "CURAND_STATUS_LAUNCH_FAILURE"; return "CURAND_STATUS_LAUNCH_FAILURE";
case CURAND_STATUS_PREEXISTING_FAILURE: case CURAND_STATUS_PREEXISTING_FAILURE:
return "CURAND_STATUS_PREEXISTING_FAILURE"; return "CURAND_STATUS_PREEXISTING_FAILURE";
case CURAND_STATUS_INITIALIZATION_FAILED: case CURAND_STATUS_INITIALIZATION_FAILED:
return "CURAND_STATUS_INITIALIZATION_FAILED"; return "CURAND_STATUS_INITIALIZATION_FAILED";
case CURAND_STATUS_ARCH_MISMATCH: case CURAND_STATUS_ARCH_MISMATCH:
return "CURAND_STATUS_ARCH_MISMATCH"; return "CURAND_STATUS_ARCH_MISMATCH";
case CURAND_STATUS_INTERNAL_ERROR: case CURAND_STATUS_INTERNAL_ERROR:
return "CURAND_STATUS_INTERNAL_ERROR"; return "CURAND_STATUS_INTERNAL_ERROR";
} }
// To suppress compiler warning. // To suppress compiler warning.
return "Unrecognized curand error string"; return "Unrecognized curand error string";
......
...@@ -3,11 +3,12 @@ ...@@ -3,11 +3,12 @@
* \file cuda_device_api.cc * \file cuda_device_api.cc
* \brief GPU specific API * \brief GPU specific API
*/ */
#include <cuda_runtime.h>
#include <dgl/runtime/device_api.h> #include <dgl/runtime/device_api.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/tensordispatch.h> #include <dgl/runtime/tensordispatch.h>
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
#include <dgl/runtime/registry.h>
#include <cuda_runtime.h>
#include "cuda_common.h" #include "cuda_common.h"
namespace dgl { namespace dgl {
...@@ -28,9 +29,7 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -28,9 +29,7 @@ class CUDADeviceAPI final : public DeviceAPI {
is_available_ = count > 0; is_available_ = count > 0;
} }
bool IsAvailable() final { bool IsAvailable() final { return is_available_; }
return is_available_;
}
void SetDevice(DGLContext ctx) final { void SetDevice(DGLContext ctx) final {
CUDA_CALL(cudaSetDevice(ctx.device_id)); CUDA_CALL(cudaSetDevice(ctx.device_id));
...@@ -39,10 +38,10 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -39,10 +38,10 @@ class CUDADeviceAPI final : public DeviceAPI {
int value = 0; int value = 0;
switch (kind) { switch (kind) {
case kExist: case kExist:
value = ( value =
cudaDeviceGetAttribute( (cudaDeviceGetAttribute(
&value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id) &value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id) ==
== cudaSuccess); cudaSuccess);
break; break;
case kMaxThreadsPerBlock: { case kMaxThreadsPerBlock: {
CUDA_CALL(cudaDeviceGetAttribute( CUDA_CALL(cudaDeviceGetAttribute(
...@@ -50,8 +49,8 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -50,8 +49,8 @@ class CUDADeviceAPI final : public DeviceAPI {
break; break;
} }
case kWarpSize: { case kWarpSize: {
CUDA_CALL(cudaDeviceGetAttribute( CUDA_CALL(
&value, cudaDevAttrWarpSize, ctx.device_id)); cudaDeviceGetAttribute(&value, cudaDevAttrWarpSize, ctx.device_id));
break; break;
} }
case kMaxSharedMemoryPerBlock: { case kMaxSharedMemoryPerBlock: {
...@@ -96,26 +95,24 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -96,26 +95,24 @@ class CUDADeviceAPI final : public DeviceAPI {
&dims[2], cudaDevAttrMaxBlockDimZ, ctx.device_id)); &dims[2], cudaDevAttrMaxBlockDimZ, ctx.device_id));
std::stringstream ss; // use json string to return multiple int values; std::stringstream ss; // use json string to return multiple int values;
ss << "[" << dims[0] <<", " << dims[1] << ", " << dims[2] << "]"; ss << "[" << dims[0] << ", " << dims[1] << ", " << dims[2] << "]";
*rv = ss.str(); *rv = ss.str();
return; return;
} }
} }
*rv = value; *rv = value;
} }
void* AllocDataSpace(DGLContext ctx, void* AllocDataSpace(
size_t nbytes, DGLContext ctx, size_t nbytes, size_t alignment,
size_t alignment, DGLDataType type_hint) final {
DGLDataType type_hint) final {
SetDevice(ctx); SetDevice(ctx);
// Redirect to PyTorch's allocator when available. // Redirect to PyTorch's allocator when available.
TensorDispatcher* td = TensorDispatcher::Global(); TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable()) if (td->IsAvailable())
return td->CUDAAllocWorkspace(nbytes, getCurrentCUDAStream()); return td->CUDAAllocWorkspace(nbytes, getCurrentCUDAStream());
CHECK_EQ(256 % alignment, 0U) CHECK_EQ(256 % alignment, 0U) << "CUDA space is aligned at 256 bytes";
<< "CUDA space is aligned at 256 bytes"; void* ret;
void *ret;
CUDA_CALL(cudaMalloc(&ret, nbytes)); CUDA_CALL(cudaMalloc(&ret, nbytes));
return ret; return ret;
} }
...@@ -123,21 +120,15 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -123,21 +120,15 @@ class CUDADeviceAPI final : public DeviceAPI {
void FreeDataSpace(DGLContext ctx, void* ptr) final { void FreeDataSpace(DGLContext ctx, void* ptr) final {
SetDevice(ctx); SetDevice(ctx);
TensorDispatcher* td = TensorDispatcher::Global(); TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable()) if (td->IsAvailable()) return td->CUDAFreeWorkspace(ptr);
return td->CUDAFreeWorkspace(ptr);
CUDA_CALL(cudaFree(ptr)); CUDA_CALL(cudaFree(ptr));
} }
void CopyDataFromTo(const void* from, void CopyDataFromTo(
size_t from_offset, const void* from, size_t from_offset, void* to, size_t to_offset,
void* to, size_t size, DGLContext ctx_from, DGLContext ctx_to,
size_t to_offset, DGLDataType type_hint, DGLStreamHandle stream) {
size_t size,
DGLContext ctx_from,
DGLContext ctx_to,
DGLDataType type_hint,
DGLStreamHandle stream) {
cudaStream_t cu_stream = static_cast<cudaStream_t>(stream); cudaStream_t cu_stream = static_cast<cudaStream_t>(stream);
from = static_cast<const char*>(from) + from_offset; from = static_cast<const char*>(from) + from_offset;
to = static_cast<char*>(to) + to_offset; to = static_cast<char*>(to) + to_offset;
...@@ -146,14 +137,15 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -146,14 +137,15 @@ class CUDADeviceAPI final : public DeviceAPI {
if (ctx_from.device_id == ctx_to.device_id) { if (ctx_from.device_id == ctx_to.device_id) {
GPUCopy(from, to, size, cudaMemcpyDeviceToDevice, cu_stream); GPUCopy(from, to, size, cudaMemcpyDeviceToDevice, cu_stream);
} else { } else {
CUDA_CALL(cudaMemcpyPeerAsync(to, ctx_to.device_id, CUDA_CALL(cudaMemcpyPeerAsync(
from, ctx_from.device_id, to, ctx_to.device_id, from, ctx_from.device_id, size, cu_stream));
size, cu_stream));
} }
} else if (ctx_from.device_type == kDGLCUDA && ctx_to.device_type == kDGLCPU) { } else if (
ctx_from.device_type == kDGLCUDA && ctx_to.device_type == kDGLCPU) {
CUDA_CALL(cudaSetDevice(ctx_from.device_id)); CUDA_CALL(cudaSetDevice(ctx_from.device_id));
GPUCopy(from, to, size, cudaMemcpyDeviceToHost, cu_stream); GPUCopy(from, to, size, cudaMemcpyDeviceToHost, cu_stream);
} else if (ctx_from.device_type == kDGLCPU && ctx_to.device_type == kDGLCUDA) { } else if (
ctx_from.device_type == kDGLCPU && ctx_to.device_type == kDGLCUDA) {
CUDA_CALL(cudaSetDevice(ctx_to.device_id)); CUDA_CALL(cudaSetDevice(ctx_to.device_id));
GPUCopy(from, to, size, cudaMemcpyHostToDevice, cu_stream); GPUCopy(from, to, size, cudaMemcpyHostToDevice, cu_stream);
} else { } else {
...@@ -161,16 +153,14 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -161,16 +153,14 @@ class CUDADeviceAPI final : public DeviceAPI {
} }
} }
void CopyDataFromTo(const void* from, void CopyDataFromTo(
size_t from_offset, const void* from, size_t from_offset, void* to, size_t to_offset,
void* to, size_t size, DGLContext ctx_from, DGLContext ctx_to,
size_t to_offset, DGLDataType type_hint) final {
size_t size,
DGLContext ctx_from,
DGLContext ctx_to,
DGLDataType type_hint) final {
auto stream = GetStream(); auto stream = GetStream();
CopyDataFromTo(from, from_offset, to, to_offset, size, ctx_from, ctx_to, type_hint, stream); CopyDataFromTo(
from, from_offset, to, to_offset, size, ctx_from, ctx_to, type_hint,
stream);
} }
DGLStreamHandle CreateStream(DGLContext ctx) { DGLStreamHandle CreateStream(DGLContext ctx) {
...@@ -187,7 +177,8 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -187,7 +177,8 @@ class CUDADeviceAPI final : public DeviceAPI {
CUDA_CALL(cudaStreamDestroy(cu_stream)); CUDA_CALL(cudaStreamDestroy(cu_stream));
} }
void SyncStreamFromTo(DGLContext ctx, DGLStreamHandle event_src, DGLStreamHandle event_dst) { void SyncStreamFromTo(
DGLContext ctx, DGLStreamHandle event_src, DGLStreamHandle event_dst) {
CUDA_CALL(cudaSetDevice(ctx.device_id)); CUDA_CALL(cudaSetDevice(ctx.device_id));
cudaStream_t src_stream = static_cast<cudaStream_t>(event_src); cudaStream_t src_stream = static_cast<cudaStream_t>(event_src);
cudaStream_t dst_stream = static_cast<cudaStream_t>(event_dst); cudaStream_t dst_stream = static_cast<cudaStream_t>(event_dst);
...@@ -222,54 +213,54 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -222,54 +213,54 @@ class CUDADeviceAPI final : public DeviceAPI {
*/ */
void PinData(void* ptr, size_t nbytes) { void PinData(void* ptr, size_t nbytes) {
// prevent users from pinning empty tensors or graphs // prevent users from pinning empty tensors or graphs
if (ptr == nullptr || nbytes == 0) if (ptr == nullptr || nbytes == 0) return;
return;
CUDA_CALL(cudaHostRegister(ptr, nbytes, cudaHostRegisterDefault)); CUDA_CALL(cudaHostRegister(ptr, nbytes, cudaHostRegisterDefault));
} }
void UnpinData(void* ptr) { void UnpinData(void* ptr) {
if (ptr == nullptr) if (ptr == nullptr) return;
return;
CUDA_CALL(cudaHostUnregister(ptr)); CUDA_CALL(cudaHostUnregister(ptr));
} }
bool IsPinned(const void* ptr) override { bool IsPinned(const void* ptr) override {
// can't be a pinned tensor if CUDA context is unavailable. // can't be a pinned tensor if CUDA context is unavailable.
if (!is_available_) if (!is_available_) return false;
return false;
cudaPointerAttributes attr; cudaPointerAttributes attr;
cudaError_t status = cudaPointerGetAttributes(&attr, ptr); cudaError_t status = cudaPointerGetAttributes(&attr, ptr);
bool result = false; bool result = false;
switch (status) { switch (status) {
case cudaErrorInvalidValue: case cudaErrorInvalidValue:
// might be a normal CPU tensor in CUDA 10.2- // might be a normal CPU tensor in CUDA 10.2-
cudaGetLastError(); // clear error cudaGetLastError(); // clear error
break; break;
case cudaSuccess: case cudaSuccess:
result = (attr.type == cudaMemoryTypeHost); result = (attr.type == cudaMemoryTypeHost);
break; break;
case cudaErrorInitializationError: case cudaErrorInitializationError:
case cudaErrorNoDevice: case cudaErrorNoDevice:
case cudaErrorInsufficientDriver: case cudaErrorInsufficientDriver:
case cudaErrorInvalidDevice: case cudaErrorInvalidDevice:
// We don't want to fail in these particular cases since this function can be called // We don't want to fail in these particular cases since this function
// when users only want to run on CPU even if CUDA API is enabled, or in a forked // can be called when users only want to run on CPU even if CUDA API is
// subprocess where CUDA context cannot be initialized. So we just mark the CUDA // enabled, or in a forked subprocess where CUDA context cannot be
// context to unavailable and return. // initialized. So we just mark the CUDA context to unavailable and
is_available_ = false; // return.
cudaGetLastError(); // clear error is_available_ = false;
break; cudaGetLastError(); // clear error
default: break;
LOG(FATAL) << "error while determining memory status: " << cudaGetErrorString(status); default:
break; LOG(FATAL) << "error while determining memory status: "
<< cudaGetErrorString(status);
break;
} }
return result; return result;
} }
void* AllocWorkspace(DGLContext ctx, size_t size, DGLDataType type_hint) final { void* AllocWorkspace(
DGLContext ctx, size_t size, DGLDataType type_hint) final {
SetDevice(ctx); SetDevice(ctx);
// Redirect to PyTorch's allocator when available. // Redirect to PyTorch's allocator when available.
TensorDispatcher* td = TensorDispatcher::Global(); TensorDispatcher* td = TensorDispatcher::Global();
...@@ -282,8 +273,7 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -282,8 +273,7 @@ class CUDADeviceAPI final : public DeviceAPI {
void FreeWorkspace(DGLContext ctx, void* data) final { void FreeWorkspace(DGLContext ctx, void* data) final {
SetDevice(ctx); SetDevice(ctx);
TensorDispatcher* td = TensorDispatcher::Global(); TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable()) if (td->IsAvailable()) return td->CUDAFreeWorkspace(data);
return td->CUDAFreeWorkspace(data);
CUDAThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data); CUDAThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
} }
...@@ -295,14 +285,13 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -295,14 +285,13 @@ class CUDADeviceAPI final : public DeviceAPI {
} }
private: private:
static void GPUCopy(const void* from, static void GPUCopy(
void* to, const void* from, void* to, size_t size, cudaMemcpyKind kind,
size_t size, cudaStream_t stream) {
cudaMemcpyKind kind,
cudaStream_t stream) {
CUDA_CALL(cudaMemcpyAsync(to, from, size, kind, stream)); CUDA_CALL(cudaMemcpyAsync(to, from, size, kind, stream));
if (stream == 0 && kind == cudaMemcpyDeviceToHost) { if (stream == 0 && kind == cudaMemcpyDeviceToHost) {
// only wait for the copy, when it's on the default stream, and it's to host memory // only wait for the copy, when it's on the default stream, and it's to
// host memory
CUDA_CALL(cudaStreamSynchronize(stream)); CUDA_CALL(cudaStreamSynchronize(stream));
} }
} }
...@@ -312,9 +301,7 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -312,9 +301,7 @@ class CUDADeviceAPI final : public DeviceAPI {
typedef dmlc::ThreadLocalStore<CUDAThreadEntry> CUDAThreadStore; typedef dmlc::ThreadLocalStore<CUDAThreadEntry> CUDAThreadStore;
CUDAThreadEntry::CUDAThreadEntry() CUDAThreadEntry::CUDAThreadEntry() : pool(kDGLCUDA, CUDADeviceAPI::Global()) {}
: pool(kDGLCUDA, CUDADeviceAPI::Global()) {
}
CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() { CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() {
return CUDAThreadStore::Get(); return CUDAThreadStore::Get();
...@@ -329,10 +316,10 @@ cudaStream_t getCurrentCUDAStream() { ...@@ -329,10 +316,10 @@ cudaStream_t getCurrentCUDAStream() {
} }
DGL_REGISTER_GLOBAL("device_api.cuda") DGL_REGISTER_GLOBAL("device_api.cuda")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
DeviceAPI* ptr = CUDADeviceAPI::Global().get(); DeviceAPI* ptr = CUDADeviceAPI::Global().get();
*rv = static_cast<void*>(ptr); *rv = static_cast<void*>(ptr);
}); });
} // namespace runtime } // namespace runtime
} // namespace dgl } // namespace dgl
...@@ -6,10 +6,10 @@ ...@@ -6,10 +6,10 @@
#include <cassert> #include <cassert>
#include "cuda_common.h"
#include "cuda_hashtable.cuh"
#include "../../array/cuda/atomic.cuh" #include "../../array/cuda/atomic.cuh"
#include "../../array/cuda/dgl_cub.cuh" #include "../../array/cuda/dgl_cub.cuh"
#include "cuda_common.h"
#include "cuda_hashtable.cuh"
using namespace dgl::aten::cuda; using namespace dgl::aten::cuda;
...@@ -23,64 +23,62 @@ constexpr static const int BLOCK_SIZE = 256; ...@@ -23,64 +23,62 @@ constexpr static const int BLOCK_SIZE = 256;
constexpr static const size_t TILE_SIZE = 1024; constexpr static const size_t TILE_SIZE = 1024;
/** /**
* @brief This is the mutable version of the DeviceOrderedHashTable, for use in * @brief This is the mutable version of the DeviceOrderedHashTable, for use in
* inserting elements into the hashtable. * inserting elements into the hashtable.
* *
* @tparam IdType The type of ID to store in the hashtable. * @tparam IdType The type of ID to store in the hashtable.
*/ */
template<typename IdType> template <typename IdType>
class MutableDeviceOrderedHashTable : public DeviceOrderedHashTable<IdType> { class MutableDeviceOrderedHashTable : public DeviceOrderedHashTable<IdType> {
public: public:
typedef typename DeviceOrderedHashTable<IdType>::Mapping* Iterator; typedef typename DeviceOrderedHashTable<IdType>::Mapping* Iterator;
static constexpr IdType kEmptyKey = DeviceOrderedHashTable<IdType>::kEmptyKey; static constexpr IdType kEmptyKey = DeviceOrderedHashTable<IdType>::kEmptyKey;
/** /**
* @brief Create a new mutable hashtable for use on the device. * @brief Create a new mutable hashtable for use on the device.
* *
* @param hostTable The original hash table on the host. * @param hostTable The original hash table on the host.
*/ */
explicit MutableDeviceOrderedHashTable( explicit MutableDeviceOrderedHashTable(
OrderedHashTable<IdType>* const hostTable) : OrderedHashTable<IdType>* const hostTable)
DeviceOrderedHashTable<IdType>(hostTable->DeviceHandle()) { : DeviceOrderedHashTable<IdType>(hostTable->DeviceHandle()) {}
}
/** /**
* @brief Find the mutable mapping of a given key within the hash table. * @brief Find the mutable mapping of a given key within the hash table.
* *
* WARNING: The key must exist within the hashtable. Searching for a key not * WARNING: The key must exist within the hashtable. Searching for a key not
* in the hashtable is undefined behavior. * in the hashtable is undefined behavior.
* *
* @param id The key to search for. * @param id The key to search for.
* *
* @return The mapping. * @return The mapping.
*/ */
inline __device__ Iterator Search( inline __device__ Iterator Search(const IdType id) {
const IdType id) {
const IdType pos = SearchForPosition(id); const IdType pos = SearchForPosition(id);
return GetMutable(pos); return GetMutable(pos);
} }
/** /**
* \brief Attempt to insert into the hash table at a specific location. * \brief Attempt to insert into the hash table at a specific location.
* *
* \param pos The position to insert at. * \param pos The position to insert at.
* \param id The ID to insert into the hash table. * \param id The ID to insert into the hash table.
* \param index The original index of the item being inserted. * \param index The original index of the item being inserted.
* *
* \return True, if the insertion was successful. * \return True, if the insertion was successful.
*/ */
inline __device__ bool AttemptInsertAt( inline __device__ bool AttemptInsertAt(
const size_t pos, const size_t pos, const IdType id, const size_t index) {
const IdType id,
const size_t index) {
const IdType key = AtomicCAS(&GetMutable(pos)->key, kEmptyKey, id); const IdType key = AtomicCAS(&GetMutable(pos)->key, kEmptyKey, id);
if (key == kEmptyKey || key == id) { if (key == kEmptyKey || key == id) {
// we either set a match key, or found a matching key, so then place the // we either set a match key, or found a matching key, so then place the
// minimum index in position. Match the type of atomicMin, so ignore // minimum index in position. Match the type of atomicMin, so ignore
// linting // linting
atomicMin(reinterpret_cast<unsigned long long*>(&GetMutable(pos)->index), // NOLINT atomicMin(
static_cast<unsigned long long>(index)); // NOLINT reinterpret_cast<unsigned long long*>( // NOLINT
&GetMutable(pos)->index),
static_cast<unsigned long long>(index)); // NOLINT
return true; return true;
} else { } else {
// we need to search elsewhere // we need to search elsewhere
...@@ -89,23 +87,21 @@ class MutableDeviceOrderedHashTable : public DeviceOrderedHashTable<IdType> { ...@@ -89,23 +87,21 @@ class MutableDeviceOrderedHashTable : public DeviceOrderedHashTable<IdType> {
} }
/** /**
* @brief Insert key-index pair into the hashtable. * @brief Insert key-index pair into the hashtable.
* *
* @param id The ID to insert. * @param id The ID to insert.
* @param index The index at which the ID occured. * @param index The index at which the ID occured.
* *
* @return An iterator to inserted mapping. * @return An iterator to inserted mapping.
*/ */
inline __device__ Iterator Insert( inline __device__ Iterator Insert(const IdType id, const size_t index) {
const IdType id,
const size_t index) {
size_t pos = Hash(id); size_t pos = Hash(id);
// linearly scan for an empty slot or matching entry // linearly scan for an empty slot or matching entry
IdType delta = 1; IdType delta = 1;
while (!AttemptInsertAt(pos, id, index)) { while (!AttemptInsertAt(pos, id, index)) {
pos = Hash(pos+delta); pos = Hash(pos + delta);
delta +=1; delta += 1;
} }
return GetMutable(pos); return GetMutable(pos);
...@@ -113,88 +109,84 @@ class MutableDeviceOrderedHashTable : public DeviceOrderedHashTable<IdType> { ...@@ -113,88 +109,84 @@ class MutableDeviceOrderedHashTable : public DeviceOrderedHashTable<IdType> {
private: private:
/** /**
* @brief Get a mutable iterator to the given bucket in the hashtable. * @brief Get a mutable iterator to the given bucket in the hashtable.
* *
* @param pos The given bucket. * @param pos The given bucket.
* *
* @return The iterator. * @return The iterator.
*/ */
inline __device__ Iterator GetMutable(const size_t pos) { inline __device__ Iterator GetMutable(const size_t pos) {
assert(pos < this->size_); assert(pos < this->size_);
// The parent class Device is read-only, but we ensure this can only be // The parent class Device is read-only, but we ensure this can only be
// constructed from a mutable version of OrderedHashTable, making this // constructed from a mutable version of OrderedHashTable, making this
// a safe cast to perform. // a safe cast to perform.
return const_cast<Iterator>(this->table_+pos); return const_cast<Iterator>(this->table_ + pos);
} }
}; };
/** /**
* @brief Calculate the number of buckets in the hashtable. To guarantee we can * @brief Calculate the number of buckets in the hashtable. To guarantee we can
* fill the hashtable in the worst case, we must use a number of buckets which * fill the hashtable in the worst case, we must use a number of buckets which
* is a power of two. * is a power of two.
* https://en.wikipedia.org/wiki/Quadratic_probing#Limitations * https://en.wikipedia.org/wiki/Quadratic_probing#Limitations
* *
* @param num The number of items to insert (should be an upper bound on the * @param num The number of items to insert (should be an upper bound on the
* number of unique keys). * number of unique keys).
* @param scale The power of two larger the number of buckets should be than the * @param scale The power of two larger the number of buckets should be than the
* unique keys. * unique keys.
* *
* @return The number of buckets the table should contain. * @return The number of buckets the table should contain.
*/ */
size_t TableSize( size_t TableSize(const size_t num, const int scale) {
const size_t num,
const int scale) {
const size_t next_pow2 = 1 << static_cast<size_t>(1 + std::log2(num >> 1)); const size_t next_pow2 = 1 << static_cast<size_t>(1 + std::log2(num >> 1));
return next_pow2 << scale; return next_pow2 << scale;
} }
/** /**
* @brief This structure is used with cub's block-level prefixscan in order to * @brief This structure is used with cub's block-level prefixscan in order to
* keep a running sum as items are iteratively processed. * keep a running sum as items are iteratively processed.
* *
* @tparam IdType The type to perform the prefixsum on. * @tparam IdType The type to perform the prefixsum on.
*/ */
template<typename IdType> template <typename IdType>
struct BlockPrefixCallbackOp { struct BlockPrefixCallbackOp {
IdType running_total_; IdType running_total_;
__device__ BlockPrefixCallbackOp( __device__ BlockPrefixCallbackOp(const IdType running_total)
const IdType running_total) : : running_total_(running_total) {}
running_total_(running_total) {
}
__device__ IdType operator()(const IdType block_aggregate) { __device__ IdType operator()(const IdType block_aggregate) {
const IdType old_prefix = running_total_; const IdType old_prefix = running_total_;
running_total_ += block_aggregate; running_total_ += block_aggregate;
return old_prefix; return old_prefix;
} }
}; };
} // namespace } // namespace
/** /**
* \brief This generates a hash map where the keys are the global item numbers, * \brief This generates a hash map where the keys are the global item numbers,
* and the values are indexes, and inputs may have duplciates. * and the values are indexes, and inputs may have duplciates.
* *
* \tparam IdType The type of of id. * \tparam IdType The type of of id.
* \tparam BLOCK_SIZE The size of the thread block. * \tparam BLOCK_SIZE The size of the thread block.
* \tparam TILE_SIZE The number of entries each thread block will process. * \tparam TILE_SIZE The number of entries each thread block will process.
* \param items The items to insert. * \param items The items to insert.
* \param num_items The number of items to insert. * \param num_items The number of items to insert.
* \param table The hash table. * \param table The hash table.
*/ */
template<typename IdType, int BLOCK_SIZE, size_t TILE_SIZE> template <typename IdType, int BLOCK_SIZE, size_t TILE_SIZE>
__global__ void generate_hashmap_duplicates( __global__ void generate_hashmap_duplicates(
const IdType * const items, const IdType* const items, const int64_t num_items,
const int64_t num_items,
MutableDeviceOrderedHashTable<IdType> table) { MutableDeviceOrderedHashTable<IdType> table) {
assert(BLOCK_SIZE == blockDim.x); assert(BLOCK_SIZE == blockDim.x);
const size_t block_start = TILE_SIZE*blockIdx.x; const size_t block_start = TILE_SIZE * blockIdx.x;
const size_t block_end = TILE_SIZE*(blockIdx.x+1); const size_t block_end = TILE_SIZE * (blockIdx.x + 1);
#pragma unroll #pragma unroll
for (size_t index = threadIdx.x + block_start; index < block_end; index += BLOCK_SIZE) { for (size_t index = threadIdx.x + block_start; index < block_end;
index += BLOCK_SIZE) {
if (index < num_items) { if (index < num_items) {
table.Insert(items[index], index); table.Insert(items[index], index);
} }
...@@ -202,30 +194,30 @@ __global__ void generate_hashmap_duplicates( ...@@ -202,30 +194,30 @@ __global__ void generate_hashmap_duplicates(
} }
/** /**
* \brief This generates a hash map where the keys are the global item numbers, * \brief This generates a hash map where the keys are the global item numbers,
* and the values are indexes, and all inputs are unique. * and the values are indexes, and all inputs are unique.
* *
* \tparam IdType The type of of id. * \tparam IdType The type of of id.
* \tparam BLOCK_SIZE The size of the thread block. * \tparam BLOCK_SIZE The size of the thread block.
* \tparam TILE_SIZE The number of entries each thread block will process. * \tparam TILE_SIZE The number of entries each thread block will process.
* \param items The unique items to insert. * \param items The unique items to insert.
* \param num_items The number of items to insert. * \param num_items The number of items to insert.
* \param table The hash table. * \param table The hash table.
*/ */
template<typename IdType, int BLOCK_SIZE, size_t TILE_SIZE> template <typename IdType, int BLOCK_SIZE, size_t TILE_SIZE>
__global__ void generate_hashmap_unique( __global__ void generate_hashmap_unique(
const IdType * const items, const IdType* const items, const int64_t num_items,
const int64_t num_items,
MutableDeviceOrderedHashTable<IdType> table) { MutableDeviceOrderedHashTable<IdType> table) {
assert(BLOCK_SIZE == blockDim.x); assert(BLOCK_SIZE == blockDim.x);
using Iterator = typename MutableDeviceOrderedHashTable<IdType>::Iterator; using Iterator = typename MutableDeviceOrderedHashTable<IdType>::Iterator;
const size_t block_start = TILE_SIZE*blockIdx.x; const size_t block_start = TILE_SIZE * blockIdx.x;
const size_t block_end = TILE_SIZE*(blockIdx.x+1); const size_t block_end = TILE_SIZE * (blockIdx.x + 1);
#pragma unroll #pragma unroll
for (size_t index = threadIdx.x + block_start; index < block_end; index += BLOCK_SIZE) { for (size_t index = threadIdx.x + block_start; index < block_end;
index += BLOCK_SIZE) {
if (index < num_items) { if (index < num_items) {
const Iterator pos = table.Insert(items[index], index); const Iterator pos = table.Insert(items[index], index);
...@@ -237,35 +229,34 @@ __global__ void generate_hashmap_unique( ...@@ -237,35 +229,34 @@ __global__ void generate_hashmap_unique(
} }
/** /**
* \brief This counts the number of nodes inserted per thread block. * \brief This counts the number of nodes inserted per thread block.
* *
* \tparam IdType The type of of id. * \tparam IdType The type of of id.
* \tparam BLOCK_SIZE The size of the thread block. * \tparam BLOCK_SIZE The size of the thread block.
* \tparam TILE_SIZE The number of entries each thread block will process. * \tparam TILE_SIZE The number of entries each thread block will process.
* \param input The nodes to insert. * \param input The nodes to insert.
* \param num_input The number of nodes to insert. * \param num_input The number of nodes to insert.
* \param table The hash table. * \param table The hash table.
* \param num_unique The number of nodes inserted into the hash table per thread * \param num_unique The number of nodes inserted into the hash table per thread
* block. * block.
*/ */
template<typename IdType, int BLOCK_SIZE, size_t TILE_SIZE> template <typename IdType, int BLOCK_SIZE, size_t TILE_SIZE>
__global__ void count_hashmap( __global__ void count_hashmap(
const IdType * items, const IdType* items, const size_t num_items,
const size_t num_items, DeviceOrderedHashTable<IdType> table, IdType* const num_unique) {
DeviceOrderedHashTable<IdType> table,
IdType * const num_unique) {
assert(BLOCK_SIZE == blockDim.x); assert(BLOCK_SIZE == blockDim.x);
using BlockReduce = typename cub::BlockReduce<IdType, BLOCK_SIZE>; using BlockReduce = typename cub::BlockReduce<IdType, BLOCK_SIZE>;
using Mapping = typename DeviceOrderedHashTable<IdType>::Mapping; using Mapping = typename DeviceOrderedHashTable<IdType>::Mapping;
const size_t block_start = TILE_SIZE*blockIdx.x; const size_t block_start = TILE_SIZE * blockIdx.x;
const size_t block_end = TILE_SIZE*(blockIdx.x+1); const size_t block_end = TILE_SIZE * (blockIdx.x + 1);
IdType count = 0; IdType count = 0;
#pragma unroll #pragma unroll
for (size_t index = threadIdx.x + block_start; index < block_end; index += BLOCK_SIZE) { for (size_t index = threadIdx.x + block_start; index < block_end;
index += BLOCK_SIZE) {
if (index < num_items) { if (index < num_items) {
const Mapping& mapping = *table.Search(items[index]); const Mapping& mapping = *table.Search(items[index]);
if (mapping.index == index) { if (mapping.index == index) {
...@@ -286,29 +277,26 @@ __global__ void count_hashmap( ...@@ -286,29 +277,26 @@ __global__ void count_hashmap(
} }
} }
/** /**
* \brief Update the local numbering of elements in the hashmap. * \brief Update the local numbering of elements in the hashmap.
* *
* \tparam IdType The type of id. * \tparam IdType The type of id.
* \tparam BLOCK_SIZE The size of the thread blocks. * \tparam BLOCK_SIZE The size of the thread blocks.
* \tparam TILE_SIZE The number of elements each thread block works on. * \tparam TILE_SIZE The number of elements each thread block works on.
* \param items The set of non-unique items to update from. * \param items The set of non-unique items to update from.
* \param num_items The number of non-unique items. * \param num_items The number of non-unique items.
* \param table The hash table. * \param table The hash table.
* \param num_items_prefix The number of unique items preceding each thread * \param num_items_prefix The number of unique items preceding each thread
* block. * block.
* \param unique_items The set of unique items (output). * \param unique_items The set of unique items (output).
* \param num_unique_items The number of unique items (output). * \param num_unique_items The number of unique items (output).
*/ */
template<typename IdType, int BLOCK_SIZE, size_t TILE_SIZE> template <typename IdType, int BLOCK_SIZE, size_t TILE_SIZE>
__global__ void compact_hashmap( __global__ void compact_hashmap(
const IdType * const items, const IdType* const items, const size_t num_items,
const size_t num_items,
MutableDeviceOrderedHashTable<IdType> table, MutableDeviceOrderedHashTable<IdType> table,
const IdType * const num_items_prefix, const IdType* const num_items_prefix, IdType* const unique_items,
IdType * const unique_items, int64_t* const num_unique_items) {
int64_t * const num_unique_items) {
assert(BLOCK_SIZE == blockDim.x); assert(BLOCK_SIZE == blockDim.x);
using FlagType = uint16_t; using FlagType = uint16_t;
...@@ -325,10 +313,10 @@ __global__ void compact_hashmap( ...@@ -325,10 +313,10 @@ __global__ void compact_hashmap(
// count successful placements // count successful placements
for (int32_t i = 0; i < VALS_PER_THREAD; ++i) { for (int32_t i = 0; i < VALS_PER_THREAD; ++i) {
const IdType index = threadIdx.x + i*BLOCK_SIZE + blockIdx.x*TILE_SIZE; const IdType index = threadIdx.x + i * BLOCK_SIZE + blockIdx.x * TILE_SIZE;
FlagType flag; FlagType flag;
Mapping * kv; Mapping* kv;
if (index < num_items) { if (index < num_items) {
kv = table.Search(items[index]); kv = table.Search(items[index]);
flag = kv->index == index; flag = kv->index == index;
...@@ -344,7 +332,7 @@ __global__ void compact_hashmap( ...@@ -344,7 +332,7 @@ __global__ void compact_hashmap(
__syncthreads(); __syncthreads();
if (kv) { if (kv) {
const IdType pos = offset+flag; const IdType pos = offset + flag;
kv->local = pos; kv->local = pos;
unique_items[pos] = items[index]; unique_items[pos] = items[index];
} }
...@@ -357,128 +345,94 @@ __global__ void compact_hashmap( ...@@ -357,128 +345,94 @@ __global__ void compact_hashmap(
// DeviceOrderedHashTable implementation // DeviceOrderedHashTable implementation
template<typename IdType> template <typename IdType>
DeviceOrderedHashTable<IdType>::DeviceOrderedHashTable( DeviceOrderedHashTable<IdType>::DeviceOrderedHashTable(
const Mapping* const table, const Mapping* const table, const size_t size)
const size_t size) : : table_(table), size_(size) {}
table_(table),
size_(size) {
}
template<typename IdType> template <typename IdType>
DeviceOrderedHashTable<IdType> OrderedHashTable<IdType>::DeviceHandle() const { DeviceOrderedHashTable<IdType> OrderedHashTable<IdType>::DeviceHandle() const {
return DeviceOrderedHashTable<IdType>(table_, size_); return DeviceOrderedHashTable<IdType>(table_, size_);
} }
// OrderedHashTable implementation // OrderedHashTable implementation
template<typename IdType> template <typename IdType>
OrderedHashTable<IdType>::OrderedHashTable( OrderedHashTable<IdType>::OrderedHashTable(
const size_t size, const size_t size, DGLContext ctx, cudaStream_t stream, const int scale)
DGLContext ctx, : table_(nullptr), size_(TableSize(size, scale)), ctx_(ctx) {
cudaStream_t stream,
const int scale) :
table_(nullptr),
size_(TableSize(size, scale)),
ctx_(ctx) {
// make sure we will at least as many buckets as items. // make sure we will at least as many buckets as items.
CHECK_GT(scale, 0); CHECK_GT(scale, 0);
auto device = runtime::DeviceAPI::Get(ctx_); auto device = runtime::DeviceAPI::Get(ctx_);
table_ = static_cast<Mapping*>( table_ = static_cast<Mapping*>(
device->AllocWorkspace(ctx_, sizeof(Mapping)*size_)); device->AllocWorkspace(ctx_, sizeof(Mapping) * size_));
CUDA_CALL(cudaMemsetAsync( CUDA_CALL(cudaMemsetAsync(
table_, table_, DeviceOrderedHashTable<IdType>::kEmptyKey,
DeviceOrderedHashTable<IdType>::kEmptyKey, sizeof(Mapping) * size_, stream));
sizeof(Mapping)*size_,
stream));
} }
template<typename IdType> template <typename IdType>
OrderedHashTable<IdType>::~OrderedHashTable() { OrderedHashTable<IdType>::~OrderedHashTable() {
auto device = runtime::DeviceAPI::Get(ctx_); auto device = runtime::DeviceAPI::Get(ctx_);
device->FreeWorkspace(ctx_, table_); device->FreeWorkspace(ctx_, table_);
} }
template<typename IdType> template <typename IdType>
void OrderedHashTable<IdType>::FillWithDuplicates( void OrderedHashTable<IdType>::FillWithDuplicates(
const IdType * const input, const IdType* const input, const size_t num_input, IdType* const unique,
const size_t num_input, int64_t* const num_unique, cudaStream_t stream) {
IdType * const unique,
int64_t * const num_unique,
cudaStream_t stream) {
auto device = runtime::DeviceAPI::Get(ctx_); auto device = runtime::DeviceAPI::Get(ctx_);
const int64_t num_tiles = (num_input+TILE_SIZE-1)/TILE_SIZE; const int64_t num_tiles = (num_input + TILE_SIZE - 1) / TILE_SIZE;
const dim3 grid(num_tiles); const dim3 grid(num_tiles);
const dim3 block(BLOCK_SIZE); const dim3 block(BLOCK_SIZE);
auto device_table = MutableDeviceOrderedHashTable<IdType>(this); auto device_table = MutableDeviceOrderedHashTable<IdType>(this);
CUDA_KERNEL_CALL((generate_hashmap_duplicates<IdType, BLOCK_SIZE, TILE_SIZE>), CUDA_KERNEL_CALL(
grid, block, 0, stream, (generate_hashmap_duplicates<IdType, BLOCK_SIZE, TILE_SIZE>), grid, block,
input, 0, stream, input, num_input, device_table);
num_input,
device_table);
IdType * item_prefix = static_cast<IdType*>( IdType* item_prefix = static_cast<IdType*>(
device->AllocWorkspace(ctx_, sizeof(IdType)*(num_input+1))); device->AllocWorkspace(ctx_, sizeof(IdType) * (num_input + 1)));
CUDA_KERNEL_CALL((count_hashmap<IdType, BLOCK_SIZE, TILE_SIZE>), CUDA_KERNEL_CALL(
grid, block, 0, stream, (count_hashmap<IdType, BLOCK_SIZE, TILE_SIZE>), grid, block, 0, stream,
input, input, num_input, device_table, item_prefix);
num_input,
device_table,
item_prefix);
size_t workspace_bytes; size_t workspace_bytes;
CUDA_CALL(cub::DeviceScan::ExclusiveSum( CUDA_CALL(cub::DeviceScan::ExclusiveSum(
nullptr, nullptr, workspace_bytes, static_cast<IdType*>(nullptr),
workspace_bytes, static_cast<IdType*>(nullptr), grid.x + 1, stream));
static_cast<IdType*>(nullptr), void* workspace = device->AllocWorkspace(ctx_, workspace_bytes);
static_cast<IdType*>(nullptr),
grid.x+1, stream));
void * workspace = device->AllocWorkspace(ctx_, workspace_bytes);
CUDA_CALL(cub::DeviceScan::ExclusiveSum( CUDA_CALL(cub::DeviceScan::ExclusiveSum(
workspace, workspace, workspace_bytes, item_prefix, item_prefix, grid.x + 1,
workspace_bytes, stream));
item_prefix,
item_prefix,
grid.x+1, stream));
device->FreeWorkspace(ctx_, workspace); device->FreeWorkspace(ctx_, workspace);
CUDA_KERNEL_CALL((compact_hashmap<IdType, BLOCK_SIZE, TILE_SIZE>), CUDA_KERNEL_CALL(
grid, block, 0, stream, (compact_hashmap<IdType, BLOCK_SIZE, TILE_SIZE>), grid, block, 0, stream,
input, input, num_input, device_table, item_prefix, unique, num_unique);
num_input,
device_table,
item_prefix,
unique,
num_unique);
device->FreeWorkspace(ctx_, item_prefix); device->FreeWorkspace(ctx_, item_prefix);
} }
template<typename IdType> template <typename IdType>
void OrderedHashTable<IdType>::FillWithUnique( void OrderedHashTable<IdType>::FillWithUnique(
const IdType * const input, const IdType* const input, const size_t num_input, cudaStream_t stream) {
const size_t num_input, const int64_t num_tiles = (num_input + TILE_SIZE - 1) / TILE_SIZE;
cudaStream_t stream) {
const int64_t num_tiles = (num_input+TILE_SIZE-1)/TILE_SIZE;
const dim3 grid(num_tiles); const dim3 grid(num_tiles);
const dim3 block(BLOCK_SIZE); const dim3 block(BLOCK_SIZE);
auto device_table = MutableDeviceOrderedHashTable<IdType>(this); auto device_table = MutableDeviceOrderedHashTable<IdType>(this);
CUDA_KERNEL_CALL((generate_hashmap_unique<IdType, BLOCK_SIZE, TILE_SIZE>), CUDA_KERNEL_CALL(
grid, block, 0, stream, (generate_hashmap_unique<IdType, BLOCK_SIZE, TILE_SIZE>), grid, block, 0,
input, stream, input, num_input, device_table);
num_input,
device_table);
} }
template class OrderedHashTable<int32_t>; template class OrderedHashTable<int32_t>;
......
...@@ -9,14 +9,14 @@ ...@@ -9,14 +9,14 @@
#include <dgl/runtime/c_runtime_api.h> #include <dgl/runtime/c_runtime_api.h>
#include "cuda_runtime.h"
#include "cuda_common.h" #include "cuda_common.h"
#include "cuda_runtime.h"
namespace dgl { namespace dgl {
namespace runtime { namespace runtime {
namespace cuda { namespace cuda {
template<typename> template <typename>
class OrderedHashTable; class OrderedHashTable;
/*! /*!
...@@ -31,7 +31,7 @@ class OrderedHashTable; ...@@ -31,7 +31,7 @@ class OrderedHashTable;
* used. * used.
* *
* The hash table should be used in two phases, with the first being populating * The hash table should be used in two phases, with the first being populating
* the hash table with the OrderedHashTable object, and then generating this * the hash table with the OrderedHashTable object, and then generating this
* handle from it. This object can then be used to search the hash table, * handle from it. This object can then be used to search the hash table,
* to find mappings, from with CUDA code. * to find mappings, from with CUDA code.
* *
...@@ -62,7 +62,7 @@ class OrderedHashTable; ...@@ -62,7 +62,7 @@ class OrderedHashTable;
* *
* \tparam IdType The type of the IDs. * \tparam IdType The type of the IDs.
*/ */
template<typename IdType> template <typename IdType>
class DeviceOrderedHashTable { class DeviceOrderedHashTable {
public: public:
/** /**
...@@ -80,16 +80,15 @@ class DeviceOrderedHashTable { ...@@ -80,16 +80,15 @@ class DeviceOrderedHashTable {
/** /**
* \brief The index of the item when inserted into the hashtable (e.g., * \brief The index of the item when inserted into the hashtable (e.g.,
* the index within the array passed into FillWithDuplicates()). * the index within the array passed into FillWithDuplicates()).
*/ */
int64_t index; int64_t index;
}; };
typedef const Mapping* ConstIterator; typedef const Mapping* ConstIterator;
DeviceOrderedHashTable( DeviceOrderedHashTable(const DeviceOrderedHashTable& other) = default;
const DeviceOrderedHashTable& other) = default; DeviceOrderedHashTable& operator=(const DeviceOrderedHashTable& other) =
DeviceOrderedHashTable& operator=( default;
const DeviceOrderedHashTable& other) = default;
/** /**
* \brief Find the non-mutable mapping of a given key within the hash table. * \brief Find the non-mutable mapping of a given key within the hash table.
...@@ -101,8 +100,7 @@ class DeviceOrderedHashTable { ...@@ -101,8 +100,7 @@ class DeviceOrderedHashTable {
* *
* \return An iterator to the mapping. * \return An iterator to the mapping.
*/ */
inline __device__ ConstIterator Search( inline __device__ ConstIterator Search(const IdType id) const {
const IdType id) const {
const IdType pos = SearchForPosition(id); const IdType pos = SearchForPosition(id);
return &table_[pos]; return &table_[pos];
...@@ -115,8 +113,7 @@ class DeviceOrderedHashTable { ...@@ -115,8 +113,7 @@ class DeviceOrderedHashTable {
* *
* \return True if the key exists in the hashtable. * \return True if the key exists in the hashtable.
*/ */
inline __device__ bool Contains( inline __device__ bool Contains(const IdType id) const {
const IdType id) const {
IdType pos = Hash(id); IdType pos = Hash(id);
IdType delta = 1; IdType delta = 1;
...@@ -124,8 +121,8 @@ class DeviceOrderedHashTable { ...@@ -124,8 +121,8 @@ class DeviceOrderedHashTable {
if (table_[pos].key == id) { if (table_[pos].key == id) {
return true; return true;
} }
pos = Hash(pos+delta); pos = Hash(pos + delta);
delta +=1; delta += 1;
} }
return false; return false;
} }
...@@ -134,7 +131,7 @@ class DeviceOrderedHashTable { ...@@ -134,7 +131,7 @@ class DeviceOrderedHashTable {
// Must be uniform bytes for memset to work // Must be uniform bytes for memset to work
static constexpr IdType kEmptyKey = static_cast<IdType>(-1); static constexpr IdType kEmptyKey = static_cast<IdType>(-1);
const Mapping * table_; const Mapping* table_;
size_t size_; size_t size_;
/** /**
...@@ -143,9 +140,7 @@ class DeviceOrderedHashTable { ...@@ -143,9 +140,7 @@ class DeviceOrderedHashTable {
* \param table The table stored in GPU memory. * \param table The table stored in GPU memory.
* \param size The size of the table. * \param size The size of the table.
*/ */
explicit DeviceOrderedHashTable( explicit DeviceOrderedHashTable(const Mapping* table, size_t size);
const Mapping * table,
size_t size);
/** /**
* \brief Search for an item in the hash table which is known to exist. * \brief Search for an item in the hash table which is known to exist.
...@@ -157,16 +152,15 @@ class DeviceOrderedHashTable { ...@@ -157,16 +152,15 @@ class DeviceOrderedHashTable {
* *
* \return The the position of the item in the hashtable. * \return The the position of the item in the hashtable.
*/ */
inline __device__ IdType SearchForPosition( inline __device__ IdType SearchForPosition(const IdType id) const {
const IdType id) const {
IdType pos = Hash(id); IdType pos = Hash(id);
// linearly scan for matching entry // linearly scan for matching entry
IdType delta = 1; IdType delta = 1;
while (table_[pos].key != id) { while (table_[pos].key != id) {
assert(table_[pos].key != kEmptyKey); assert(table_[pos].key != kEmptyKey);
pos = Hash(pos+delta); pos = Hash(pos + delta);
delta +=1; delta += 1;
} }
assert(pos < size_); assert(pos < size_);
...@@ -180,10 +174,7 @@ class DeviceOrderedHashTable { ...@@ -180,10 +174,7 @@ class DeviceOrderedHashTable {
* *
* \return The hash. * \return The hash.
*/ */
inline __device__ size_t Hash( inline __device__ size_t Hash(const IdType id) const { return id % size_; }
const IdType id) const {
return id % size_;
}
friend class OrderedHashTable<IdType>; friend class OrderedHashTable<IdType>;
}; };
...@@ -219,7 +210,7 @@ class DeviceOrderedHashTable { ...@@ -219,7 +210,7 @@ class DeviceOrderedHashTable {
* *
* \tparam IdType The type of the IDs. * \tparam IdType The type of the IDs.
*/ */
template<typename IdType> template <typename IdType>
class OrderedHashTable { class OrderedHashTable {
public: public:
static constexpr int kDefaultScale = 3; static constexpr int kDefaultScale = 3;
...@@ -237,9 +228,7 @@ class OrderedHashTable { ...@@ -237,9 +228,7 @@ class OrderedHashTable {
* \param stream The stream to use for initializing the hashtable. * \param stream The stream to use for initializing the hashtable.
*/ */
OrderedHashTable( OrderedHashTable(
const size_t size, const size_t size, DGLContext ctx, cudaStream_t stream,
DGLContext ctx,
cudaStream_t stream,
const int scale = kDefaultScale); const int scale = kDefaultScale);
/** /**
...@@ -248,10 +237,8 @@ class OrderedHashTable { ...@@ -248,10 +237,8 @@ class OrderedHashTable {
~OrderedHashTable(); ~OrderedHashTable();
// Disable copying // Disable copying
OrderedHashTable( OrderedHashTable(const OrderedHashTable& other) = delete;
const OrderedHashTable& other) = delete; OrderedHashTable& operator=(const OrderedHashTable& other) = delete;
OrderedHashTable& operator=(
const OrderedHashTable& other) = delete;
/** /**
* \brief Fill the hashtable with the array containing possibly duplicate * \brief Fill the hashtable with the array containing possibly duplicate
...@@ -264,11 +251,8 @@ class OrderedHashTable { ...@@ -264,11 +251,8 @@ class OrderedHashTable {
* \param stream The stream to perform operations on. * \param stream The stream to perform operations on.
*/ */
void FillWithDuplicates( void FillWithDuplicates(
const IdType * const input, const IdType* const input, const size_t num_input, IdType* const unique,
const size_t num_input, int64_t* const num_unique, cudaStream_t stream);
IdType * const unique,
int64_t * const num_unique,
cudaStream_t stream);
/** /**
* \brief Fill the hashtable with an array of unique keys. * \brief Fill the hashtable with an array of unique keys.
...@@ -278,9 +262,7 @@ class OrderedHashTable { ...@@ -278,9 +262,7 @@ class OrderedHashTable {
* \param stream The stream to perform operations on. * \param stream The stream to perform operations on.
*/ */
void FillWithUnique( void FillWithUnique(
const IdType * const input, const IdType* const input, const size_t num_input, cudaStream_t stream);
const size_t num_input,
cudaStream_t stream);
/** /**
* \brief Get a verison of the hashtable usable from device functions. * \brief Get a verison of the hashtable usable from device functions.
...@@ -290,12 +272,11 @@ class OrderedHashTable { ...@@ -290,12 +272,11 @@ class OrderedHashTable {
DeviceOrderedHashTable<IdType> DeviceHandle() const; DeviceOrderedHashTable<IdType> DeviceHandle() const;
private: private:
Mapping * table_; Mapping* table_;
size_t size_; size_t size_;
DGLContext ctx_; DGLContext ctx_;
}; };
} // namespace cuda } // namespace cuda
} // namespace runtime } // namespace runtime
} // namespace dgl } // namespace dgl
......
...@@ -17,42 +17,39 @@ ...@@ -17,42 +17,39 @@
* \brief Implementation of wrapper around NCCL routines. * \brief Implementation of wrapper around NCCL routines.
*/ */
#include <cuda_fp16.h>
#include "nccl_api.h" #include <cuda_runtime.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/aten/array_ops.h> #include <dgl/aten/array_ops.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h> #include <dgl/runtime/container.h>
#include <dgl/runtime/device_api.h> #include <dgl/runtime/device_api.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/registry.h> #include <dgl/runtime/registry.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <algorithm>
#include <cmath> #include <cmath>
#include <sstream>
#include <iomanip> #include <iomanip>
#include <utility> #include <limits>
#include <vector>
#include <memory> #include <memory>
#include <sstream>
#include <string> #include <string>
#include <algorithm> #include <utility>
#include <limits> #include <vector>
#include "cuda_common.h"
#include "../../runtime/workspace.h"
#include "../../partition/ndarray_partition.h"
#include "../../array/cuda/dgl_cub.cuh"
#include "../../array/cuda/array_index_select.cuh" #include "../../array/cuda/array_index_select.cuh"
#include "../../array/cuda/dgl_cub.cuh"
#include "../../partition/ndarray_partition.h"
#include "../../runtime/workspace.h"
#include "cuda_common.h"
#include "nccl_api.h"
#define NCCL_CALL(func) \ #define NCCL_CALL(func) \
{ \ { \
ncclResult_t result = func; \ ncclResult_t result = func; \
if (result != ncclSuccess) { \ if (result != ncclSuccess) { \
LOG(FATAL) \ LOG(FATAL) << "NCCLError: " #func " failed with error: " << result; \
<< "NCCLError: " #func " failed with error: " << result; \ } \
} \ }
}
namespace dgl { namespace dgl {
...@@ -65,36 +62,39 @@ namespace { ...@@ -65,36 +62,39 @@ namespace {
#ifdef DGL_USE_NCCL #ifdef DGL_USE_NCCL
template<typename T> ncclDataType_t NCCLType(); template <typename T>
template<> ncclDataType_t NCCLType<int32_t>() { ncclDataType_t NCCLType();
return ncclInt32; template <>
ncclDataType_t NCCLType<int32_t>() {
return ncclInt32;
} }
template<> ncclDataType_t NCCLType<int64_t>() { template <>
return ncclInt64; ncclDataType_t NCCLType<int64_t>() {
return ncclInt64;
} }
template<> ncclDataType_t NCCLType<__half>() { template <>
return ncclHalf; ncclDataType_t NCCLType<__half>() {
return ncclHalf;
} }
template<> ncclDataType_t NCCLType<float>() { template <>
return ncclFloat32; ncclDataType_t NCCLType<float>() {
return ncclFloat32;
} }
template<> ncclDataType_t NCCLType<double>() { template <>
return ncclFloat64; ncclDataType_t NCCLType<double>() {
return ncclFloat64;
} }
#endif // DGL_USE_NCCL #endif // DGL_USE_NCCL
template<typename IdType, typename DType> template <typename IdType, typename DType>
__global__ void _DualPermKernel( __global__ void _DualPermKernel(
const IdType * const in_idx, const IdType* const in_idx, const DType* const in_value,
const DType * const in_value, const IdType* const perm, const int64_t num_in, const int64_t num_feat,
const IdType * const perm, IdType* const out_idx, DType* const out_value) {
const int64_t num_in,
const int64_t num_feat,
IdType * const out_idx,
DType * const out_value) {
// set index permutation // set index permutation
const int64_t tidx = blockDim.x*static_cast<int64_t>(blockIdx.x)+threadIdx.x; const int64_t tidx =
blockDim.x * static_cast<int64_t>(blockIdx.x) + threadIdx.x;
if (tidx < num_in) { if (tidx < num_in) {
const IdType perm_idx = perm[tidx]; const IdType perm_idx = perm[tidx];
assert(perm_idx < num_in); assert(perm_idx < num_in);
...@@ -103,11 +103,11 @@ __global__ void _DualPermKernel( ...@@ -103,11 +103,11 @@ __global__ void _DualPermKernel(
if (num_feat > 1) { if (num_feat > 1) {
for (int d = 0; d < blockDim.x; ++d) { for (int d = 0; d < blockDim.x; ++d) {
const int64_t bidx = blockDim.x*static_cast<int64_t>(blockIdx.x) + d; const int64_t bidx = blockDim.x * static_cast<int64_t>(blockIdx.x) + d;
if (bidx < num_in) { if (bidx < num_in) {
const IdType perm_idx = perm[bidx]; const IdType perm_idx = perm[bidx];
for (int64_t f = threadIdx.x; f < num_feat; f+=blockDim.x) { for (int64_t f = threadIdx.x; f < num_feat; f += blockDim.x) {
out_value[bidx*num_feat+f] = in_value[perm_idx*num_feat+f]; out_value[bidx * num_feat + f] = in_value[perm_idx * num_feat + f];
} }
} }
} }
...@@ -121,48 +121,43 @@ __global__ void _DualPermKernel( ...@@ -121,48 +121,43 @@ __global__ void _DualPermKernel(
template <typename DType, typename IdType> template <typename DType, typename IdType>
__global__ void _InversePermKernel( __global__ void _InversePermKernel(
const DType* const array, const DType* const array, const int64_t num_feat, int64_t length,
const int64_t num_feat, const IdType* const perm, DType* const out) {
int64_t length, int64_t in_row = blockIdx.x * blockDim.y + threadIdx.y;
const IdType* const perm,
DType* const out) {
int64_t in_row = blockIdx.x*blockDim.y+threadIdx.y;
const int64_t stride = blockDim.y*gridDim.x; const int64_t stride = blockDim.y * gridDim.x;
while (in_row < length) { while (in_row < length) {
int64_t col = threadIdx.x; int64_t col = threadIdx.x;
const int64_t out_row = perm[in_row]; const int64_t out_row = perm[in_row];
while (col < num_feat) { while (col < num_feat) {
out[out_row*num_feat+col] = array[in_row*num_feat+col]; out[out_row * num_feat + col] = array[in_row * num_feat + col];
col += blockDim.x; col += blockDim.x;
} }
in_row += stride; in_row += stride;
} }
} }
template <typename IdType, typename DType>
template<typename IdType, typename DType>
std::pair<IdArray, NDArray> SparsePush( std::pair<IdArray, NDArray> SparsePush(
NCCLCommunicatorRef comm, NCCLCommunicatorRef comm, IdArray in_idx, NDArray in_value,
IdArray in_idx,
NDArray in_value,
NDArrayPartitionRef part) { NDArrayPartitionRef part) {
const auto& ctx = in_idx->ctx; const auto& ctx = in_idx->ctx;
CHECK_EQ(ctx, in_value->ctx) << "Indices and values must be on the same " CHECK_EQ(ctx, in_value->ctx) << "Indices and values must be on the same "
"device"; "device";
auto device = DeviceAPI::Get(ctx); auto device = DeviceAPI::Get(ctx);
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
CHECK_LE(in_idx->ndim, 1) << "The tensor of sending indices must be of " CHECK_LE(in_idx->ndim, 1) << "The tensor of sending indices must be of "
"dimension one (or empty)."; "dimension one (or empty).";
const int64_t num_in = in_idx->ndim > 0 ? in_idx->shape[0] : 0; const int64_t num_in = in_idx->ndim > 0 ? in_idx->shape[0] : 0;
CHECK_EQ(num_in, in_value->ndim > 0 ? in_value->shape[0] : 0) << CHECK_EQ(num_in, in_value->ndim > 0 ? in_value->shape[0] : 0)
"Leading dimension of indices (" << num_in << ") must match " << "Leading dimension of indices (" << num_in
"leading dimension of values (" << << ") must match "
(in_value->ndim > 0 ? in_value->shape[0] : 0) << ")."; "leading dimension of values ("
<< (in_value->ndim > 0 ? in_value->shape[0] : 0) << ").";
int64_t num_feat = 1; int64_t num_feat = 1;
for (int d = 1; d < in_value->ndim; ++d) { for (int d = 1; d < in_value->ndim; ++d) {
...@@ -177,91 +172,83 @@ std::pair<IdArray, NDArray> SparsePush( ...@@ -177,91 +172,83 @@ std::pair<IdArray, NDArray> SparsePush(
} }
std::pair<IdArray, NDArray> part_perm = part->GeneratePermutation(in_idx); std::pair<IdArray, NDArray> part_perm = part->GeneratePermutation(in_idx);
const IdType * const perm = static_cast<const IdType*>(part_perm.first->data); const IdType* const perm = static_cast<const IdType*>(part_perm.first->data);
const int64_t * const send_sum = const int64_t* const send_sum =
static_cast<const int64_t*>(part_perm.second->data); static_cast<const int64_t*>(part_perm.second->data);
Workspace<IdType> send_idx(device, ctx, num_in); Workspace<IdType> send_idx(device, ctx, num_in);
Workspace<DType> send_value(device, ctx, num_in*num_feat); Workspace<DType> send_value(device, ctx, num_in * num_feat);
// permute the indices and values // permute the indices and values
if (num_in > 0) { if (num_in > 0) {
const dim3 block(256); const dim3 block(256);
const dim3 grid((num_in+block.x-1)/block.x); const dim3 grid((num_in + block.x - 1) / block.x);
CUDA_KERNEL_CALL(_DualPermKernel, CUDA_KERNEL_CALL(
grid, block, 0, stream, _DualPermKernel, grid, block, 0, stream,
static_cast<const IdType*>(in_idx->data), static_cast<const IdType*>(in_idx->data),
static_cast<const DType*>(in_value->data), static_cast<const DType*>(in_value->data), perm, num_in, num_feat,
perm, send_idx.get(), send_value.get());
num_in,
num_feat,
send_idx.get(),
send_value.get());
} }
// compute the prefix sum of the send values // compute the prefix sum of the send values
Workspace<int64_t> send_prefix(device, ctx, comm_size+1); Workspace<int64_t> send_prefix(device, ctx, comm_size + 1);
{ {
size_t prefix_workspace_size; size_t prefix_workspace_size;
CUDA_CALL(cub::DeviceScan::ExclusiveSum(nullptr, prefix_workspace_size, CUDA_CALL(cub::DeviceScan::ExclusiveSum(
send_sum, send_prefix.get(), comm_size+1, stream)); nullptr, prefix_workspace_size, send_sum, send_prefix.get(),
comm_size + 1, stream));
Workspace<void> prefix_workspace(device, ctx, prefix_workspace_size); Workspace<void> prefix_workspace(device, ctx, prefix_workspace_size);
CUDA_CALL(cub::DeviceScan::ExclusiveSum(prefix_workspace.get(), CUDA_CALL(cub::DeviceScan::ExclusiveSum(
prefix_workspace_size, send_sum, send_prefix.get(), prefix_workspace.get(), prefix_workspace_size, send_sum,
comm_size+1, stream)); send_prefix.get(), comm_size + 1, stream));
} }
std::vector<int64_t> send_prefix_host(comm_size+1); std::vector<int64_t> send_prefix_host(comm_size + 1);
// copy using the same stream (local current stream), no need to sync // copy using the same stream (local current stream), no need to sync
device->CopyDataFromTo( device->CopyDataFromTo(
send_prefix.get(), send_prefix.get(), 0, send_prefix_host.data(), 0,
0, send_prefix_host.size() * sizeof(*send_prefix.get()), ctx,
send_prefix_host.data(),
0,
send_prefix_host.size()*sizeof(*send_prefix.get()),
ctx,
DGLContext{kDGLCPU, 0}, DGLContext{kDGLCPU, 0},
DGLDataType{kDGLInt, sizeof(*send_prefix.get())*8, 1}); DGLDataType{kDGLInt, sizeof(*send_prefix.get()) * 8, 1});
send_prefix.free(); send_prefix.free();
CHECK_EQ(send_prefix_host.back(), num_in) << "Internal Error: " CHECK_EQ(send_prefix_host.back(), num_in)
"send_prefix_host.back() = " << send_prefix_host.back() << << "Internal Error: "
", and num_in = " << num_in; "send_prefix_host.back() = "
<< send_prefix_host.back() << ", and num_in = " << num_in;
// communicate the amount to send // communicate the amount to send
Workspace<int64_t> recv_sum(device, ctx, comm_size+1); Workspace<int64_t> recv_sum(device, ctx, comm_size + 1);
comm->AllToAll(send_sum, recv_sum.get(), 1, stream); comm->AllToAll(send_sum, recv_sum.get(), 1, stream);
cudaEvent_t d2h; cudaEvent_t d2h;
CUDA_CALL(cudaEventCreate(&d2h)); CUDA_CALL(cudaEventCreate(&d2h));
// compute the prefix sum of the recv values // compute the prefix sum of the recv values
Workspace<int64_t> recv_prefix(device, ctx, comm_size+1); Workspace<int64_t> recv_prefix(device, ctx, comm_size + 1);
{ {
size_t prefix_workspace_size; size_t prefix_workspace_size;
CUDA_CALL(cub::DeviceScan::ExclusiveSum(nullptr, prefix_workspace_size, CUDA_CALL(cub::DeviceScan::ExclusiveSum(
recv_sum.get(), recv_prefix.get(), comm_size+1, stream)); nullptr, prefix_workspace_size, recv_sum.get(), recv_prefix.get(),
comm_size + 1, stream));
Workspace<void> prefix_workspace(device, ctx, prefix_workspace_size); Workspace<void> prefix_workspace(device, ctx, prefix_workspace_size);
CUDA_CALL(cub::DeviceScan::ExclusiveSum(prefix_workspace.get(), CUDA_CALL(cub::DeviceScan::ExclusiveSum(
prefix_workspace_size, recv_sum.get(), recv_prefix.get(), comm_size+1, stream)); prefix_workspace.get(), prefix_workspace_size, recv_sum.get(),
recv_prefix.get(), comm_size + 1, stream));
} }
recv_sum.free(); recv_sum.free();
// finally copy the prefixsum sum down to the host // finally copy the prefixsum sum down to the host
std::vector<int64_t> recv_prefix_host(comm_size+1); std::vector<int64_t> recv_prefix_host(comm_size + 1);
// copy using the same stream (local current stream), no need to sync // copy using the same stream (local current stream), no need to sync
device->CopyDataFromTo( device->CopyDataFromTo(
recv_prefix.get(), recv_prefix.get(), 0, recv_prefix_host.data(), 0,
0, recv_prefix_host.size() * sizeof(*recv_prefix.get()), ctx,
recv_prefix_host.data(),
0,
recv_prefix_host.size()*sizeof(*recv_prefix.get()),
ctx,
DGLContext{kDGLCPU, 0}, DGLContext{kDGLCPU, 0},
DGLDataType{kDGLInt, sizeof(*recv_prefix.get())*8, 1}); DGLDataType{kDGLInt, sizeof(*recv_prefix.get()) * 8, 1});
recv_prefix.free(); recv_prefix.free();
// use an event to track when copying is done // use an event to track when copying is done
...@@ -271,8 +258,8 @@ std::pair<IdArray, NDArray> SparsePush( ...@@ -271,8 +258,8 @@ std::pair<IdArray, NDArray> SparsePush(
CUDA_CALL(cudaEventSynchronize(d2h)); CUDA_CALL(cudaEventSynchronize(d2h));
CUDA_CALL(cudaEventDestroy(d2h)); CUDA_CALL(cudaEventDestroy(d2h));
IdArray recv_idx = aten::NewIdArray( IdArray recv_idx =
recv_prefix_host.back(), ctx, sizeof(IdType)*8); aten::NewIdArray(recv_prefix_host.back(), ctx, sizeof(IdType) * 8);
std::vector<int64_t> value_shape(in_value->ndim, 0); std::vector<int64_t> value_shape(in_value->ndim, 0);
value_shape[0] = recv_prefix_host.back(); value_shape[0] = recv_prefix_host.back();
...@@ -283,33 +270,26 @@ std::pair<IdArray, NDArray> SparsePush( ...@@ -283,33 +270,26 @@ std::pair<IdArray, NDArray> SparsePush(
// send data // send data
comm->SparseAllToAll( comm->SparseAllToAll(
send_idx.get(), send_idx.get(), send_value.get(), num_feat, send_prefix_host.data(),
send_value.get(),
num_feat,
send_prefix_host.data(),
static_cast<IdType*>(recv_idx->data), static_cast<IdType*>(recv_idx->data),
static_cast<DType*>(recv_value->data), static_cast<DType*>(recv_value->data), recv_prefix_host.data(), stream);
recv_prefix_host.data(),
stream);
return std::pair<IdArray, NDArray>(recv_idx, recv_value); return std::pair<IdArray, NDArray>(recv_idx, recv_value);
} }
template<typename IdType, typename DType> template <typename IdType, typename DType>
NDArray SparsePull( NDArray SparsePull(
NCCLCommunicatorRef comm, NCCLCommunicatorRef comm, IdArray req_idx, NDArray local_tensor,
IdArray req_idx,
NDArray local_tensor,
NDArrayPartitionRef part) { NDArrayPartitionRef part) {
const auto& ctx = req_idx->ctx; const auto& ctx = req_idx->ctx;
CHECK_EQ(ctx, local_tensor->ctx) << "The request indices and set of local " CHECK_EQ(ctx, local_tensor->ctx) << "The request indices and set of local "
"values must be on the same device"; "values must be on the same device";
auto device = DeviceAPI::Get(ctx); auto device = DeviceAPI::Get(ctx);
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
CHECK_LE(req_idx->ndim, 1) << "The tensor of requested indices must be of " CHECK_LE(req_idx->ndim, 1) << "The tensor of requested indices must be of "
"dimension one (or empty)."; "dimension one (or empty).";
const int64_t num_in = req_idx->ndim > 0 ? req_idx->shape[0] : 0; const int64_t num_in = req_idx->ndim > 0 ? req_idx->shape[0] : 0;
int64_t num_feat = 1; int64_t num_feat = 1;
for (int d = 1; d < local_tensor->ndim; ++d) { for (int d = 1; d < local_tensor->ndim; ++d) {
...@@ -333,86 +313,78 @@ NDArray SparsePull( ...@@ -333,86 +313,78 @@ NDArray SparsePull(
Workspace<IdType> send_idx(device, ctx, num_in); Workspace<IdType> send_idx(device, ctx, num_in);
std::pair<IdArray, NDArray> part_perm = part->GeneratePermutation(req_idx); std::pair<IdArray, NDArray> part_perm = part->GeneratePermutation(req_idx);
const IdType * const perm = static_cast<const IdType*>(part_perm.first->data); const IdType* const perm = static_cast<const IdType*>(part_perm.first->data);
const int64_t * const send_sum = const int64_t* const send_sum =
static_cast<const int64_t*>(part_perm.second->data); static_cast<const int64_t*>(part_perm.second->data);
// permute requests // permute requests
if (num_in > 0) { if (num_in > 0) {
const dim3 block(256); const dim3 block(256);
const dim3 grid((num_in+block.x-1)/block.x); const dim3 grid((num_in + block.x - 1) / block.x);
CUDA_KERNEL_CALL(aten::impl::IndexSelectSingleKernel, CUDA_KERNEL_CALL(
grid, block, 0, stream, aten::impl::IndexSelectSingleKernel, grid, block, 0, stream,
static_cast<const IdType*>(req_idx->data), static_cast<const IdType*>(req_idx->data), perm, num_in,
perm, req_idx->shape[0], send_idx.get());
num_in,
req_idx->shape[0],
send_idx.get());
} }
// compute the prefix sum of the indexes this process is requesting // compute the prefix sum of the indexes this process is requesting
Workspace<int64_t> request_prefix(device, ctx, comm_size+1); Workspace<int64_t> request_prefix(device, ctx, comm_size + 1);
{ {
size_t prefix_workspace_size; size_t prefix_workspace_size;
CUDA_CALL(cub::DeviceScan::ExclusiveSum(nullptr, prefix_workspace_size, CUDA_CALL(cub::DeviceScan::ExclusiveSum(
send_sum, request_prefix.get(), comm_size+1, stream)); nullptr, prefix_workspace_size, send_sum, request_prefix.get(),
comm_size + 1, stream));
Workspace<void> prefix_workspace(device, ctx, prefix_workspace_size); Workspace<void> prefix_workspace(device, ctx, prefix_workspace_size);
CUDA_CALL(cub::DeviceScan::ExclusiveSum(prefix_workspace.get(), CUDA_CALL(cub::DeviceScan::ExclusiveSum(
prefix_workspace_size, send_sum, request_prefix.get(), prefix_workspace.get(), prefix_workspace_size, send_sum,
comm_size+1, stream)); request_prefix.get(), comm_size + 1, stream));
} }
cudaEvent_t d2h; cudaEvent_t d2h;
CUDA_CALL(cudaEventCreate(&d2h)); CUDA_CALL(cudaEventCreate(&d2h));
std::vector<int64_t> request_prefix_host(comm_size+1); std::vector<int64_t> request_prefix_host(comm_size + 1);
// copy using the same stream (local current stream), no need to sync // copy using the same stream (local current stream), no need to sync
device->CopyDataFromTo( device->CopyDataFromTo(
request_prefix.get(), request_prefix.get(), 0, request_prefix_host.data(), 0,
0, request_prefix_host.size() * sizeof(*request_prefix.get()), ctx,
request_prefix_host.data(),
0,
request_prefix_host.size()*sizeof(*request_prefix.get()),
ctx,
DGLContext{kDGLCPU, 0}, DGLContext{kDGLCPU, 0},
DGLDataType{kDGLInt, sizeof(*request_prefix.get())*8, 1}); DGLDataType{kDGLInt, sizeof(*request_prefix.get()) * 8, 1});
request_prefix.free(); request_prefix.free();
CHECK_EQ(request_prefix_host.back(), num_in) << "Internal Error: " CHECK_EQ(request_prefix_host.back(), num_in)
"request_prefix_host.back() = " << request_prefix_host.back() << << "Internal Error: "
", num_in = " << num_in; "request_prefix_host.back() = "
<< request_prefix_host.back() << ", num_in = " << num_in;
// communicate the amount requested // communicate the amount requested
Workspace<int64_t> recv_sum(device, ctx, comm_size+1); Workspace<int64_t> recv_sum(device, ctx, comm_size + 1);
comm->AllToAll(send_sum, recv_sum.get(), 1, stream); comm->AllToAll(send_sum, recv_sum.get(), 1, stream);
// compute the prefix sum of the requested indexes // compute the prefix sum of the requested indexes
Workspace<int64_t> response_prefix(device, ctx, comm_size+1); Workspace<int64_t> response_prefix(device, ctx, comm_size + 1);
{ {
size_t prefix_workspace_size; size_t prefix_workspace_size;
CUDA_CALL(cub::DeviceScan::ExclusiveSum(nullptr, prefix_workspace_size, CUDA_CALL(cub::DeviceScan::ExclusiveSum(
recv_sum.get(), response_prefix.get(), comm_size+1, stream)); nullptr, prefix_workspace_size, recv_sum.get(), response_prefix.get(),
comm_size + 1, stream));
Workspace<void> prefix_workspace(device, ctx, prefix_workspace_size); Workspace<void> prefix_workspace(device, ctx, prefix_workspace_size);
CUDA_CALL(cub::DeviceScan::ExclusiveSum(prefix_workspace.get(), CUDA_CALL(cub::DeviceScan::ExclusiveSum(
prefix_workspace_size, recv_sum.get(), response_prefix.get(), prefix_workspace.get(), prefix_workspace_size, recv_sum.get(),
comm_size+1, stream)); response_prefix.get(), comm_size + 1, stream));
} }
recv_sum.free(); recv_sum.free();
// finally copy the prefixsum sum down to the host // finally copy the prefixsum sum down to the host
std::vector<int64_t> response_prefix_host(comm_size+1); std::vector<int64_t> response_prefix_host(comm_size + 1);
// copy using the same stream (local current stream), no need to sync // copy using the same stream (local current stream), no need to sync
device->CopyDataFromTo( device->CopyDataFromTo(
response_prefix.get(), response_prefix.get(), 0, response_prefix_host.data(), 0,
0, response_prefix_host.size() * sizeof(*response_prefix.get()), ctx,
response_prefix_host.data(),
0,
response_prefix_host.size()*sizeof(*response_prefix.get()),
ctx,
DGLContext{kDGLCPU, 0}, DGLContext{kDGLCPU, 0},
DGLDataType{kDGLInt, sizeof(*response_prefix.get())*8, 1}); DGLDataType{kDGLInt, sizeof(*response_prefix.get()) * 8, 1});
response_prefix.free(); response_prefix.free();
// use an event to track when copying is done // use an event to track when copying is done
...@@ -423,13 +395,11 @@ NDArray SparsePull( ...@@ -423,13 +395,11 @@ NDArray SparsePull(
CUDA_CALL(cudaEventDestroy(d2h)); CUDA_CALL(cudaEventDestroy(d2h));
// gather requested indexes // gather requested indexes
IdArray recv_idx = aten::NewIdArray( IdArray recv_idx =
response_prefix_host.back(), ctx, sizeof(IdType)*8); aten::NewIdArray(response_prefix_host.back(), ctx, sizeof(IdType) * 8);
comm->AllToAllV( comm->AllToAllV(
send_idx.get(), send_idx.get(), request_prefix_host.data(),
request_prefix_host.data(), static_cast<IdType*>(recv_idx->data), response_prefix_host.data(),
static_cast<IdType*>(recv_idx->data),
response_prefix_host.data(),
stream); stream);
send_idx.free(); send_idx.free();
...@@ -439,24 +409,21 @@ NDArray SparsePull( ...@@ -439,24 +409,21 @@ NDArray SparsePull(
} }
// and then index select them into place // and then index select them into place
Workspace<DType> filled_response_value(device, ctx, Workspace<DType> filled_response_value(
response_prefix_host.back()*num_feat); device, ctx, response_prefix_host.back() * num_feat);
if (response_prefix_host.back() > 0) { if (response_prefix_host.back() > 0) {
dim3 block(256, 1); dim3 block(256, 1);
while (block.x >= 2*num_feat) { while (block.x >= 2 * num_feat) {
block.x /= 2; block.x /= 2;
block.y *= 2; block.y *= 2;
} }
const dim3 grid((response_prefix_host.back()+block.y-1)/block.y); const dim3 grid((response_prefix_host.back() + block.y - 1) / block.y);
CUDA_KERNEL_CALL(aten::impl::IndexSelectMultiKernel, CUDA_KERNEL_CALL(
grid, block, 0, stream, aten::impl::IndexSelectMultiKernel, grid, block, 0, stream,
static_cast<const DType*>(local_tensor->data), static_cast<const DType*>(local_tensor->data), num_feat,
num_feat, static_cast<IdType*>(recv_idx->data), response_prefix_host.back(),
static_cast<IdType*>(recv_idx->data), local_tensor->shape[0], filled_response_value.get());
response_prefix_host.back(),
local_tensor->shape[0],
filled_response_value.get());
} }
// we will collect recieved values in this array // we will collect recieved values in this array
...@@ -465,8 +432,8 @@ NDArray SparsePull( ...@@ -465,8 +432,8 @@ NDArray SparsePull(
for (int d = 1; d < local_tensor->ndim; ++d) { for (int d = 1; d < local_tensor->ndim; ++d) {
value_shape[d] = local_tensor->shape[d]; value_shape[d] = local_tensor->shape[d];
} }
Workspace<DType> filled_request_value(device, ctx, Workspace<DType> filled_request_value(
request_prefix_host.back()*num_feat); device, ctx, request_prefix_host.back() * num_feat);
// multiply the prefixes by the number of features being sent // multiply the prefixes by the number of features being sent
for (auto& v : request_prefix_host) { for (auto& v : request_prefix_host) {
...@@ -478,30 +445,23 @@ NDArray SparsePull( ...@@ -478,30 +445,23 @@ NDArray SparsePull(
// send the values // send the values
comm->AllToAllV( comm->AllToAllV(
filled_response_value.get(), filled_response_value.get(), response_prefix_host.data(),
response_prefix_host.data(), filled_request_value.get(), request_prefix_host.data(), stream);
filled_request_value.get(),
request_prefix_host.data(),
stream);
filled_response_value.free(); filled_response_value.free();
// finally, we need to permute the values back into the requested order // finally, we need to permute the values back into the requested order
NDArray result = NDArray::Empty(value_shape, local_tensor->dtype, ctx); NDArray result = NDArray::Empty(value_shape, local_tensor->dtype, ctx);
if (num_in > 0) { if (num_in > 0) {
dim3 block(256, 1); dim3 block(256, 1);
while (block.x >= 2*num_feat) { while (block.x >= 2 * num_feat) {
block.x /= 2; block.x /= 2;
block.y *= 2; block.y *= 2;
} }
const dim3 grid((num_in+block.y-1)/block.y); const dim3 grid((num_in + block.y - 1) / block.y);
CUDA_KERNEL_CALL(_InversePermKernel, CUDA_KERNEL_CALL(
grid, block, 0, stream, _InversePermKernel, grid, block, 0, stream, filled_request_value.get(),
filled_request_value.get(), num_feat, num_in, perm, static_cast<DType*>(result->data));
num_feat,
num_in,
perm,
static_cast<DType*>(result->data));
} }
return result; return result;
...@@ -511,21 +471,18 @@ NDArray SparsePull( ...@@ -511,21 +471,18 @@ NDArray SparsePull(
/* NCCLUniqueId **************************************************************/ /* NCCLUniqueId **************************************************************/
NCCLUniqueId::NCCLUniqueId() : NCCLUniqueId::NCCLUniqueId() : id_() {
id_() { #ifdef DGL_USE_NCCL
#ifdef DGL_USE_NCCL
// this ID is unique to the process, not to each call of this function // this ID is unique to the process, not to each call of this function
NCCL_CALL(ncclGetUniqueId(&id_)); NCCL_CALL(ncclGetUniqueId(&id_));
#else #else
// when NCCL isn't enabled, use all zeros // when NCCL isn't enabled, use all zeros
std::fill(id_.internal, id_.internal + NCCL_UNIQUE_ID_BYTES, std::fill(
static_cast<char>(0)); id_.internal, id_.internal + NCCL_UNIQUE_ID_BYTES, static_cast<char>(0));
#endif #endif
} }
ncclUniqueId NCCLUniqueId::Get() const { ncclUniqueId NCCLUniqueId::Get() const { return id_; }
return id_;
}
std::string NCCLUniqueId::ToString() const { std::string NCCLUniqueId::ToString() const {
std::ostringstream oss; std::ostringstream oss;
...@@ -538,82 +495,78 @@ std::string NCCLUniqueId::ToString() const { ...@@ -538,82 +495,78 @@ std::string NCCLUniqueId::ToString() const {
} }
std::string result = oss.str(); std::string result = oss.str();
CHECK_EQ(result.length(), NCCL_UNIQUE_ID_BYTES*2) << CHECK_EQ(result.length(), NCCL_UNIQUE_ID_BYTES * 2)
"Invalid NCCL ID format: '" << result << "'"; << "Invalid NCCL ID format: '" << result << "'";
return result; return result;
} }
void NCCLUniqueId::FromString( void NCCLUniqueId::FromString(const std::string& str) {
const std::string& str) {
// must be exactly 256 hex characters // must be exactly 256 hex characters
CHECK_EQ(str.length(), NCCL_UNIQUE_ID_BYTES * 2) << CHECK_EQ(str.length(), NCCL_UNIQUE_ID_BYTES * 2)
"Invalid NCCL ID format: '" << str << "'"; << "Invalid NCCL ID format: '" << str << "'";
for (size_t b = 0; b < NCCL_UNIQUE_ID_BYTES; ++b) { for (size_t b = 0; b < NCCL_UNIQUE_ID_BYTES; ++b) {
id_.internal[b] = std::strtol(str.substr(b*2, 2).c_str(), nullptr, 16); id_.internal[b] = std::strtol(str.substr(b * 2, 2).c_str(), nullptr, 16);
} }
} }
/* NCCLCommunicator **********************************************************/ /* NCCLCommunicator **********************************************************/
NCCLCommunicator::NCCLCommunicator( NCCLCommunicator::NCCLCommunicator(
const int size, const int size, const int rank, ncclUniqueId id)
const int rank, : comm_(), size_(size), rank_(rank) {
ncclUniqueId id) : CHECK_LT(rank, size) << "The rank (" << rank
comm_(), << ") must be smaller than "
size_(size), "the size of the communicator ("
rank_(rank) { << size << ").";
CHECK_LT(rank, size) << "The rank (" << rank << ") must be smaller than " CHECK_GE(rank, 0) << "The rank (" << rank
"the size of the communicator (" << size << ")."; << ") must be greater than or "
CHECK_GE(rank, 0) << "The rank (" << rank << ") must be greater than or " "equal to 0.";
"equal to 0.";
#ifdef DGL_USE_NCCL
#ifdef DGL_USE_NCCL
NCCL_CALL(ncclCommInitRank(&comm_, size_, id, rank_)); NCCL_CALL(ncclCommInitRank(&comm_, size_, id, rank_));
#else #else
CHECK_EQ(size, 1) << "Cannot create a communicator of size " << size << ". " CHECK_EQ(size, 1)
"To use a communicator size greater than 1, compile DGL with NCCL " << "Cannot create a communicator of size " << size
"support."; << ". "
#endif "To use a communicator size greater than 1, compile DGL with NCCL "
"support.";
#endif
} }
NCCLCommunicator::~NCCLCommunicator() { NCCLCommunicator::~NCCLCommunicator() {
#ifdef DGL_USE_NCCL #ifdef DGL_USE_NCCL
ncclCommDestroy(comm_); ncclCommDestroy(comm_);
#endif #endif
} }
ncclComm_t NCCLCommunicator::Get() { ncclComm_t NCCLCommunicator::Get() { return comm_; }
return comm_;
}
template<typename DType> template <typename DType>
void NCCLCommunicator::AllToAllV( void NCCLCommunicator::AllToAllV(
const DType * const send, const DType* const send, const int64_t* const send_prefix,
const int64_t * const send_prefix, DType* const recv, const int64_t* const recv_prefix, cudaStream_t stream) {
DType * const recv, #ifdef DGL_USE_NCCL
const int64_t * const recv_prefix,
cudaStream_t stream) {
#ifdef DGL_USE_NCCL
const ncclDataType_t type = NCCLType<DType>(); const ncclDataType_t type = NCCLType<DType>();
NCCL_CALL(ncclGroupStart()); NCCL_CALL(ncclGroupStart());
for (int r = 0; r < size_; ++r) { for (int r = 0; r < size_; ++r) {
const int64_t send_size = send_prefix[r+1]-send_prefix[r]; const int64_t send_size = send_prefix[r + 1] - send_prefix[r];
if (send_size > 0) { if (send_size > 0) {
NCCL_CALL(ncclSend(send+send_prefix[r], send_size, type, r, comm_, stream)); NCCL_CALL(
ncclSend(send + send_prefix[r], send_size, type, r, comm_, stream));
} }
const int64_t recv_size = recv_prefix[r+1]-recv_prefix[r]; const int64_t recv_size = recv_prefix[r + 1] - recv_prefix[r];
if (recv_size > 0) { if (recv_size > 0) {
NCCL_CALL(ncclRecv(recv+recv_prefix[r], recv_size, type, r, comm_, stream)); NCCL_CALL(
ncclRecv(recv + recv_prefix[r], recv_size, type, r, comm_, stream));
} }
} }
NCCL_CALL(ncclGroupEnd()); NCCL_CALL(ncclGroupEnd());
#else #else
CHECK_EQ(send_prefix[1]-send_prefix[0], recv_prefix[1]-recv_prefix[0]) << CHECK_EQ(send_prefix[1] - send_prefix[0], recv_prefix[1] - recv_prefix[0])
"Send message size must equal receive message size."; << "Send message size must equal receive message size.";
int dev_id; int dev_id;
CUDA_CALL(cudaGetDevice(&dev_id)); CUDA_CALL(cudaGetDevice(&dev_id));
...@@ -623,60 +576,39 @@ void NCCLCommunicator::AllToAllV( ...@@ -623,60 +576,39 @@ void NCCLCommunicator::AllToAllV(
auto dtype = DGLDataTypeTraits<DType>::dtype; auto dtype = DGLDataTypeTraits<DType>::dtype;
// copy using the same stream (local current stream), no need to sync // copy using the same stream (local current stream), no need to sync
device->CopyDataFromTo(send, send_prefix[0], device->CopyDataFromTo(
recv, recv_prefix[0], send, send_prefix[0], recv, recv_prefix[0],
sizeof(DType)*send_prefix[1]-send_prefix[0], sizeof(DType) * send_prefix[1] - send_prefix[0], ctx, ctx, dtype);
ctx, ctx, #endif
dtype);
#endif
} }
template template void NCCLCommunicator::AllToAllV<int32_t>(
void NCCLCommunicator::AllToAllV<int32_t>( const int32_t* const send, const int64_t* send_prefix, int32_t* const recv,
const int32_t * const send, const int64_t* recv_prefix, cudaStream_t stream);
const int64_t * send_prefix, template void NCCLCommunicator::AllToAllV<int64_t>(
int32_t * const recv, const int64_t* const send, const int64_t* send_prefix, int64_t* const recv,
const int64_t * recv_prefix, const int64_t* recv_prefix, cudaStream_t stream);
cudaStream_t stream); template void NCCLCommunicator::AllToAllV<float>(
template const float* const send, const int64_t* send_prefix, float* const recv,
void NCCLCommunicator::AllToAllV<int64_t>( const int64_t* recv_prefix, cudaStream_t stream);
const int64_t * const send, template void NCCLCommunicator::AllToAllV<__half>(
const int64_t * send_prefix, const __half* const send, const int64_t* send_prefix, __half* const recv,
int64_t * const recv, const int64_t* recv_prefix, cudaStream_t stream);
const int64_t * recv_prefix,
cudaStream_t stream); template <typename IdType>
template
void NCCLCommunicator::AllToAllV<float>(
const float * const send,
const int64_t * send_prefix,
float * const recv,
const int64_t * recv_prefix,
cudaStream_t stream);
template
void NCCLCommunicator::AllToAllV<__half>(
const __half * const send,
const int64_t * send_prefix,
__half * const recv,
const int64_t * recv_prefix,
cudaStream_t stream);
template<typename IdType>
void NCCLCommunicator::AllToAll( void NCCLCommunicator::AllToAll(
const IdType * const send, const IdType* const send, IdType* const recv, const int64_t count,
IdType * const recv,
const int64_t count,
cudaStream_t stream) { cudaStream_t stream) {
#ifdef DGL_USE_NCCL #ifdef DGL_USE_NCCL
const ncclDataType_t type = NCCLType<IdType>(); const ncclDataType_t type = NCCLType<IdType>();
NCCL_CALL(ncclGroupStart()); NCCL_CALL(ncclGroupStart());
for (int r = 0; r < size_; ++r) { for (int r = 0; r < size_; ++r) {
NCCL_CALL(ncclSend(send+(r*count), count, type, r, comm_, stream)); NCCL_CALL(ncclSend(send + (r * count), count, type, r, comm_, stream));
NCCL_CALL(ncclRecv(recv+(r*count), count, type, r, comm_, stream)); NCCL_CALL(ncclRecv(recv + (r * count), count, type, r, comm_, stream));
} }
NCCL_CALL(ncclGroupEnd()); NCCL_CALL(ncclGroupEnd());
#else #else
int dev_id; int dev_id;
CUDA_CALL(cudaGetDevice(&dev_id)); CUDA_CALL(cudaGetDevice(&dev_id));
DGLContext ctx{kDGLCUDA, dev_id}; DGLContext ctx{kDGLCUDA, dev_id};
...@@ -686,156 +618,131 @@ void NCCLCommunicator::AllToAll( ...@@ -686,156 +618,131 @@ void NCCLCommunicator::AllToAll(
// copy using the same stream (local current stream), no need to sync // copy using the same stream (local current stream), no need to sync
device->CopyDataFromTo(send, 0, recv, 0, count, ctx, ctx, dtype); device->CopyDataFromTo(send, 0, recv, 0, count, ctx, ctx, dtype);
#endif #endif
} }
template template void NCCLCommunicator::AllToAll<int32_t>(
void NCCLCommunicator::AllToAll<int32_t>( const int32_t* const send, int32_t* const recv, const int64_t count,
const int32_t * const send,
int32_t * const recv,
const int64_t count,
cudaStream_t stream); cudaStream_t stream);
template template void NCCLCommunicator::AllToAll<int64_t>(
void NCCLCommunicator::AllToAll<int64_t>( const int64_t* const send, int64_t* const recv, const int64_t count,
const int64_t * const send,
int64_t * const recv,
const int64_t count,
cudaStream_t stream); cudaStream_t stream);
template <typename IdType, typename DType>
template<typename IdType, typename DType>
void NCCLCommunicator::SparseAllToAll( void NCCLCommunicator::SparseAllToAll(
const IdType * const send_idx, const IdType* const send_idx, const DType* const send_value,
const DType * const send_value, const int64_t num_feat, const int64_t* const send_prefix,
const int64_t num_feat, IdType* const recv_idx, DType* const recv_value,
const int64_t * const send_prefix, const int64_t* const recv_prefix, cudaStream_t stream) {
IdType * const recv_idx,
DType * const recv_value,
const int64_t * const recv_prefix,
cudaStream_t stream) {
// idxs // idxs
AllToAllV(send_idx, send_prefix, recv_idx, recv_prefix, stream); AllToAllV(send_idx, send_prefix, recv_idx, recv_prefix, stream);
// scale prefixes by number of features // scale prefixes by number of features
std::vector<int64_t> value_send_prefix(size_+1); std::vector<int64_t> value_send_prefix(size_ + 1);
for (int r = 0; r < size_+1; ++r) { for (int r = 0; r < size_ + 1; ++r) {
value_send_prefix[r] = send_prefix[r]*num_feat; value_send_prefix[r] = send_prefix[r] * num_feat;
} }
std::vector<int64_t> value_recv_prefix(size_+1); std::vector<int64_t> value_recv_prefix(size_ + 1);
for (int r = 0; r < size_+1; ++r) { for (int r = 0; r < size_ + 1; ++r) {
value_recv_prefix[r] = recv_prefix[r]*num_feat; value_recv_prefix[r] = recv_prefix[r] * num_feat;
} }
AllToAllV(send_value, value_send_prefix.data(), AllToAllV(
recv_value, value_recv_prefix.data(), stream); send_value, value_send_prefix.data(), recv_value,
value_recv_prefix.data(), stream);
} }
template void NCCLCommunicator::SparseAllToAll<int32_t, __half>(
const int32_t* const send_idx, const __half* const send_value,
const int64_t num_feat, const int64_t* const send_prefix,
int32_t* const recv_idx, __half* const recv_value,
const int64_t* const recv_prefix, cudaStream_t stream);
template void NCCLCommunicator::SparseAllToAll<int64_t, __half>(
const int64_t* const send_idx, const __half* const send_value,
const int64_t num_feat, const int64_t* const send_prefix,
int64_t* const recv_idx, __half* const recv_value,
const int64_t* const recv_prefix, cudaStream_t stream);
template int NCCLCommunicator::size() const { return size_; }
void NCCLCommunicator::SparseAllToAll<int32_t, __half>(
const int32_t * const send_idx,
const __half * const send_value,
const int64_t num_feat,
const int64_t * const send_prefix,
int32_t * const recv_idx,
__half * const recv_value,
const int64_t * const recv_prefix,
cudaStream_t stream);
template
void NCCLCommunicator::SparseAllToAll<int64_t, __half>(
const int64_t * const send_idx,
const __half * const send_value,
const int64_t num_feat,
const int64_t * const send_prefix,
int64_t * const recv_idx,
__half * const recv_value,
const int64_t * const recv_prefix,
cudaStream_t stream);
int NCCLCommunicator::size() const {
return size_;
}
int NCCLCommunicator::rank() const {
return rank_;
}
int NCCLCommunicator::rank() const { return rank_; }
/* CAPI **********************************************************************/ /* CAPI **********************************************************************/
DGL_REGISTER_GLOBAL("cuda.nccl._CAPI_DGLNCCLGetUniqueId") DGL_REGISTER_GLOBAL("cuda.nccl._CAPI_DGLNCCLGetUniqueId")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
*rv = NCCLUniqueIdRef(std::make_shared<NCCLUniqueId>()); *rv = NCCLUniqueIdRef(std::make_shared<NCCLUniqueId>());
}); });
DGL_REGISTER_GLOBAL("cuda.nccl._CAPI_DGLNCCLUniqueIdToString") DGL_REGISTER_GLOBAL("cuda.nccl._CAPI_DGLNCCLUniqueIdToString")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
NCCLUniqueIdRef idObj = args[0]; NCCLUniqueIdRef idObj = args[0];
*rv = idObj->ToString(); *rv = idObj->ToString();
}); });
DGL_REGISTER_GLOBAL("cuda.nccl._CAPI_DGLNCCLUniqueIdFromString") DGL_REGISTER_GLOBAL("cuda.nccl._CAPI_DGLNCCLUniqueIdFromString")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const std::string str = args[0]; const std::string str = args[0];
NCCLUniqueIdRef ref(std::make_shared<NCCLUniqueId>()); NCCLUniqueIdRef ref(std::make_shared<NCCLUniqueId>());
ref->FromString(str); ref->FromString(str);
*rv = ref; *rv = ref;
}); });
DGL_REGISTER_GLOBAL("cuda.nccl._CAPI_DGLNCCLCreateComm") DGL_REGISTER_GLOBAL("cuda.nccl._CAPI_DGLNCCLCreateComm")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const int size = args[0]; const int size = args[0];
const int rank = args[1]; const int rank = args[1];
NCCLUniqueIdRef idObj = args[2]; NCCLUniqueIdRef idObj = args[2];
*rv = NCCLCommunicatorRef(std::make_shared<NCCLCommunicator>(size, rank, *rv = NCCLCommunicatorRef(
idObj->Get())); std::make_shared<NCCLCommunicator>(size, rank, idObj->Get()));
}); });
DGL_REGISTER_GLOBAL("cuda.nccl._CAPI_DGLNCCLSparseAllToAllPush") DGL_REGISTER_GLOBAL("cuda.nccl._CAPI_DGLNCCLSparseAllToAllPush")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
NCCLCommunicatorRef comm = args[0]; NCCLCommunicatorRef comm = args[0];
IdArray in_idx = args[1]; IdArray in_idx = args[1];
NDArray in_values = args[2]; NDArray in_values = args[2];
NDArrayPartitionRef part = args[3]; NDArrayPartitionRef part = args[3];
List<ObjectRef> ret; List<ObjectRef> ret;
ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, { ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
ATEN_DTYPE_SWITCH(in_values->dtype, DType, "values", { ATEN_DTYPE_SWITCH(in_values->dtype, DType, "values", {
auto result = SparsePush<IdType, DType>(comm, in_idx, in_values, part); auto result =
ret.push_back(Value(MakeValue(result.first))); SparsePush<IdType, DType>(comm, in_idx, in_values, part);
ret.push_back(Value(MakeValue(result.second))); ret.push_back(Value(MakeValue(result.first)));
ret.push_back(Value(MakeValue(result.second)));
});
});
*rv = ret;
}); });
});
*rv = ret;
});
DGL_REGISTER_GLOBAL("cuda.nccl._CAPI_DGLNCCLSparseAllToAllPull") DGL_REGISTER_GLOBAL("cuda.nccl._CAPI_DGLNCCLSparseAllToAllPull")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
NCCLCommunicatorRef comm = args[0]; NCCLCommunicatorRef comm = args[0];
// the indexes this process is requesting from others // the indexes this process is requesting from others
IdArray req_idx = args[1]; IdArray req_idx = args[1];
// the tensor this process has to fulfill other requests // the tensor this process has to fulfill other requests
NDArray tensor = args[2]; NDArray tensor = args[2];
NDArrayPartitionRef part = args[3]; NDArrayPartitionRef part = args[3];
ATEN_ID_TYPE_SWITCH(req_idx->dtype, IdType, { ATEN_ID_TYPE_SWITCH(req_idx->dtype, IdType, {
ATEN_DTYPE_SWITCH(tensor->dtype, DType, "values", { ATEN_DTYPE_SWITCH(tensor->dtype, DType, "values", {
*rv = SparsePull<IdType, DType>(comm, req_idx, tensor, part); *rv = SparsePull<IdType, DType>(comm, req_idx, tensor, part);
});
});
}); });
});
});
DGL_REGISTER_GLOBAL("cuda.nccl._CAPI_DGLNCCLHasSupport") DGL_REGISTER_GLOBAL("cuda.nccl._CAPI_DGLNCCLHasSupport")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
#ifndef DGL_USE_NCCL #ifndef DGL_USE_NCCL
return false; return false;
#else #else
return true; return true;
#endif #endif
}); });
} // namespace cuda } // namespace cuda
} // namespace runtime } // namespace runtime
......
...@@ -14,10 +14,9 @@ ...@@ -14,10 +14,9 @@
* limitations under the License. * limitations under the License.
* *
* \file nccl_api.h * \file nccl_api.h
* \brief Wrapper around NCCL routines. * \brief Wrapper around NCCL routines.
*/ */
#ifndef DGL_RUNTIME_CUDA_NCCL_API_H_ #ifndef DGL_RUNTIME_CUDA_NCCL_API_H_
#define DGL_RUNTIME_CUDA_NCCL_API_H_ #define DGL_RUNTIME_CUDA_NCCL_API_H_
...@@ -27,11 +26,14 @@ ...@@ -27,11 +26,14 @@
// if not compiling with NCCL, this class will only support communicators of // if not compiling with NCCL, this class will only support communicators of
// size 1. // size 1.
#define NCCL_UNIQUE_ID_BYTES 128 #define NCCL_UNIQUE_ID_BYTES 128
typedef struct { char internal[NCCL_UNIQUE_ID_BYTES]; } ncclUniqueId; typedef struct {
char internal[NCCL_UNIQUE_ID_BYTES];
} ncclUniqueId;
typedef int ncclComm_t; typedef int ncclComm_t;
#endif #endif
#include <dgl/runtime/object.h> #include <dgl/runtime/object.h>
#include <string> #include <string>
namespace dgl { namespace dgl {
...@@ -59,17 +61,13 @@ DGL_DEFINE_OBJECT_REF(NCCLUniqueIdRef, NCCLUniqueId); ...@@ -59,17 +61,13 @@ DGL_DEFINE_OBJECT_REF(NCCLUniqueIdRef, NCCLUniqueId);
class NCCLCommunicator : public runtime::Object { class NCCLCommunicator : public runtime::Object {
public: public:
NCCLCommunicator( NCCLCommunicator(int size, int rank, ncclUniqueId id);
int size,
int rank,
ncclUniqueId id);
~NCCLCommunicator(); ~NCCLCommunicator();
// disable copying // disable copying
NCCLCommunicator(const NCCLCommunicator& other) = delete; NCCLCommunicator(const NCCLCommunicator& other) = delete;
NCCLCommunicator& operator=( NCCLCommunicator& operator=(const NCCLCommunicator& other);
const NCCLCommunicator& other);
ncclComm_t Get(); ncclComm_t Get();
...@@ -81,12 +79,9 @@ class NCCLCommunicator : public runtime::Object { ...@@ -81,12 +79,9 @@ class NCCLCommunicator : public runtime::Object {
* @param count The size of data to send to each rank. * @param count The size of data to send to each rank.
* @param stream The stream to operate on. * @param stream The stream to operate on.
*/ */
template<typename IdType> template <typename IdType>
void AllToAll( void AllToAll(
const IdType * send, const IdType* send, IdType* recv, int64_t count, cudaStream_t stream);
IdType * recv,
int64_t count,
cudaStream_t stream);
/** /**
* @brief Perform an all-to-all variable sized communication. * @brief Perform an all-to-all variable sized communication.
...@@ -99,13 +94,10 @@ class NCCLCommunicator : public runtime::Object { ...@@ -99,13 +94,10 @@ class NCCLCommunicator : public runtime::Object {
* @param type The type of data to send. * @param type The type of data to send.
* @param stream The stream to operate on. * @param stream The stream to operate on.
*/ */
template<typename DType> template <typename DType>
void AllToAllV( void AllToAllV(
const DType * const send, const DType* const send, const int64_t* send_prefix, DType* const recv,
const int64_t * send_prefix, const int64_t* recv_prefix, cudaStream_t stream);
DType * const recv,
const int64_t * recv_prefix,
cudaStream_t stream);
/** /**
* @brief Perform an all-to-all with sparse data (idx and value pairs). By * @brief Perform an all-to-all with sparse data (idx and value pairs). By
...@@ -124,16 +116,11 @@ class NCCLCommunicator : public runtime::Object { ...@@ -124,16 +116,11 @@ class NCCLCommunicator : public runtime::Object {
* recieve on the host. * recieve on the host.
* @param stream The stream to communicate on. * @param stream The stream to communicate on.
*/ */
template<typename IdType, typename DType> template <typename IdType, typename DType>
void SparseAllToAll( void SparseAllToAll(
const IdType * send_idx, const IdType* send_idx, const DType* send_value, const int64_t num_feat,
const DType * send_value, const int64_t* send_prefix, IdType* recv_idx, DType* recv_value,
const int64_t num_feat, const int64_t* recv_prefix, cudaStream_t stream);
const int64_t * send_prefix,
IdType * recv_idx,
DType * recv_value,
const int64_t * recv_prefix,
cudaStream_t stream);
int size() const; int size() const;
......
...@@ -3,12 +3,12 @@ ...@@ -3,12 +3,12 @@
* \file src/runtime/dlpack_convert.cc * \file src/runtime/dlpack_convert.cc
* \brief Conversion between NDArray and DLPack. * \brief Conversion between NDArray and DLPack.
*/ */
#include <dgl/runtime/dlpack_convert.h>
#include <dlpack/dlpack.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/runtime/c_runtime_api.h> #include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/device_api.h> #include <dgl/runtime/device_api.h>
#include <dgl/runtime/dlpack_convert.h>
#include <dgl/runtime/ndarray.h>
#include <dlpack/dlpack.h>
#include "runtime_base.h" #include "runtime_base.h"
// deleter for arrays used by DLPack exporter // deleter for arrays used by DLPack exporter
...@@ -69,8 +69,7 @@ NDArray DLPackConvert::FromDLPack(DLManagedTensor* tensor) { ...@@ -69,8 +69,7 @@ NDArray DLPackConvert::FromDLPack(DLManagedTensor* tensor) {
void DLPackConvert::DLPackDeleter(NDArray::Container* ptr) { void DLPackConvert::DLPackDeleter(NDArray::Container* ptr) {
// if the array is pinned by dgl, unpin it before freeing // if the array is pinned by dgl, unpin it before freeing
if (ptr->pinned_by_dgl_) if (ptr->pinned_by_dgl_) NDArray::UnpinContainer(ptr);
NDArray::UnpinContainer(ptr);
DLManagedTensor* tensor = static_cast<DLManagedTensor*>(ptr->manager_ctx); DLManagedTensor* tensor = static_cast<DLManagedTensor*>(ptr->manager_ctx);
if (tensor->deleter != nullptr) { if (tensor->deleter != nullptr) {
(*tensor->deleter)(tensor); (*tensor->deleter)(tensor);
...@@ -95,7 +94,7 @@ DLManagedTensor* ContainerToDLPack(NDArray::Container* from) { ...@@ -95,7 +94,7 @@ DLManagedTensor* ContainerToDLPack(NDArray::Container* from) {
return ret; return ret;
} }
DLManagedTensor* DLPackConvert::ToDLPack(const NDArray &from) { DLManagedTensor* DLPackConvert::ToDLPack(const NDArray& from) {
return ContainerToDLPack(from.data_); return ContainerToDLPack(from.data_);
} }
...@@ -113,15 +112,14 @@ inline bool IsAligned(const void* ptr, std::uintptr_t alignment) noexcept { ...@@ -113,15 +112,14 @@ inline bool IsAligned(const void* ptr, std::uintptr_t alignment) noexcept {
return !(iptr % alignment); return !(iptr % alignment);
} }
int DGLArrayFromDLPack(DLManagedTensor* from, int DGLArrayFromDLPack(DLManagedTensor* from, DGLArrayHandle* out) {
DGLArrayHandle* out) {
API_BEGIN(); API_BEGIN();
*out = NDArray::Internal::MoveAsDGLArray(DLPackConvert::FromDLPack(from)); *out = NDArray::Internal::MoveAsDGLArray(DLPackConvert::FromDLPack(from));
API_END(); API_END();
} }
int DGLArrayToDLPack(DGLArrayHandle from, DLManagedTensor** out, int DGLArrayToDLPack(
int alignment) { DGLArrayHandle from, DLManagedTensor** out, int alignment) {
API_BEGIN(); API_BEGIN();
auto* nd_container = reinterpret_cast<NDArray::Container*>(from); auto* nd_container = reinterpret_cast<NDArray::Container*>(from);
DGLArray* nd = &(nd_container->dl_tensor); DGLArray* nd = &(nd_container->dl_tensor);
......
...@@ -4,8 +4,9 @@ ...@@ -4,8 +4,9 @@
* \brief Module to load from dynamic shared library. * \brief Module to load from dynamic shared library.
*/ */
#include <dgl/runtime/module.h> #include <dgl/runtime/module.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/packed_func.h> #include <dgl/runtime/packed_func.h>
#include <dgl/runtime/registry.h>
#include "module_util.h" #include "module_util.h"
#if defined(_WIN32) #if defined(_WIN32)
...@@ -25,9 +26,7 @@ class DSOModuleNode final : public ModuleNode { ...@@ -25,9 +26,7 @@ class DSOModuleNode final : public ModuleNode {
if (lib_handle_) Unload(); if (lib_handle_) Unload();
} }
const char* type_key() const final { const char* type_key() const final { return "dso"; }
return "dso";
}
PackedFunc GetFunction( PackedFunc GetFunction(
const std::string& name, const std::string& name,
...@@ -36,8 +35,9 @@ class DSOModuleNode final : public ModuleNode { ...@@ -36,8 +35,9 @@ class DSOModuleNode final : public ModuleNode {
if (name == runtime::symbol::dgl_module_main) { if (name == runtime::symbol::dgl_module_main) {
const char* entry_name = reinterpret_cast<const char*>( const char* entry_name = reinterpret_cast<const char*>(
GetSymbol(runtime::symbol::dgl_module_main)); GetSymbol(runtime::symbol::dgl_module_main));
CHECK(entry_name!= nullptr) CHECK(entry_name != nullptr)
<< "Symbol " << runtime::symbol::dgl_module_main << " is not presented"; << "Symbol " << runtime::symbol::dgl_module_main
<< " is not presented";
faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(entry_name)); faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(entry_name));
} else { } else {
faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(name.c_str())); faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(name.c_str()));
...@@ -48,17 +48,15 @@ class DSOModuleNode final : public ModuleNode { ...@@ -48,17 +48,15 @@ class DSOModuleNode final : public ModuleNode {
void Init(const std::string& name) { void Init(const std::string& name) {
Load(name); Load(name);
if (auto *ctx_addr = if (auto* ctx_addr = reinterpret_cast<void**>(
reinterpret_cast<void**>(GetSymbol(runtime::symbol::dgl_module_ctx))) { GetSymbol(runtime::symbol::dgl_module_ctx))) {
*ctx_addr = this; *ctx_addr = this;
} }
InitContextFunctions([this](const char* fname) { InitContextFunctions(
return GetSymbol(fname); [this](const char* fname) { return GetSymbol(fname); });
});
// Load the imported modules // Load the imported modules
const char* dev_mblob = const char* dev_mblob = reinterpret_cast<const char*>(
reinterpret_cast<const char*>( GetSymbol(runtime::symbol::dgl_dev_mblob));
GetSymbol(runtime::symbol::dgl_dev_mblob));
if (dev_mblob != nullptr) { if (dev_mblob != nullptr) {
ImportModuleBlob(dev_mblob, &imports_); ImportModuleBlob(dev_mblob, &imports_);
} }
...@@ -79,11 +77,9 @@ class DSOModuleNode final : public ModuleNode { ...@@ -79,11 +77,9 @@ class DSOModuleNode final : public ModuleNode {
} }
void* GetSymbol(const char* name) { void* GetSymbol(const char* name) {
return reinterpret_cast<void*>( return reinterpret_cast<void*>(
GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*) GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*)
}
void Unload() {
FreeLibrary(lib_handle_);
} }
void Unload() { FreeLibrary(lib_handle_); }
#else #else
// Library handle // Library handle
void* lib_handle_{nullptr}; void* lib_handle_{nullptr};
...@@ -91,23 +87,18 @@ class DSOModuleNode final : public ModuleNode { ...@@ -91,23 +87,18 @@ class DSOModuleNode final : public ModuleNode {
void Load(const std::string& name) { void Load(const std::string& name) {
lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL); lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL);
CHECK(lib_handle_ != nullptr) CHECK(lib_handle_ != nullptr)
<< "Failed to load dynamic shared library " << name << "Failed to load dynamic shared library " << name << " " << dlerror();
<< " " << dlerror();
}
void* GetSymbol(const char* name) {
return dlsym(lib_handle_, name);
}
void Unload() {
dlclose(lib_handle_);
} }
void* GetSymbol(const char* name) { return dlsym(lib_handle_, name); }
void Unload() { dlclose(lib_handle_); }
#endif #endif
}; };
DGL_REGISTER_GLOBAL("module.loadfile_so") DGL_REGISTER_GLOBAL("module.loadfile_so")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
std::shared_ptr<DSOModuleNode> n = std::make_shared<DSOModuleNode>(); std::shared_ptr<DSOModuleNode> n = std::make_shared<DSOModuleNode>();
n->Init(args[0]); n->Init(args[0]);
*rv = runtime::Module(n); *rv = runtime::Module(n);
}); });
} // namespace runtime } // namespace runtime
} // namespace dgl } // namespace dgl
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment